kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: tools and BaseTask.agentic refactor
rodzic
94cb9041b1
commit
f95dd54436
|
@ -2,16 +2,21 @@ import { OpenAIClient } from '@agentic/openai-fetch'
|
|||
import 'dotenv/config'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { Agentic, CalculatorTool } from '@/index'
|
||||
import { Agentic, CalculatorTool, WeatherTool } from '@/index'
|
||||
|
||||
async function main() {
|
||||
const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! })
|
||||
const agentic = new Agentic({ openai })
|
||||
|
||||
const example = await agentic
|
||||
.gpt3('What is 5 * 50?')
|
||||
.tools([new CalculatorTool({ agentic })])
|
||||
.output(z.object({ answer: z.number() }))
|
||||
.gpt3('What is the temperature in san francisco today?')
|
||||
.tools([new CalculatorTool(), new WeatherTool()])
|
||||
.output(
|
||||
z.object({
|
||||
answer: z.number(),
|
||||
units: z.union([z.literal('fahrenheit'), z.literal('celcius')])
|
||||
})
|
||||
)
|
||||
.call()
|
||||
console.log(example)
|
||||
}
|
||||
|
|
|
@ -101,7 +101,7 @@
|
|||
]
|
||||
},
|
||||
"ava": {
|
||||
"snapshotDir": ".snapshots",
|
||||
"snapshotDir": "test/.snapshots",
|
||||
"extensions": {
|
||||
"ts": "module"
|
||||
},
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import * as types from '@/types'
|
||||
import { DEFAULT_OPENAI_MODEL } from '@/constants'
|
||||
import { OpenAIChatCompletion } from '@/llms/openai'
|
||||
|
||||
import * as types from './types'
|
||||
import { DEFAULT_OPENAI_MODEL } from './constants'
|
||||
import {
|
||||
HumanFeedbackMechanism,
|
||||
HumanFeedbackMechanismCLI
|
||||
} from './human-feedback'
|
||||
import { OpenAIChatCompletion } from './llms/openai'
|
||||
import { defaultIDGeneratorFn } from './utils'
|
||||
|
||||
export class Agentic {
|
||||
|
@ -20,7 +19,9 @@ export class Agentic {
|
|||
'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig'
|
||||
>
|
||||
protected _defaultHumanFeedbackMechamism?: HumanFeedbackMechanism
|
||||
|
||||
protected _idGeneratorFn: types.IDGeneratorFunction
|
||||
protected _id: string
|
||||
|
||||
constructor(opts: {
|
||||
openai?: types.openai.OpenAIClient
|
||||
|
@ -33,6 +34,10 @@ export class Agentic {
|
|||
defaultHumanFeedbackMechanism?: HumanFeedbackMechanism
|
||||
idGeneratorFn?: types.IDGeneratorFunction
|
||||
}) {
|
||||
if (!globalThis.__agentic?.deref()) {
|
||||
globalThis.__agentic = new WeakRef(this)
|
||||
}
|
||||
|
||||
this._openai = opts.openai
|
||||
this._anthropic = opts.anthropic
|
||||
|
||||
|
@ -59,6 +64,7 @@ export class Agentic {
|
|||
new HumanFeedbackMechanismCLI({ agentic: this })
|
||||
|
||||
this._idGeneratorFn = opts.idGeneratorFn ?? defaultIDGeneratorFn
|
||||
this._id = this._idGeneratorFn()
|
||||
}
|
||||
|
||||
public get openai(): types.openai.OpenAIClient | undefined {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import * as types from '@/types'
|
||||
import { Agentic } from '@/agentic'
|
||||
import { BaseTask } from '@/task'
|
||||
import * as types from './types'
|
||||
import { Agentic } from './agentic'
|
||||
import { BaseTask } from './task'
|
||||
|
||||
export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN'
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ export class AnthropicChatCompletion<
|
|||
...options
|
||||
})
|
||||
|
||||
if (this._agentic.anthropic) {
|
||||
if (this._agentic?.anthropic) {
|
||||
this._client = this._agentic.anthropic
|
||||
} else {
|
||||
throw new Error(
|
||||
|
|
|
@ -72,13 +72,38 @@ export abstract class BaseChatCompletion<
|
|||
}
|
||||
|
||||
this._tools = tools
|
||||
for (const tool of tools) {
|
||||
tool.agentic = this.agentic
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether or not this chat completion model directly supports the use of tools.
|
||||
*/
|
||||
public get supportsTools(): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
public override validate() {
|
||||
super.validate()
|
||||
|
||||
if (this._tools) {
|
||||
for (const tool of this._tools) {
|
||||
if (!tool.agentic) {
|
||||
tool.agentic = this.agentic
|
||||
} else if (tool.agentic !== this.agentic) {
|
||||
throw new Error(
|
||||
`Task "${this.nameForHuman}" has a different Agentic runtime instance than the tool "${tool.nameForHuman}"`
|
||||
)
|
||||
}
|
||||
|
||||
tool.validate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract _createChatCompletion(
|
||||
messages: types.ChatMessage[],
|
||||
functions?: types.openai.ChatMessageFunction[]
|
||||
|
@ -133,9 +158,22 @@ export abstract class BaseChatCompletion<
|
|||
.replace(/^ {4}/gm, ' ')
|
||||
.replace(/;$/gm, '')
|
||||
|
||||
const label =
|
||||
this._outputSchema instanceof z.ZodArray
|
||||
? 'JSON array'
|
||||
: this._outputSchema instanceof z.ZodObject
|
||||
? 'JSON object'
|
||||
: this._outputSchema instanceof z.ZodNumber
|
||||
? 'number'
|
||||
: this._outputSchema instanceof z.ZodString
|
||||
? 'string'
|
||||
: this._outputSchema instanceof z.ZodBoolean
|
||||
? 'boolean'
|
||||
: 'JSON value'
|
||||
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: dedent`Do not output code. Output JSON only in the following TypeScript format:
|
||||
content: dedent`Do not output code. Output a single ${label} in the following TypeScript format:
|
||||
\`\`\`ts
|
||||
${tsTypeString}
|
||||
\`\`\``
|
||||
|
@ -285,6 +323,11 @@ export abstract class BaseChatCompletion<
|
|||
try {
|
||||
const trimmedOutput = extractJSONObjectFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
|
||||
if (Array.isArray(output)) {
|
||||
// TODO
|
||||
output = output[0]
|
||||
}
|
||||
} catch (err: any) {
|
||||
if (err instanceof JSONRepairError) {
|
||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
|
|
|
@ -38,7 +38,7 @@ export class OpenAIChatCompletion<
|
|||
...options
|
||||
})
|
||||
|
||||
if (this._agentic.openai) {
|
||||
if (this._agentic?.openai) {
|
||||
this._client = this._agentic.openai
|
||||
} else {
|
||||
throw new Error(
|
||||
|
@ -93,6 +93,30 @@ export class OpenAIChatCompletion<
|
|||
return openaiModelsSupportingFunctions.has(this._model)
|
||||
}
|
||||
|
||||
public override validate() {
|
||||
super.validate()
|
||||
|
||||
if (!this._client) {
|
||||
throw new Error(
|
||||
'OpenAIChatCompletion requires an OpenAI client to be configured on the Agentic runtime'
|
||||
)
|
||||
}
|
||||
|
||||
if (!this.supportsTools) {
|
||||
if (this._tools) {
|
||||
throw new Error(
|
||||
`This OpenAI chat model "${this.nameForHuman}" does not support tools`
|
||||
)
|
||||
}
|
||||
|
||||
if (this._modelParams?.functions) {
|
||||
throw new Error(
|
||||
`This OpenAI chat model "${this.nameForHuman}" does not support functions`
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected override async _createChatCompletion(
|
||||
messages: types.ChatMessage[],
|
||||
functions?: types.openai.ChatMessageFunction[]
|
||||
|
|
44
src/task.ts
44
src/task.ts
|
@ -1,9 +1,10 @@
|
|||
import pRetry, { FailedAttemptError } from 'p-retry'
|
||||
import { ZodType } from 'zod'
|
||||
|
||||
import * as errors from '@/errors'
|
||||
import * as types from '@/types'
|
||||
import { Agentic } from '@/agentic'
|
||||
import * as errors from './errors'
|
||||
import * as types from './types'
|
||||
import type { Agentic } from './agentic'
|
||||
import { defaultIDGeneratorFn, isValidTaskIdentifier } from './utils'
|
||||
|
||||
/**
|
||||
* A `Task` is an async function call that may be non-deterministic. It has
|
||||
|
@ -28,24 +29,27 @@ export abstract class BaseTask<
|
|||
protected _timeoutMs?: number
|
||||
protected _retryConfig: types.RetryConfig
|
||||
|
||||
constructor(options: types.BaseTaskOptions) {
|
||||
if (!options.agentic) {
|
||||
throw new Error('Passing "agentic" is required when creating a Task')
|
||||
}
|
||||
constructor(options: types.BaseTaskOptions = {}) {
|
||||
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
|
||||
|
||||
this._agentic = options.agentic
|
||||
this._timeoutMs = options.timeoutMs
|
||||
this._retryConfig = options.retryConfig ?? {
|
||||
retries: 3,
|
||||
strategy: 'default'
|
||||
}
|
||||
this._id = options.id ?? this._agentic.idGeneratorFn()
|
||||
|
||||
this._id =
|
||||
options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn()
|
||||
}
|
||||
|
||||
public get agentic(): Agentic {
|
||||
return this._agentic
|
||||
}
|
||||
|
||||
public set agentic(agentic: Agentic) {
|
||||
this._agentic = agentic
|
||||
}
|
||||
|
||||
public get id(): string {
|
||||
return this._id
|
||||
}
|
||||
|
@ -53,7 +57,10 @@ export abstract class BaseTask<
|
|||
public abstract get inputSchema(): ZodType<TInput>
|
||||
public abstract get outputSchema(): ZodType<TOutput>
|
||||
|
||||
public abstract get nameForModel(): string
|
||||
public get nameForModel(): string {
|
||||
const name = this.constructor.name
|
||||
return name[0].toLowerCase() + name.slice(1)
|
||||
}
|
||||
|
||||
public get nameForHuman(): string {
|
||||
return this.constructor.name
|
||||
|
@ -63,6 +70,19 @@ export abstract class BaseTask<
|
|||
return ''
|
||||
}
|
||||
|
||||
public validate() {
|
||||
if (!this._agentic) {
|
||||
throw new Error(
|
||||
`Task "${this.nameForHuman}" is missing a required "agentic" instance`
|
||||
)
|
||||
}
|
||||
|
||||
const nameForModel = this.nameForModel
|
||||
if (!isValidTaskIdentifier(nameForModel)) {
|
||||
throw new Error(`Task field nameForModel "${nameForModel}" is invalid`)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: is this really necessary?
|
||||
public clone(): BaseTask<TInput, TOutput> {
|
||||
// TODO: override in subclass if needed
|
||||
|
@ -88,6 +108,8 @@ export abstract class BaseTask<
|
|||
public async callWithMetadata(
|
||||
input?: TInput
|
||||
): Promise<types.TaskResponse<TOutput>> {
|
||||
this.validate()
|
||||
|
||||
if (this.inputSchema) {
|
||||
const safeInput = this.inputSchema.safeParse(input)
|
||||
|
||||
|
@ -104,7 +126,7 @@ export abstract class BaseTask<
|
|||
metadata: {
|
||||
taskName: this.nameForModel,
|
||||
taskId: this.id,
|
||||
callId: this._agentic.idGeneratorFn()
|
||||
callId: this._agentic!.idGeneratorFn()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ export class CalculatorTool extends BaseTask<
|
|||
CalculatorInput,
|
||||
CalculatorOutput
|
||||
> {
|
||||
constructor(opts: types.BaseTaskOptions) {
|
||||
constructor(opts: types.BaseTaskOptions = {}) {
|
||||
super(opts)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,26 +1,20 @@
|
|||
import * as metaphor from '@/services/metaphor'
|
||||
import * as types from '@/types'
|
||||
import { Agentic } from '@/agentic'
|
||||
import { BaseTask } from '@/task'
|
||||
|
||||
export class MetaphorSearchTool extends BaseTask<
|
||||
metaphor.MetaphorSearchInput,
|
||||
metaphor.MetaphorSearchOutput
|
||||
> {
|
||||
_metaphorClient: metaphor.MetaphorClient
|
||||
protected _metaphorClient: metaphor.MetaphorClient
|
||||
|
||||
constructor({
|
||||
agentic,
|
||||
metaphorClient = new metaphor.MetaphorClient(),
|
||||
...rest
|
||||
...opts
|
||||
}: {
|
||||
agentic: Agentic
|
||||
metaphorClient?: metaphor.MetaphorClient
|
||||
} & types.BaseTaskOptions) {
|
||||
super({
|
||||
agentic,
|
||||
...rest
|
||||
})
|
||||
} & types.BaseTaskOptions = {}) {
|
||||
super(opts)
|
||||
|
||||
this._metaphorClient = metaphorClient
|
||||
}
|
||||
|
@ -34,7 +28,7 @@ export class MetaphorSearchTool extends BaseTask<
|
|||
}
|
||||
|
||||
public override get nameForModel(): string {
|
||||
return 'metaphor_web_search'
|
||||
return 'metaphorWebSearch'
|
||||
}
|
||||
|
||||
protected override async _call(
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import { z } from 'zod'
|
||||
|
||||
import * as types from '@/types'
|
||||
import { Agentic } from '@/agentic'
|
||||
import { NovuClient } from '@/services/novu'
|
||||
import { BaseTask } from '@/task'
|
||||
|
||||
|
@ -39,18 +38,15 @@ export class NovuNotificationTool extends BaseTask<
|
|||
NovuNotificationToolInput,
|
||||
NovuNotificationToolOutput
|
||||
> {
|
||||
_novuClient: NovuClient
|
||||
protected _novuClient: NovuClient
|
||||
|
||||
constructor({
|
||||
agentic,
|
||||
novuClient = new NovuClient()
|
||||
novuClient = new NovuClient(),
|
||||
...opts
|
||||
}: {
|
||||
agentic: Agentic
|
||||
novuClient?: NovuClient
|
||||
}) {
|
||||
super({
|
||||
agentic
|
||||
})
|
||||
} & types.BaseTaskOptions = {}) {
|
||||
super(opts)
|
||||
|
||||
this._novuClient = novuClient
|
||||
}
|
||||
|
@ -64,7 +60,7 @@ export class NovuNotificationTool extends BaseTask<
|
|||
}
|
||||
|
||||
public override get nameForModel(): string {
|
||||
return 'novu_send_notification'
|
||||
return 'novuSendNotification'
|
||||
}
|
||||
|
||||
protected override async _call(
|
||||
|
|
|
@ -9,7 +9,12 @@ export const WeatherInputSchema = z.object({
|
|||
.string()
|
||||
.describe(
|
||||
'Location to get the weather for. Can be a city name like "Paris", a zipcode like "53121", an international postal code like "SW1", or a latitude and longitude like "48.8567,2.3508"'
|
||||
)
|
||||
),
|
||||
|
||||
units: z
|
||||
.union([z.literal('imperial'), z.literal('metric')])
|
||||
.default('imperial')
|
||||
.optional()
|
||||
})
|
||||
export type WeatherInput = z.infer<typeof WeatherInputSchema>
|
||||
|
||||
|
@ -33,8 +38,8 @@ const ConditionSchema = z.object({
|
|||
const CurrentSchema = z.object({
|
||||
last_updated_epoch: z.number(),
|
||||
last_updated: z.string(),
|
||||
temp_c: z.number(),
|
||||
temp_f: z.number(),
|
||||
temp_c: z.number().describe('temperature in celsius'),
|
||||
temp_f: z.number().describe('temperature in fahrenheit'),
|
||||
is_day: z.number(),
|
||||
condition: ConditionSchema,
|
||||
wind_mph: z.number(),
|
||||
|
@ -65,14 +70,19 @@ export type WeatherOutput = z.infer<typeof WeatherOutputSchema>
|
|||
export class WeatherTool extends BaseTask<WeatherInput, WeatherOutput> {
|
||||
client: WeatherClient
|
||||
|
||||
constructor(
|
||||
opts: {
|
||||
weather: WeatherClient
|
||||
} & types.BaseTaskOptions
|
||||
) {
|
||||
constructor({
|
||||
weather = new WeatherClient({ apiKey: process.env.WEATHER_API_KEY }),
|
||||
...opts
|
||||
}: {
|
||||
weather?: WeatherClient
|
||||
} & types.BaseTaskOptions = {}) {
|
||||
super(opts)
|
||||
|
||||
this.client = opts.weather
|
||||
if (!weather) {
|
||||
throw new Error(`Error WeatherTool missing required "weather" client`)
|
||||
}
|
||||
|
||||
this.client = weather
|
||||
}
|
||||
|
||||
public override get inputSchema() {
|
||||
|
@ -92,7 +102,7 @@ export class WeatherTool extends BaseTask<WeatherInput, WeatherOutput> {
|
|||
}
|
||||
|
||||
public override get descForModel(): string {
|
||||
return 'Useful for getting the current weather at a location'
|
||||
return 'Useful for getting the current weather at a location.'
|
||||
}
|
||||
|
||||
protected override async _call(
|
||||
|
|
|
@ -20,7 +20,7 @@ export type SafeParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
|
|||
: never
|
||||
|
||||
export interface BaseTaskOptions {
|
||||
agentic: Agentic
|
||||
agentic?: Agentic
|
||||
|
||||
timeoutMs?: number
|
||||
retryConfig?: RetryConfig
|
||||
|
|
Ładowanie…
Reference in New Issue