From 23593d79b90c11e3583599e42f553bb062ac3d8e Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Tue, 13 Jun 2023 20:14:10 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9B=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- legacy/src/llms/chat.ts | 10 ++++++++++ legacy/src/llms/openai.ts | 28 +++++++++++++++++++++++++++- legacy/test/services/openai.test.ts | 4 +++- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/legacy/src/llms/chat.ts b/legacy/src/llms/chat.ts index 3032f2b5..edeb6686 100644 --- a/legacy/src/llms/chat.ts +++ b/legacy/src/llms/chat.ts @@ -62,6 +62,12 @@ export abstract class BaseChatModel< } tools(tools: BaseTask[]): this { + if (!this.supportsTools) { + throw new Error( + `This Chat model "${this.nameForHuman}" does not support tools` + ) + } + this._tools = tools return this } @@ -70,6 +76,10 @@ export abstract class BaseChatModel< messages: types.ChatMessage[] ): Promise> + public get supportsTools(): boolean { + return false + } + public async buildMessages( input?: TInput, ctx?: types.TaskCallContext diff --git a/legacy/src/llms/openai.ts b/legacy/src/llms/openai.ts index 95735461..95781013 100644 --- a/legacy/src/llms/openai.ts +++ b/legacy/src/llms/openai.ts @@ -5,6 +5,13 @@ import { DEFAULT_OPENAI_MODEL } from '@/constants' import { BaseChatModel } from './chat' +const openaiModelsSupportingFunctions = new Set([ + 'gpt-4-0613', + 'gpt-4-32k-0613', + 'gpt-3.5-turbo-0613', + 'gpt-3.5-turbo-16k' +]) + export class OpenAIChatModel< TInput = any, TOutput = string @@ -23,9 +30,10 @@ export class OpenAIChatModel< SetOptional, 'model'> > ) { + const model = options.modelParams?.model || DEFAULT_OPENAI_MODEL super({ provider: 'openai', - model: options.modelParams?.model || DEFAULT_OPENAI_MODEL, + model, ...options }) @@ -36,6 +44,20 @@ export class OpenAIChatModel< 'OpenAIChatModel requires an OpenAI client to be configured on the Agentic runtime' ) } + + if (!this.supportsTools) { + if (this._tools) { + throw new Error( + `This OpenAI chat model "${this.nameForHuman}" does not support tools` + ) + } + + if (this._modelParams?.functions) { + throw new Error( + `This OpenAI chat model "${this.nameForHuman}" does not support functions` + ) + } + } } public override get nameForModel(): string { @@ -46,6 +68,10 @@ export class OpenAIChatModel< return 'OpenAIChatModel' } + public override get supportsTools(): boolean { + return openaiModelsSupportingFunctions.has(this._model) + } + protected override async _createChatCompletion( messages: types.ChatMessage[] ): Promise< diff --git a/legacy/test/services/openai.test.ts b/legacy/test/services/openai.test.ts index 9d84727f..b4222184 100644 --- a/legacy/test/services/openai.test.ts +++ b/legacy/test/services/openai.test.ts @@ -8,6 +8,7 @@ test('OpenAIClient - createChatCompletion - functions', async (t) => { const openai = createOpenAITestClient() const model = 'gpt-3.5-turbo-0613' + // const model = 'gpt-3.5-turbo-16k' const messages: types.ChatMessage[] = [ { role: 'user', @@ -47,7 +48,8 @@ test('OpenAIClient - createChatCompletion - functions', async (t) => { t.is(res0.message.function_call!.name, 'get_current_weather') const args = JSON.parse(res0.message.function_call!.arguments) - t.deepEqual(args, { location: 'Boston' }) + t.is(typeof args.location, 'string') + t.true(args.location.toLowerCase().includes('boston')) const weatherMock = { temperature: 22, unit: 'celsius', description: 'Sunny' }