feat: tools and BaseTask.agentic refactor

Travis Fischer 2023-06-14 17:40:48 -07:00
rodzic 94cb9041b1
commit f95dd54436
15 zmienionych plików z 159 dodań i 59 usunięć

13
examples/functions.ts vendored
Wyświetl plik

@ -2,16 +2,21 @@ import { OpenAIClient } from '@agentic/openai-fetch'
import 'dotenv/config' import 'dotenv/config'
import { z } from 'zod' import { z } from 'zod'
import { Agentic, CalculatorTool } from '@/index' import { Agentic, CalculatorTool, WeatherTool } from '@/index'
async function main() { async function main() {
const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! }) const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! })
const agentic = new Agentic({ openai }) const agentic = new Agentic({ openai })
const example = await agentic const example = await agentic
.gpt3('What is 5 * 50?') .gpt3('What is the temperature in san francisco today?')
.tools([new CalculatorTool({ agentic })]) .tools([new CalculatorTool(), new WeatherTool()])
.output(z.object({ answer: z.number() })) .output(
z.object({
answer: z.number(),
units: z.union([z.literal('fahrenheit'), z.literal('celcius')])
})
)
.call() .call()
console.log(example) console.log(example)
} }

Wyświetl plik

@ -101,7 +101,7 @@
] ]
}, },
"ava": { "ava": {
"snapshotDir": ".snapshots", "snapshotDir": "test/.snapshots",
"extensions": { "extensions": {
"ts": "module" "ts": "module"
}, },

Wyświetl plik

@ -1,11 +1,10 @@
import * as types from '@/types' import * as types from './types'
import { DEFAULT_OPENAI_MODEL } from '@/constants' import { DEFAULT_OPENAI_MODEL } from './constants'
import { OpenAIChatCompletion } from '@/llms/openai'
import { import {
HumanFeedbackMechanism, HumanFeedbackMechanism,
HumanFeedbackMechanismCLI HumanFeedbackMechanismCLI
} from './human-feedback' } from './human-feedback'
import { OpenAIChatCompletion } from './llms/openai'
import { defaultIDGeneratorFn } from './utils' import { defaultIDGeneratorFn } from './utils'
export class Agentic { export class Agentic {
@ -20,7 +19,9 @@ export class Agentic {
'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig' 'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig'
> >
protected _defaultHumanFeedbackMechamism?: HumanFeedbackMechanism protected _defaultHumanFeedbackMechamism?: HumanFeedbackMechanism
protected _idGeneratorFn: types.IDGeneratorFunction protected _idGeneratorFn: types.IDGeneratorFunction
protected _id: string
constructor(opts: { constructor(opts: {
openai?: types.openai.OpenAIClient openai?: types.openai.OpenAIClient
@ -33,6 +34,10 @@ export class Agentic {
defaultHumanFeedbackMechanism?: HumanFeedbackMechanism defaultHumanFeedbackMechanism?: HumanFeedbackMechanism
idGeneratorFn?: types.IDGeneratorFunction idGeneratorFn?: types.IDGeneratorFunction
}) { }) {
if (!globalThis.__agentic?.deref()) {
globalThis.__agentic = new WeakRef(this)
}
this._openai = opts.openai this._openai = opts.openai
this._anthropic = opts.anthropic this._anthropic = opts.anthropic
@ -59,6 +64,7 @@ export class Agentic {
new HumanFeedbackMechanismCLI({ agentic: this }) new HumanFeedbackMechanismCLI({ agentic: this })
this._idGeneratorFn = opts.idGeneratorFn ?? defaultIDGeneratorFn this._idGeneratorFn = opts.idGeneratorFn ?? defaultIDGeneratorFn
this._id = this._idGeneratorFn()
} }
public get openai(): types.openai.OpenAIClient | undefined { public get openai(): types.openai.OpenAIClient | undefined {

Wyświetl plik

@ -1,6 +1,6 @@
import * as types from '@/types' import * as types from './types'
import { Agentic } from '@/agentic' import { Agentic } from './agentic'
import { BaseTask } from '@/task' import { BaseTask } from './task'
export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN' export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN'

Wyświetl plik

@ -38,7 +38,7 @@ export class AnthropicChatCompletion<
...options ...options
}) })
if (this._agentic.anthropic) { if (this._agentic?.anthropic) {
this._client = this._agentic.anthropic this._client = this._agentic.anthropic
} else { } else {
throw new Error( throw new Error(

Wyświetl plik

@ -72,13 +72,38 @@ export abstract class BaseChatCompletion<
} }
this._tools = tools this._tools = tools
for (const tool of tools) {
tool.agentic = this.agentic
}
return this return this
} }
/**
* Whether or not this chat completion model directly supports the use of tools.
*/
public get supportsTools(): boolean { public get supportsTools(): boolean {
return false return false
} }
public override validate() {
super.validate()
if (this._tools) {
for (const tool of this._tools) {
if (!tool.agentic) {
tool.agentic = this.agentic
} else if (tool.agentic !== this.agentic) {
throw new Error(
`Task "${this.nameForHuman}" has a different Agentic runtime instance than the tool "${tool.nameForHuman}"`
)
}
tool.validate()
}
}
}
protected abstract _createChatCompletion( protected abstract _createChatCompletion(
messages: types.ChatMessage[], messages: types.ChatMessage[],
functions?: types.openai.ChatMessageFunction[] functions?: types.openai.ChatMessageFunction[]
@ -133,9 +158,22 @@ export abstract class BaseChatCompletion<
.replace(/^ {4}/gm, ' ') .replace(/^ {4}/gm, ' ')
.replace(/;$/gm, '') .replace(/;$/gm, '')
const label =
this._outputSchema instanceof z.ZodArray
? 'JSON array'
: this._outputSchema instanceof z.ZodObject
? 'JSON object'
: this._outputSchema instanceof z.ZodNumber
? 'number'
: this._outputSchema instanceof z.ZodString
? 'string'
: this._outputSchema instanceof z.ZodBoolean
? 'boolean'
: 'JSON value'
messages.push({ messages.push({
role: 'system', role: 'system',
content: dedent`Do not output code. Output JSON only in the following TypeScript format: content: dedent`Do not output code. Output a single ${label} in the following TypeScript format:
\`\`\`ts \`\`\`ts
${tsTypeString} ${tsTypeString}
\`\`\`` \`\`\``
@ -285,6 +323,11 @@ export abstract class BaseChatCompletion<
try { try {
const trimmedOutput = extractJSONObjectFromString(output) const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output)) output = JSON.parse(jsonrepair(trimmedOutput ?? output))
if (Array.isArray(output)) {
// TODO
output = output[0]
}
} catch (err: any) { } catch (err: any) {
if (err instanceof JSONRepairError) { if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err }) throw new errors.OutputValidationError(err.message, { cause: err })

Wyświetl plik

@ -38,7 +38,7 @@ export class OpenAIChatCompletion<
...options ...options
}) })
if (this._agentic.openai) { if (this._agentic?.openai) {
this._client = this._agentic.openai this._client = this._agentic.openai
} else { } else {
throw new Error( throw new Error(
@ -93,6 +93,30 @@ export class OpenAIChatCompletion<
return openaiModelsSupportingFunctions.has(this._model) return openaiModelsSupportingFunctions.has(this._model)
} }
public override validate() {
super.validate()
if (!this._client) {
throw new Error(
'OpenAIChatCompletion requires an OpenAI client to be configured on the Agentic runtime'
)
}
if (!this.supportsTools) {
if (this._tools) {
throw new Error(
`This OpenAI chat model "${this.nameForHuman}" does not support tools`
)
}
if (this._modelParams?.functions) {
throw new Error(
`This OpenAI chat model "${this.nameForHuman}" does not support functions`
)
}
}
}
protected override async _createChatCompletion( protected override async _createChatCompletion(
messages: types.ChatMessage[], messages: types.ChatMessage[],
functions?: types.openai.ChatMessageFunction[] functions?: types.openai.ChatMessageFunction[]

Wyświetl plik

@ -1,9 +1,10 @@
import pRetry, { FailedAttemptError } from 'p-retry' import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodType } from 'zod' import { ZodType } from 'zod'
import * as errors from '@/errors' import * as errors from './errors'
import * as types from '@/types' import * as types from './types'
import { Agentic } from '@/agentic' import type { Agentic } from './agentic'
import { defaultIDGeneratorFn, isValidTaskIdentifier } from './utils'
/** /**
* A `Task` is an async function call that may be non-deterministic. It has * A `Task` is an async function call that may be non-deterministic. It has
@ -28,24 +29,27 @@ export abstract class BaseTask<
protected _timeoutMs?: number protected _timeoutMs?: number
protected _retryConfig: types.RetryConfig protected _retryConfig: types.RetryConfig
constructor(options: types.BaseTaskOptions) { constructor(options: types.BaseTaskOptions = {}) {
if (!options.agentic) { this._agentic = options.agentic ?? globalThis.__agentic?.deref()
throw new Error('Passing "agentic" is required when creating a Task')
}
this._agentic = options.agentic
this._timeoutMs = options.timeoutMs this._timeoutMs = options.timeoutMs
this._retryConfig = options.retryConfig ?? { this._retryConfig = options.retryConfig ?? {
retries: 3, retries: 3,
strategy: 'default' strategy: 'default'
} }
this._id = options.id ?? this._agentic.idGeneratorFn()
this._id =
options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn()
} }
public get agentic(): Agentic { public get agentic(): Agentic {
return this._agentic return this._agentic
} }
public set agentic(agentic: Agentic) {
this._agentic = agentic
}
public get id(): string { public get id(): string {
return this._id return this._id
} }
@ -53,7 +57,10 @@ export abstract class BaseTask<
public abstract get inputSchema(): ZodType<TInput> public abstract get inputSchema(): ZodType<TInput>
public abstract get outputSchema(): ZodType<TOutput> public abstract get outputSchema(): ZodType<TOutput>
public abstract get nameForModel(): string public get nameForModel(): string {
const name = this.constructor.name
return name[0].toLowerCase() + name.slice(1)
}
public get nameForHuman(): string { public get nameForHuman(): string {
return this.constructor.name return this.constructor.name
@ -63,6 +70,19 @@ export abstract class BaseTask<
return '' return ''
} }
public validate() {
if (!this._agentic) {
throw new Error(
`Task "${this.nameForHuman}" is missing a required "agentic" instance`
)
}
const nameForModel = this.nameForModel
if (!isValidTaskIdentifier(nameForModel)) {
throw new Error(`Task field nameForModel "${nameForModel}" is invalid`)
}
}
// TODO: is this really necessary? // TODO: is this really necessary?
public clone(): BaseTask<TInput, TOutput> { public clone(): BaseTask<TInput, TOutput> {
// TODO: override in subclass if needed // TODO: override in subclass if needed
@ -88,6 +108,8 @@ export abstract class BaseTask<
public async callWithMetadata( public async callWithMetadata(
input?: TInput input?: TInput
): Promise<types.TaskResponse<TOutput>> { ): Promise<types.TaskResponse<TOutput>> {
this.validate()
if (this.inputSchema) { if (this.inputSchema) {
const safeInput = this.inputSchema.safeParse(input) const safeInput = this.inputSchema.safeParse(input)
@ -104,7 +126,7 @@ export abstract class BaseTask<
metadata: { metadata: {
taskName: this.nameForModel, taskName: this.nameForModel,
taskId: this.id, taskId: this.id,
callId: this._agentic.idGeneratorFn() callId: this._agentic!.idGeneratorFn()
} }
} }

Wyświetl plik

@ -19,7 +19,7 @@ export class CalculatorTool extends BaseTask<
CalculatorInput, CalculatorInput,
CalculatorOutput CalculatorOutput
> { > {
constructor(opts: types.BaseTaskOptions) { constructor(opts: types.BaseTaskOptions = {}) {
super(opts) super(opts)
} }

Wyświetl plik

@ -1,26 +1,20 @@
import * as metaphor from '@/services/metaphor' import * as metaphor from '@/services/metaphor'
import * as types from '@/types' import * as types from '@/types'
import { Agentic } from '@/agentic'
import { BaseTask } from '@/task' import { BaseTask } from '@/task'
export class MetaphorSearchTool extends BaseTask< export class MetaphorSearchTool extends BaseTask<
metaphor.MetaphorSearchInput, metaphor.MetaphorSearchInput,
metaphor.MetaphorSearchOutput metaphor.MetaphorSearchOutput
> { > {
_metaphorClient: metaphor.MetaphorClient protected _metaphorClient: metaphor.MetaphorClient
constructor({ constructor({
agentic,
metaphorClient = new metaphor.MetaphorClient(), metaphorClient = new metaphor.MetaphorClient(),
...rest ...opts
}: { }: {
agentic: Agentic
metaphorClient?: metaphor.MetaphorClient metaphorClient?: metaphor.MetaphorClient
} & types.BaseTaskOptions) { } & types.BaseTaskOptions = {}) {
super({ super(opts)
agentic,
...rest
})
this._metaphorClient = metaphorClient this._metaphorClient = metaphorClient
} }
@ -34,7 +28,7 @@ export class MetaphorSearchTool extends BaseTask<
} }
public override get nameForModel(): string { public override get nameForModel(): string {
return 'metaphor_web_search' return 'metaphorWebSearch'
} }
protected override async _call( protected override async _call(

Wyświetl plik

@ -1,7 +1,6 @@
import { z } from 'zod' import { z } from 'zod'
import * as types from '@/types' import * as types from '@/types'
import { Agentic } from '@/agentic'
import { NovuClient } from '@/services/novu' import { NovuClient } from '@/services/novu'
import { BaseTask } from '@/task' import { BaseTask } from '@/task'
@ -39,18 +38,15 @@ export class NovuNotificationTool extends BaseTask<
NovuNotificationToolInput, NovuNotificationToolInput,
NovuNotificationToolOutput NovuNotificationToolOutput
> { > {
_novuClient: NovuClient protected _novuClient: NovuClient
constructor({ constructor({
agentic, novuClient = new NovuClient(),
novuClient = new NovuClient() ...opts
}: { }: {
agentic: Agentic
novuClient?: NovuClient novuClient?: NovuClient
}) { } & types.BaseTaskOptions = {}) {
super({ super(opts)
agentic
})
this._novuClient = novuClient this._novuClient = novuClient
} }
@ -64,7 +60,7 @@ export class NovuNotificationTool extends BaseTask<
} }
public override get nameForModel(): string { public override get nameForModel(): string {
return 'novu_send_notification' return 'novuSendNotification'
} }
protected override async _call( protected override async _call(

Wyświetl plik

@ -9,7 +9,12 @@ export const WeatherInputSchema = z.object({
.string() .string()
.describe( .describe(
'Location to get the weather for. Can be a city name like "Paris", a zipcode like "53121", an international postal code like "SW1", or a latitude and longitude like "48.8567,2.3508"' 'Location to get the weather for. Can be a city name like "Paris", a zipcode like "53121", an international postal code like "SW1", or a latitude and longitude like "48.8567,2.3508"'
) ),
units: z
.union([z.literal('imperial'), z.literal('metric')])
.default('imperial')
.optional()
}) })
export type WeatherInput = z.infer<typeof WeatherInputSchema> export type WeatherInput = z.infer<typeof WeatherInputSchema>
@ -33,8 +38,8 @@ const ConditionSchema = z.object({
const CurrentSchema = z.object({ const CurrentSchema = z.object({
last_updated_epoch: z.number(), last_updated_epoch: z.number(),
last_updated: z.string(), last_updated: z.string(),
temp_c: z.number(), temp_c: z.number().describe('temperature in celsius'),
temp_f: z.number(), temp_f: z.number().describe('temperature in fahrenheit'),
is_day: z.number(), is_day: z.number(),
condition: ConditionSchema, condition: ConditionSchema,
wind_mph: z.number(), wind_mph: z.number(),
@ -65,14 +70,19 @@ export type WeatherOutput = z.infer<typeof WeatherOutputSchema>
export class WeatherTool extends BaseTask<WeatherInput, WeatherOutput> { export class WeatherTool extends BaseTask<WeatherInput, WeatherOutput> {
client: WeatherClient client: WeatherClient
constructor( constructor({
opts: { weather = new WeatherClient({ apiKey: process.env.WEATHER_API_KEY }),
weather: WeatherClient ...opts
} & types.BaseTaskOptions }: {
) { weather?: WeatherClient
} & types.BaseTaskOptions = {}) {
super(opts) super(opts)
this.client = opts.weather if (!weather) {
throw new Error(`Error WeatherTool missing required "weather" client`)
}
this.client = weather
} }
public override get inputSchema() { public override get inputSchema() {
@ -92,7 +102,7 @@ export class WeatherTool extends BaseTask<WeatherInput, WeatherOutput> {
} }
public override get descForModel(): string { public override get descForModel(): string {
return 'Useful for getting the current weather at a location' return 'Useful for getting the current weather at a location.'
} }
protected override async _call( protected override async _call(

Wyświetl plik

@ -20,7 +20,7 @@ export type SafeParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
: never : never
export interface BaseTaskOptions { export interface BaseTaskOptions {
agentic: Agentic agentic?: Agentic
timeoutMs?: number timeoutMs?: number
retryConfig?: RetryConfig retryConfig?: RetryConfig