diff --git a/examples/functions.ts b/examples/functions.ts index 8b239b08..99a070ac 100644 --- a/examples/functions.ts +++ b/examples/functions.ts @@ -2,16 +2,21 @@ import { OpenAIClient } from '@agentic/openai-fetch' import 'dotenv/config' import { z } from 'zod' -import { Agentic, CalculatorTool } from '@/index' +import { Agentic, CalculatorTool, WeatherTool } from '@/index' async function main() { const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! }) const agentic = new Agentic({ openai }) const example = await agentic - .gpt3('What is 5 * 50?') - .tools([new CalculatorTool({ agentic })]) - .output(z.object({ answer: z.number() })) + .gpt3('What is the temperature in san francisco today?') + .tools([new CalculatorTool(), new WeatherTool()]) + .output( + z.object({ + answer: z.number(), + units: z.union([z.literal('fahrenheit'), z.literal('celcius')]) + }) + ) .call() console.log(example) } diff --git a/package.json b/package.json index 82fd944c..9bb4e975 100644 --- a/package.json +++ b/package.json @@ -101,7 +101,7 @@ ] }, "ava": { - "snapshotDir": ".snapshots", + "snapshotDir": "test/.snapshots", "extensions": { "ts": "module" }, diff --git a/src/agentic.ts b/src/agentic.ts index dcd757b4..67e9a923 100644 --- a/src/agentic.ts +++ b/src/agentic.ts @@ -1,11 +1,10 @@ -import * as types from '@/types' -import { DEFAULT_OPENAI_MODEL } from '@/constants' -import { OpenAIChatCompletion } from '@/llms/openai' - +import * as types from './types' +import { DEFAULT_OPENAI_MODEL } from './constants' import { HumanFeedbackMechanism, HumanFeedbackMechanismCLI } from './human-feedback' +import { OpenAIChatCompletion } from './llms/openai' import { defaultIDGeneratorFn } from './utils' export class Agentic { @@ -20,7 +19,9 @@ export class Agentic { 'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig' > protected _defaultHumanFeedbackMechamism?: HumanFeedbackMechanism + protected _idGeneratorFn: types.IDGeneratorFunction + protected _id: string constructor(opts: { openai?: types.openai.OpenAIClient @@ -33,6 +34,10 @@ export class Agentic { defaultHumanFeedbackMechanism?: HumanFeedbackMechanism idGeneratorFn?: types.IDGeneratorFunction }) { + if (!globalThis.__agentic?.deref()) { + globalThis.__agentic = new WeakRef(this) + } + this._openai = opts.openai this._anthropic = opts.anthropic @@ -59,6 +64,7 @@ export class Agentic { new HumanFeedbackMechanismCLI({ agentic: this }) this._idGeneratorFn = opts.idGeneratorFn ?? defaultIDGeneratorFn + this._id = this._idGeneratorFn() } public get openai(): types.openai.OpenAIClient | undefined { diff --git a/src/human-feedback.ts b/src/human-feedback.ts index 40c1d5a3..6f9e3d2d 100644 --- a/src/human-feedback.ts +++ b/src/human-feedback.ts @@ -1,6 +1,6 @@ -import * as types from '@/types' -import { Agentic } from '@/agentic' -import { BaseTask } from '@/task' +import * as types from './types' +import { Agentic } from './agentic' +import { BaseTask } from './task' export type HumanFeedbackType = 'confirm' | 'selectOne' | 'selectN' diff --git a/src/llms/anthropic.ts b/src/llms/anthropic.ts index 52a21c74..2a267cdd 100644 --- a/src/llms/anthropic.ts +++ b/src/llms/anthropic.ts @@ -38,7 +38,7 @@ export class AnthropicChatCompletion< ...options }) - if (this._agentic.anthropic) { + if (this._agentic?.anthropic) { this._client = this._agentic.anthropic } else { throw new Error( diff --git a/src/llms/chat.ts b/src/llms/chat.ts index acd925e8..9ab4aa1c 100644 --- a/src/llms/chat.ts +++ b/src/llms/chat.ts @@ -72,13 +72,38 @@ export abstract class BaseChatCompletion< } this._tools = tools + for (const tool of tools) { + tool.agentic = this.agentic + } + return this } + /** + * Whether or not this chat completion model directly supports the use of tools. + */ public get supportsTools(): boolean { return false } + public override validate() { + super.validate() + + if (this._tools) { + for (const tool of this._tools) { + if (!tool.agentic) { + tool.agentic = this.agentic + } else if (tool.agentic !== this.agentic) { + throw new Error( + `Task "${this.nameForHuman}" has a different Agentic runtime instance than the tool "${tool.nameForHuman}"` + ) + } + + tool.validate() + } + } + } + protected abstract _createChatCompletion( messages: types.ChatMessage[], functions?: types.openai.ChatMessageFunction[] @@ -133,9 +158,22 @@ export abstract class BaseChatCompletion< .replace(/^ {4}/gm, ' ') .replace(/;$/gm, '') + const label = + this._outputSchema instanceof z.ZodArray + ? 'JSON array' + : this._outputSchema instanceof z.ZodObject + ? 'JSON object' + : this._outputSchema instanceof z.ZodNumber + ? 'number' + : this._outputSchema instanceof z.ZodString + ? 'string' + : this._outputSchema instanceof z.ZodBoolean + ? 'boolean' + : 'JSON value' + messages.push({ role: 'system', - content: dedent`Do not output code. Output JSON only in the following TypeScript format: + content: dedent`Do not output code. Output a single ${label} in the following TypeScript format: \`\`\`ts ${tsTypeString} \`\`\`` @@ -285,6 +323,11 @@ export abstract class BaseChatCompletion< try { const trimmedOutput = extractJSONObjectFromString(output) output = JSON.parse(jsonrepair(trimmedOutput ?? output)) + + if (Array.isArray(output)) { + // TODO + output = output[0] + } } catch (err: any) { if (err instanceof JSONRepairError) { throw new errors.OutputValidationError(err.message, { cause: err }) diff --git a/src/llms/openai.ts b/src/llms/openai.ts index 1ec90238..ee8f5017 100644 --- a/src/llms/openai.ts +++ b/src/llms/openai.ts @@ -38,7 +38,7 @@ export class OpenAIChatCompletion< ...options }) - if (this._agentic.openai) { + if (this._agentic?.openai) { this._client = this._agentic.openai } else { throw new Error( @@ -93,6 +93,30 @@ export class OpenAIChatCompletion< return openaiModelsSupportingFunctions.has(this._model) } + public override validate() { + super.validate() + + if (!this._client) { + throw new Error( + 'OpenAIChatCompletion requires an OpenAI client to be configured on the Agentic runtime' + ) + } + + if (!this.supportsTools) { + if (this._tools) { + throw new Error( + `This OpenAI chat model "${this.nameForHuman}" does not support tools` + ) + } + + if (this._modelParams?.functions) { + throw new Error( + `This OpenAI chat model "${this.nameForHuman}" does not support functions` + ) + } + } + } + protected override async _createChatCompletion( messages: types.ChatMessage[], functions?: types.openai.ChatMessageFunction[] diff --git a/src/task.ts b/src/task.ts index e6f504a0..b0b98131 100644 --- a/src/task.ts +++ b/src/task.ts @@ -1,9 +1,10 @@ import pRetry, { FailedAttemptError } from 'p-retry' import { ZodType } from 'zod' -import * as errors from '@/errors' -import * as types from '@/types' -import { Agentic } from '@/agentic' +import * as errors from './errors' +import * as types from './types' +import type { Agentic } from './agentic' +import { defaultIDGeneratorFn, isValidTaskIdentifier } from './utils' /** * A `Task` is an async function call that may be non-deterministic. It has @@ -28,24 +29,27 @@ export abstract class BaseTask< protected _timeoutMs?: number protected _retryConfig: types.RetryConfig - constructor(options: types.BaseTaskOptions) { - if (!options.agentic) { - throw new Error('Passing "agentic" is required when creating a Task') - } + constructor(options: types.BaseTaskOptions = {}) { + this._agentic = options.agentic ?? globalThis.__agentic?.deref() - this._agentic = options.agentic this._timeoutMs = options.timeoutMs this._retryConfig = options.retryConfig ?? { retries: 3, strategy: 'default' } - this._id = options.id ?? this._agentic.idGeneratorFn() + + this._id = + options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn() } public get agentic(): Agentic { return this._agentic } + public set agentic(agentic: Agentic) { + this._agentic = agentic + } + public get id(): string { return this._id } @@ -53,7 +57,10 @@ export abstract class BaseTask< public abstract get inputSchema(): ZodType public abstract get outputSchema(): ZodType - public abstract get nameForModel(): string + public get nameForModel(): string { + const name = this.constructor.name + return name[0].toLowerCase() + name.slice(1) + } public get nameForHuman(): string { return this.constructor.name @@ -63,6 +70,19 @@ export abstract class BaseTask< return '' } + public validate() { + if (!this._agentic) { + throw new Error( + `Task "${this.nameForHuman}" is missing a required "agentic" instance` + ) + } + + const nameForModel = this.nameForModel + if (!isValidTaskIdentifier(nameForModel)) { + throw new Error(`Task field nameForModel "${nameForModel}" is invalid`) + } + } + // TODO: is this really necessary? public clone(): BaseTask { // TODO: override in subclass if needed @@ -88,6 +108,8 @@ export abstract class BaseTask< public async callWithMetadata( input?: TInput ): Promise> { + this.validate() + if (this.inputSchema) { const safeInput = this.inputSchema.safeParse(input) @@ -104,7 +126,7 @@ export abstract class BaseTask< metadata: { taskName: this.nameForModel, taskId: this.id, - callId: this._agentic.idGeneratorFn() + callId: this._agentic!.idGeneratorFn() } } diff --git a/src/tools/calculator.ts b/src/tools/calculator.ts index 6850374d..82a0aef6 100644 --- a/src/tools/calculator.ts +++ b/src/tools/calculator.ts @@ -19,7 +19,7 @@ export class CalculatorTool extends BaseTask< CalculatorInput, CalculatorOutput > { - constructor(opts: types.BaseTaskOptions) { + constructor(opts: types.BaseTaskOptions = {}) { super(opts) } diff --git a/src/tools/metaphor.ts b/src/tools/metaphor.ts index 0dec2821..7cf00f75 100644 --- a/src/tools/metaphor.ts +++ b/src/tools/metaphor.ts @@ -1,26 +1,20 @@ import * as metaphor from '@/services/metaphor' import * as types from '@/types' -import { Agentic } from '@/agentic' import { BaseTask } from '@/task' export class MetaphorSearchTool extends BaseTask< metaphor.MetaphorSearchInput, metaphor.MetaphorSearchOutput > { - _metaphorClient: metaphor.MetaphorClient + protected _metaphorClient: metaphor.MetaphorClient constructor({ - agentic, metaphorClient = new metaphor.MetaphorClient(), - ...rest + ...opts }: { - agentic: Agentic metaphorClient?: metaphor.MetaphorClient - } & types.BaseTaskOptions) { - super({ - agentic, - ...rest - }) + } & types.BaseTaskOptions = {}) { + super(opts) this._metaphorClient = metaphorClient } @@ -34,7 +28,7 @@ export class MetaphorSearchTool extends BaseTask< } public override get nameForModel(): string { - return 'metaphor_web_search' + return 'metaphorWebSearch' } protected override async _call( diff --git a/src/tools/novu.ts b/src/tools/novu.ts index 2fec0283..46b9ec71 100644 --- a/src/tools/novu.ts +++ b/src/tools/novu.ts @@ -1,7 +1,6 @@ import { z } from 'zod' import * as types from '@/types' -import { Agentic } from '@/agentic' import { NovuClient } from '@/services/novu' import { BaseTask } from '@/task' @@ -39,18 +38,15 @@ export class NovuNotificationTool extends BaseTask< NovuNotificationToolInput, NovuNotificationToolOutput > { - _novuClient: NovuClient + protected _novuClient: NovuClient constructor({ - agentic, - novuClient = new NovuClient() + novuClient = new NovuClient(), + ...opts }: { - agentic: Agentic novuClient?: NovuClient - }) { - super({ - agentic - }) + } & types.BaseTaskOptions = {}) { + super(opts) this._novuClient = novuClient } @@ -64,7 +60,7 @@ export class NovuNotificationTool extends BaseTask< } public override get nameForModel(): string { - return 'novu_send_notification' + return 'novuSendNotification' } protected override async _call( diff --git a/src/tools/weather.ts b/src/tools/weather.ts index f84c45f9..58edc322 100644 --- a/src/tools/weather.ts +++ b/src/tools/weather.ts @@ -9,7 +9,12 @@ export const WeatherInputSchema = z.object({ .string() .describe( 'Location to get the weather for. Can be a city name like "Paris", a zipcode like "53121", an international postal code like "SW1", or a latitude and longitude like "48.8567,2.3508"' - ) + ), + + units: z + .union([z.literal('imperial'), z.literal('metric')]) + .default('imperial') + .optional() }) export type WeatherInput = z.infer @@ -33,8 +38,8 @@ const ConditionSchema = z.object({ const CurrentSchema = z.object({ last_updated_epoch: z.number(), last_updated: z.string(), - temp_c: z.number(), - temp_f: z.number(), + temp_c: z.number().describe('temperature in celsius'), + temp_f: z.number().describe('temperature in fahrenheit'), is_day: z.number(), condition: ConditionSchema, wind_mph: z.number(), @@ -65,14 +70,19 @@ export type WeatherOutput = z.infer export class WeatherTool extends BaseTask { client: WeatherClient - constructor( - opts: { - weather: WeatherClient - } & types.BaseTaskOptions - ) { + constructor({ + weather = new WeatherClient({ apiKey: process.env.WEATHER_API_KEY }), + ...opts + }: { + weather?: WeatherClient + } & types.BaseTaskOptions = {}) { super(opts) - this.client = opts.weather + if (!weather) { + throw new Error(`Error WeatherTool missing required "weather" client`) + } + + this.client = weather } public override get inputSchema() { @@ -92,7 +102,7 @@ export class WeatherTool extends BaseTask { } public override get descForModel(): string { - return 'Useful for getting the current weather at a location' + return 'Useful for getting the current weather at a location.' } protected override async _call( diff --git a/src/types.ts b/src/types.ts index 37c16bbf..c18c2269 100644 --- a/src/types.ts +++ b/src/types.ts @@ -20,7 +20,7 @@ export type SafeParsedData = T extends ZodTypeAny : never export interface BaseTaskOptions { - agentic: Agentic + agentic?: Agentic timeoutMs?: number retryConfig?: RetryConfig diff --git a/.snapshots/test/llms/llm-utils.test.ts.md b/test/.snapshots/test/llms/llm-utils.test.ts.md similarity index 100% rename from .snapshots/test/llms/llm-utils.test.ts.md rename to test/.snapshots/test/llms/llm-utils.test.ts.md diff --git a/.snapshots/test/llms/llm-utils.test.ts.snap b/test/.snapshots/test/llms/llm-utils.test.ts.snap similarity index 100% rename from .snapshots/test/llms/llm-utils.test.ts.snap rename to test/.snapshots/test/llms/llm-utils.test.ts.snap