old-agentic-v1^2
Travis Fischer 2023-06-13 20:14:10 -07:00
rodzic 7afbc0d259
commit c9d6619855
3 zmienionych plików z 40 dodań i 2 usunięć

Wyświetl plik

@ -62,6 +62,12 @@ export abstract class BaseChatModel<
} }
tools(tools: BaseTask<any, any>[]): this { tools(tools: BaseTask<any, any>[]): this {
if (!this.supportsTools) {
throw new Error(
`This Chat model "${this.nameForHuman}" does not support tools`
)
}
this._tools = tools this._tools = tools
return this return this
} }
@ -70,6 +76,10 @@ export abstract class BaseChatModel<
messages: types.ChatMessage[] messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>> ): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public get supportsTools(): boolean {
return false
}
public async buildMessages( public async buildMessages(
input?: TInput, input?: TInput,
ctx?: types.TaskCallContext<TInput> ctx?: types.TaskCallContext<TInput>

Wyświetl plik

@ -5,6 +5,13 @@ import { DEFAULT_OPENAI_MODEL } from '@/constants'
import { BaseChatModel } from './chat' import { BaseChatModel } from './chat'
const openaiModelsSupportingFunctions = new Set([
'gpt-4-0613',
'gpt-4-32k-0613',
'gpt-3.5-turbo-0613',
'gpt-3.5-turbo-16k'
])
export class OpenAIChatModel< export class OpenAIChatModel<
TInput = any, TInput = any,
TOutput = string TOutput = string
@ -23,9 +30,10 @@ export class OpenAIChatModel<
SetOptional<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'> SetOptional<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>
> >
) { ) {
const model = options.modelParams?.model || DEFAULT_OPENAI_MODEL
super({ super({
provider: 'openai', provider: 'openai',
model: options.modelParams?.model || DEFAULT_OPENAI_MODEL, model,
...options ...options
}) })
@ -36,6 +44,20 @@ export class OpenAIChatModel<
'OpenAIChatModel requires an OpenAI client to be configured on the Agentic runtime' 'OpenAIChatModel requires an OpenAI client to be configured on the Agentic runtime'
) )
} }
if (!this.supportsTools) {
if (this._tools) {
throw new Error(
`This OpenAI chat model "${this.nameForHuman}" does not support tools`
)
}
if (this._modelParams?.functions) {
throw new Error(
`This OpenAI chat model "${this.nameForHuman}" does not support functions`
)
}
}
} }
public override get nameForModel(): string { public override get nameForModel(): string {
@ -46,6 +68,10 @@ export class OpenAIChatModel<
return 'OpenAIChatModel' return 'OpenAIChatModel'
} }
public override get supportsTools(): boolean {
return openaiModelsSupportingFunctions.has(this._model)
}
protected override async _createChatCompletion( protected override async _createChatCompletion(
messages: types.ChatMessage[] messages: types.ChatMessage[]
): Promise< ): Promise<

Wyświetl plik

@ -8,6 +8,7 @@ test('OpenAIClient - createChatCompletion - functions', async (t) => {
const openai = createOpenAITestClient() const openai = createOpenAITestClient()
const model = 'gpt-3.5-turbo-0613' const model = 'gpt-3.5-turbo-0613'
// const model = 'gpt-3.5-turbo-16k'
const messages: types.ChatMessage[] = [ const messages: types.ChatMessage[] = [
{ {
role: 'user', role: 'user',
@ -47,7 +48,8 @@ test('OpenAIClient - createChatCompletion - functions', async (t) => {
t.is(res0.message.function_call!.name, 'get_current_weather') t.is(res0.message.function_call!.name, 'get_current_weather')
const args = JSON.parse(res0.message.function_call!.arguments) const args = JSON.parse(res0.message.function_call!.arguments)
t.deepEqual(args, { location: 'Boston' }) t.is(typeof args.location, 'string')
t.true(args.location.toLowerCase().includes('boston'))
const weatherMock = { temperature: 22, unit: 'celsius', description: 'Sunny' } const weatherMock = { temperature: 22, unit: 'celsius', description: 'Sunny' }