old-agentic-v1^2
Travis Fischer 2023-06-13 20:14:10 -07:00
rodzic 7afbc0d259
commit c9d6619855
3 zmienionych plików z 40 dodań i 2 usunięć

Wyświetl plik

@ -62,6 +62,12 @@ export abstract class BaseChatModel<
}
tools(tools: BaseTask<any, any>[]): this {
if (!this.supportsTools) {
throw new Error(
`This Chat model "${this.nameForHuman}" does not support tools`
)
}
this._tools = tools
return this
}
@ -70,6 +76,10 @@ export abstract class BaseChatModel<
messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public get supportsTools(): boolean {
return false
}
public async buildMessages(
input?: TInput,
ctx?: types.TaskCallContext<TInput>

Wyświetl plik

@ -5,6 +5,13 @@ import { DEFAULT_OPENAI_MODEL } from '@/constants'
import { BaseChatModel } from './chat'
const openaiModelsSupportingFunctions = new Set([
'gpt-4-0613',
'gpt-4-32k-0613',
'gpt-3.5-turbo-0613',
'gpt-3.5-turbo-16k'
])
export class OpenAIChatModel<
TInput = any,
TOutput = string
@ -23,9 +30,10 @@ export class OpenAIChatModel<
SetOptional<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>
>
) {
const model = options.modelParams?.model || DEFAULT_OPENAI_MODEL
super({
provider: 'openai',
model: options.modelParams?.model || DEFAULT_OPENAI_MODEL,
model,
...options
})
@ -36,6 +44,20 @@ export class OpenAIChatModel<
'OpenAIChatModel requires an OpenAI client to be configured on the Agentic runtime'
)
}
if (!this.supportsTools) {
if (this._tools) {
throw new Error(
`This OpenAI chat model "${this.nameForHuman}" does not support tools`
)
}
if (this._modelParams?.functions) {
throw new Error(
`This OpenAI chat model "${this.nameForHuman}" does not support functions`
)
}
}
}
public override get nameForModel(): string {
@ -46,6 +68,10 @@ export class OpenAIChatModel<
return 'OpenAIChatModel'
}
public override get supportsTools(): boolean {
return openaiModelsSupportingFunctions.has(this._model)
}
protected override async _createChatCompletion(
messages: types.ChatMessage[]
): Promise<

Wyświetl plik

@ -8,6 +8,7 @@ test('OpenAIClient - createChatCompletion - functions', async (t) => {
const openai = createOpenAITestClient()
const model = 'gpt-3.5-turbo-0613'
// const model = 'gpt-3.5-turbo-16k'
const messages: types.ChatMessage[] = [
{
role: 'user',
@ -47,7 +48,8 @@ test('OpenAIClient - createChatCompletion - functions', async (t) => {
t.is(res0.message.function_call!.name, 'get_current_weather')
const args = JSON.parse(res0.message.function_call!.arguments)
t.deepEqual(args, { location: 'Boston' })
t.is(typeof args.location, 'string')
t.true(args.location.toLowerCase().includes('boston'))
const weatherMock = { temperature: 22, unit: 'celsius', description: 'Sunny' }