diff --git a/examples/functions.ts b/examples/functions.ts index ee28cb7..8b239b0 100644 --- a/examples/functions.ts +++ b/examples/functions.ts @@ -9,7 +9,7 @@ async function main() { const agentic = new Agentic({ openai }) const example = await agentic - .gpt4('What is 5 * 50?') + .gpt3('What is 5 * 50?') .tools([new CalculatorTool({ agentic })]) .output(z.object({ answer: z.number() })) .call() diff --git a/src/agentic.ts b/src/agentic.ts index 4e2092b..dcd757b 100644 --- a/src/agentic.ts +++ b/src/agentic.ts @@ -1,6 +1,6 @@ import * as types from '@/types' import { DEFAULT_OPENAI_MODEL } from '@/constants' -import { OpenAIChatModel } from '@/llms/openai' +import { OpenAIChatCompletion } from '@/llms/openai' import { HumanFeedbackMechanism, @@ -77,7 +77,7 @@ export class Agentic { return this._idGeneratorFn } - llm( + openaiChatCompletion( promptOrChatCompletionParams: | string | Partial // TODO: make more strict @@ -101,13 +101,16 @@ export class Agentic { } } - return new OpenAIChatModel({ + return new OpenAIChatCompletion({ agentic: this, ...(this._openaiModelDefaults as any), // TODO ...options }) } + /** + * Shortcut for creating an OpenAI chat completion call with the `gpt-3.5-turbo` model. + */ gpt3( promptOrChatCompletionParams: | string @@ -132,7 +135,7 @@ export class Agentic { } } - return new OpenAIChatModel({ + return new OpenAIChatCompletion({ agentic: this, ...(this._openaiModelDefaults as any), // TODO model: 'gpt-3.5-turbo', @@ -140,6 +143,9 @@ export class Agentic { }) } + /** + * Shortcut for creating an OpenAI chat completion call with the `gpt-4` model. + */ gpt4( promptOrChatCompletionParams: | string @@ -164,7 +170,7 @@ export class Agentic { } } - return new OpenAIChatModel({ + return new OpenAIChatCompletion({ agentic: this, ...(this._openaiModelDefaults as any), // TODO model: 'gpt-4', diff --git a/src/llms/anthropic.ts b/src/llms/anthropic.ts index b4e9372..52a21c7 100644 --- a/src/llms/anthropic.ts +++ b/src/llms/anthropic.ts @@ -4,14 +4,14 @@ import { type SetOptional } from 'type-fest' import * as types from '@/types' import { DEFAULT_ANTHROPIC_MODEL } from '@/constants' -import { BaseChatModel } from './chat' +import { BaseChatCompletion } from './chat' const defaultStopSequences = [anthropic.HUMAN_PROMPT] -export class AnthropicChatModel< +export class AnthropicChatCompletion< TInput extends void | types.JsonObject = any, TOutput extends types.JsonValue = string -> extends BaseChatModel< +> extends BaseChatCompletion< TInput, TOutput, SetOptional< @@ -42,7 +42,7 @@ export class AnthropicChatModel< this._client = this._agentic.anthropic } else { throw new Error( - 'AnthropicChatModel requires an Anthropic client to be configured on the Agentic runtime' + 'AnthropicChatCompletion requires an Anthropic client to be configured on the Agentic runtime' ) } } @@ -91,8 +91,8 @@ export class AnthropicChatModel< } } - public override clone(): AnthropicChatModel { - return new AnthropicChatModel({ + public override clone(): AnthropicChatCompletion { + return new AnthropicChatCompletion({ agentic: this._agentic, timeoutMs: this._timeoutMs, retryConfig: this._retryConfig, diff --git a/src/llms/chat.ts b/src/llms/chat.ts index 251ffc6..acd925e 100644 --- a/src/llms/chat.ts +++ b/src/llms/chat.ts @@ -19,7 +19,7 @@ import { getNumTokensForChatMessages } from './llm-utils' -export abstract class BaseChatModel< +export abstract class BaseChatCompletion< TInput extends void | types.JsonObject = void, TOutput extends types.JsonValue = string, TModelParams extends Record = Record, @@ -43,8 +43,8 @@ export abstract class BaseChatModel< // TODO: use polymorphic `this` type to return correct BaseLLM subclass type input( inputSchema: ZodType - ): BaseChatModel { - const refinedInstance = this as unknown as BaseChatModel< + ): BaseChatCompletion { + const refinedInstance = this as unknown as BaseChatCompletion< U, TOutput, TModelParams @@ -56,8 +56,8 @@ export abstract class BaseChatModel< // TODO: use polymorphic `this` type to return correct BaseLLM subclass type output( outputSchema: ZodType - ): BaseChatModel { - const refinedInstance = this as unknown as BaseChatModel< + ): BaseChatCompletion { + const refinedInstance = this as unknown as BaseChatCompletion< TInput, U, TModelParams @@ -66,11 +66,9 @@ export abstract class BaseChatModel< return refinedInstance } - tools(tools: BaseTask[]): this { + public tools(tools: BaseTask[]): this { if (!this.supportsTools) { - throw new Error( - `This Chat model "${this.nameForHuman}" does not support tools` - ) + throw new Error(`This LLM "${this.nameForHuman}" does not support tools`) } this._tools = tools @@ -179,20 +177,30 @@ export abstract class BaseChatModel< } } - const completion = await this._createChatCompletion(messages, functions) - ctx.metadata.completion = completion + let output: any - if (completion.message.function_call) { - const functionCall = completion.message.function_call + do { + console.log('<<< completion', { messages, functions }) + const completion = await this._createChatCompletion(messages, functions) + console.log('>>> completion', completion.message) + ctx.metadata.completion = completion + + if (completion.message.function_call) { + const functionCall = completion.message.function_call + + if (!isUsingTools) { + // TODO: not sure what we should do in this case... + output = functionCall + break + } - if (isUsingTools) { const tool = this._tools!.find( (tool) => tool.nameForModel === functionCall.name ) if (!tool) { throw new errors.OutputValidationError( - `Unrecognized function call "${functionCall.name}"` + `Function not found "${functionCall.name}"` ) } @@ -201,7 +209,9 @@ export abstract class BaseChatModel< functionArguments = JSON.parse(jsonrepair(functionCall.arguments)) } catch (err: any) { if (err instanceof JSONRepairError) { - throw new errors.OutputValidationError(err.message, { cause: err }) + throw new errors.OutputValidationError(err.message, { + cause: err + }) } else if (err instanceof SyntaxError) { throw new errors.OutputValidationError( `Invalid JSON object: ${err.message}`, @@ -212,11 +222,22 @@ export abstract class BaseChatModel< } } + console.log('>>> sub-task', { + task: functionCall.name, + input: functionArguments + }) + // TODO: handle sub-task errors gracefully const toolCallResponse = await tool.callWithMetadata(functionArguments) + console.log('<<< sub-task', { + task: functionCall.name, + input: functionArguments, + output: toolCallResponse.result + }) + // TODO: handle result as string or JSON - // TODO: + // TODO: better support for JSON spacing const taskCallContent = JSON.stringify( toolCallResponse.result, null, @@ -231,12 +252,11 @@ export abstract class BaseChatModel< content: taskCallContent }) - // TODO: re-invoke completion with new messages - throw new Error('TODO') + continue } - } - let output: any = completion.message.content + output = completion.message.content + } while (output === undefined) console.log('===') console.log(output) diff --git a/src/llms/openai.ts b/src/llms/openai.ts index ba1bcd2..1ec9023 100644 --- a/src/llms/openai.ts +++ b/src/llms/openai.ts @@ -2,8 +2,9 @@ import type { SetOptional } from 'type-fest' import * as types from '@/types' import { DEFAULT_OPENAI_MODEL } from '@/constants' +import { BaseTask } from '@/task' -import { BaseChatModel } from './chat' +import { BaseChatCompletion } from './chat' const openaiModelsSupportingFunctions = new Set([ 'gpt-4-0613', @@ -12,10 +13,10 @@ const openaiModelsSupportingFunctions = new Set([ 'gpt-3.5-turbo-16k' ]) -export class OpenAIChatModel< +export class OpenAIChatCompletion< TInput extends void | types.JsonObject = any, TOutput extends types.JsonValue = string -> extends BaseChatModel< +> extends BaseChatCompletion< TInput, TOutput, SetOptional, 'model'>, @@ -41,7 +42,7 @@ export class OpenAIChatModel< this._client = this._agentic.openai } else { throw new Error( - 'OpenAIChatModel requires an OpenAI client to be configured on the Agentic runtime' + 'OpenAIChatCompletion requires an OpenAI client to be configured on the Agentic runtime' ) } @@ -65,7 +66,27 @@ export class OpenAIChatModel< } public override get nameForHuman(): string { - return `OpenAIChatModel ${this._model}` + return `OpenAIChatCompletion ${this._model}` + } + + public override tools(tools: BaseTask[]): this { + if (!this.supportsTools) { + switch (this._model) { + case 'gpt-4': + this._model = 'gpt-4-0613' + break + + case 'gpt-4-32k': + this._model = 'gpt-4-32k-0613' + break + + case 'gpt-3.5-turbo': + this._model = 'gpt-3.5-turbo-0613' + break + } + } + + return super.tools(tools) } public override get supportsTools(): boolean { @@ -86,8 +107,8 @@ export class OpenAIChatModel< }) } - public override clone(): OpenAIChatModel { - return new OpenAIChatModel({ + public override clone(): OpenAIChatCompletion { + return new OpenAIChatCompletion({ agentic: this._agentic, timeoutMs: this._timeoutMs, retryConfig: this._retryConfig, diff --git a/test/llms/anthropic.test.ts b/test/llms/anthropic.test.ts index bf940ed..158b3ca 100644 --- a/test/llms/anthropic.test.ts +++ b/test/llms/anthropic.test.ts @@ -1,15 +1,15 @@ import test from 'ava' import { expectTypeOf } from 'expect-type' -import { AnthropicChatModel } from '@/llms/anthropic' +import { AnthropicChatCompletion } from '@/llms/anthropic' import { createTestAgenticRuntime } from '../_utils' -test('AnthropicChatModel ⇒ string output', async (t) => { +test('AnthropicChatCompletion ⇒ string output', async (t) => { t.timeout(2 * 60 * 1000) const agentic = createTestAgenticRuntime() - const builder = new AnthropicChatModel({ + const builder = new AnthropicChatCompletion({ agentic, modelParams: { temperature: 0, diff --git a/test/llms/openai.test.ts b/test/llms/openai.test.ts index c08bcc1..1b8da04 100644 --- a/test/llms/openai.test.ts +++ b/test/llms/openai.test.ts @@ -4,16 +4,16 @@ import sinon from 'sinon' import { z } from 'zod' import { OutputValidationError, TemplateValidationError } from '@/errors' -import { BaseChatModel, OpenAIChatModel } from '@/llms' +import { BaseChatCompletion, OpenAIChatCompletion } from '@/llms' import { createTestAgenticRuntime } from '../_utils' -test('OpenAIChatModel ⇒ types', async (t) => { +test('OpenAIChatCompletion ⇒ types', async (t) => { const agentic = createTestAgenticRuntime() const b = agentic.gpt4('test') t.pass() - expectTypeOf(b).toMatchTypeOf>() + expectTypeOf(b).toMatchTypeOf>() expectTypeOf( b.input( @@ -22,7 +22,7 @@ test('OpenAIChatModel ⇒ types', async (t) => { }) ) ).toMatchTypeOf< - BaseChatModel< + BaseChatCompletion< { foo: string }, @@ -37,7 +37,7 @@ test('OpenAIChatModel ⇒ types', async (t) => { }) ) ).toMatchTypeOf< - BaseChatModel< + BaseChatCompletion< any, { bar?: string @@ -46,11 +46,11 @@ test('OpenAIChatModel ⇒ types', async (t) => { >() }) -test('OpenAIChatModel ⇒ string output', async (t) => { +test('OpenAIChatCompletion ⇒ string output', async (t) => { t.timeout(2 * 60 * 1000) const agentic = createTestAgenticRuntime() - const builder = new OpenAIChatModel({ + const builder = new OpenAIChatCompletion({ agentic, modelParams: { temperature: 0, @@ -80,11 +80,11 @@ test('OpenAIChatModel ⇒ string output', async (t) => { expectTypeOf(result2).toMatchTypeOf() }) -test('OpenAIChatModel ⇒ json output', async (t) => { +test('OpenAIChatCompletion ⇒ json output', async (t) => { t.timeout(2 * 60 * 1000) const agentic = createTestAgenticRuntime() - const builder = new OpenAIChatModel({ + const builder = new OpenAIChatCompletion({ agentic, modelParams: { temperature: 0.5 @@ -106,11 +106,11 @@ test('OpenAIChatModel ⇒ json output', async (t) => { expectTypeOf(result).toMatchTypeOf<{ foo: string; bar: number }>() }) -test('OpenAIChatModel ⇒ boolean output', async (t) => { +test('OpenAIChatCompletion ⇒ boolean output', async (t) => { t.timeout(2 * 60 * 1000) const agentic = createTestAgenticRuntime() - const builder = new OpenAIChatModel({ + const builder = new OpenAIChatCompletion({ agentic, modelParams: { temperature: 0, @@ -130,11 +130,11 @@ test('OpenAIChatModel ⇒ boolean output', async (t) => { expectTypeOf(result).toMatchTypeOf() }) -test('OpenAIChatModel ⇒ retry logic', async (t) => { +test('OpenAIChatCompletion ⇒ retry logic', async (t) => { t.timeout(2 * 60 * 1000) const agentic = createTestAgenticRuntime() - const builder = new OpenAIChatModel({ + const builder = new OpenAIChatCompletion({ agentic, modelParams: { temperature: 0, @@ -167,7 +167,7 @@ test('OpenAIChatModel ⇒ retry logic', async (t) => { t.is(fakeCall.callCount, 3) }) -test('OpenAIChatModel ⇒ template variables', async (t) => { +test('OpenAIChatCompletion ⇒ template variables', async (t) => { t.timeout(2 * 60 * 1000) const agentic = createTestAgenticRuntime() @@ -203,7 +203,7 @@ test('OpenAIChatModel ⇒ template variables', async (t) => { } }) -test('OpenAIChatModel ⇒ missing template variable', async (t) => { +test('OpenAIChatCompletion ⇒ missing template variable', async (t) => { t.timeout(2 * 60 * 1000) const agentic = createTestAgenticRuntime()