From a4d89fcba6ca9507fae232000ef4a42906f67c03 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Wed, 14 Jun 2023 00:56:27 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llms/anthropic.ts | 6 +--- src/llms/chat.ts | 83 +++++++++++++++++++++++++++++++++++++++---- src/llms/llm.ts | 2 +- src/llms/openai.ts | 12 ++++--- 4 files changed, 86 insertions(+), 17 deletions(-) diff --git a/src/llms/anthropic.ts b/src/llms/anthropic.ts index e7da9fa..b4e9372 100644 --- a/src/llms/anthropic.ts +++ b/src/llms/anthropic.ts @@ -48,11 +48,7 @@ export class AnthropicChatModel< } public override get nameForModel(): string { - return 'anthropic_chat' - } - - public override get nameForHuman(): string { - return 'AnthropicChatModel' + return 'anthropicChatCompletion' } protected override async _createChatCompletion( diff --git a/src/llms/chat.ts b/src/llms/chat.ts index 5a1ade2..7ebb74f 100644 --- a/src/llms/chat.ts +++ b/src/llms/chat.ts @@ -6,15 +6,18 @@ import { printNode, zodToTs } from 'zod-to-ts' import * as errors from '@/errors' import * as types from '@/types' +import { BaseTask } from '@/task' import { getCompiledTemplate } from '@/template' import { extractJSONArrayFromString, extractJSONObjectFromString } from '@/utils' -import { BaseTask } from '../task' import { BaseLLM } from './llm' -import { getNumTokensForChatMessages } from './llm-utils' +import { + getChatMessageFunctionDefinitionsFromTasks, + getNumTokensForChatMessages +} from './llm-utils' export abstract class BaseChatModel< TInput extends void | types.JsonObject = void, @@ -79,7 +82,8 @@ export abstract class BaseChatModel< } protected abstract _createChatCompletion( - messages: types.ChatMessage[] + messages: types.ChatMessage[], + functions?: types.openai.ChatMessageFunction[] ): Promise> public async buildMessages( @@ -100,7 +104,7 @@ export abstract class BaseChatModel< : '' } }) - .filter((message) => message.content) + .filter((message) => message.content || message.function_call) if (this._examples?.length) { // TODO: smarter example selection @@ -161,9 +165,76 @@ export abstract class BaseChatModel< console.log('>>>') console.log(messages) - const completion = await this._createChatCompletion(messages) + let functions = this._modelParams?.functions + let isUsingTools = false + + if (this.supportsTools) { + if (this._tools?.length) { + if (functions?.length) { + throw new Error(`Cannot specify both tools and functions`) + } + + functions = getChatMessageFunctionDefinitionsFromTasks(this._tools) + isUsingTools = true + } + } + + const completion = await this._createChatCompletion(messages, functions) ctx.metadata.completion = completion + if (completion.message.function_call) { + const functionCall = completion.message.function_call + + if (isUsingTools) { + const tool = this._tools!.find( + (tool) => tool.nameForModel === functionCall.name + ) + + if (!tool) { + throw new errors.OutputValidationError( + `Unrecognized function call "${functionCall.name}"` + ) + } + + let functionArguments: any + try { + functionArguments = JSON.parse(jsonrepair(functionCall.arguments)) + } catch (err: any) { + if (err instanceof JSONRepairError) { + throw new errors.OutputValidationError(err.message, { cause: err }) + } else if (err instanceof SyntaxError) { + throw new errors.OutputValidationError( + `Invalid JSON object: ${err.message}`, + { cause: err } + ) + } else { + throw err + } + } + + // TODO: handle sub-task errors gracefully + const toolCallResponse = await tool.callWithMetadata(functionArguments) + + // TODO: handle result as string or JSON + // TODO: + const taskCallContent = JSON.stringify( + toolCallResponse.result, + null, + 1 + ).replaceAll(/\n ?/gm, ' ') + + messages.push(completion.message) + messages.push({ + role: 'function', + name: functionCall.name, + content: taskCallContent + }) + + // TODO: re-invoke completion with new messages + throw new Error('TODO') + } + } + let output: any = completion.message.content console.log('===') @@ -241,7 +312,7 @@ export abstract class BaseChatModel< } } - // TODO: this doesn't bode well, batman... + // TODO: fix typescript issue here with recursive types const safeResult = (outputSchema.safeParse as any)(output) if (!safeResult.success) { diff --git a/src/llms/llm.ts b/src/llms/llm.ts index 98e3c27..37583e5 100644 --- a/src/llms/llm.ts +++ b/src/llms/llm.ts @@ -78,7 +78,7 @@ export abstract class BaseLLM< } public override get nameForHuman(): string { - return `${this._provider}:chat:${this._model}` + return `${this.constructor.name} ${this._model}` } examples(examples: types.LLMExample[]): this { diff --git a/src/llms/openai.ts b/src/llms/openai.ts index e26cd29..c1b3920 100644 --- a/src/llms/openai.ts +++ b/src/llms/openai.ts @@ -1,4 +1,4 @@ -import { type SetOptional } from 'type-fest' +import type { SetOptional } from 'type-fest' import * as types from '@/types' import { DEFAULT_OPENAI_MODEL } from '@/constants' @@ -61,11 +61,11 @@ export class OpenAIChatModel< } public override get nameForModel(): string { - return 'openai_chat' + return 'openaiChatCompletion' } public override get nameForHuman(): string { - return 'OpenAIChatModel' + return `OpenAIChatModel ${this._model}` } public override get supportsTools(): boolean { @@ -73,14 +73,16 @@ export class OpenAIChatModel< } protected override async _createChatCompletion( - messages: types.ChatMessage[] + messages: types.ChatMessage[], + functions?: types.openai.ChatMessageFunction[] ): Promise< types.BaseChatCompletionResponse > { const res = await this._client.createChatCompletion({ ...this._modelParams, model: this._model, - messages + messages, + functions }) return res