old-agentic-v1^2
Travis Fischer 2023-06-14 00:56:27 -07:00
rodzic 8f4fde0260
commit a4d89fcba6
4 zmienionych plików z 86 dodań i 17 usunięć

Wyświetl plik

@ -48,11 +48,7 @@ export class AnthropicChatModel<
} }
public override get nameForModel(): string { public override get nameForModel(): string {
return 'anthropic_chat' return 'anthropicChatCompletion'
}
public override get nameForHuman(): string {
return 'AnthropicChatModel'
} }
protected override async _createChatCompletion( protected override async _createChatCompletion(

Wyświetl plik

@ -6,15 +6,18 @@ import { printNode, zodToTs } from 'zod-to-ts'
import * as errors from '@/errors' import * as errors from '@/errors'
import * as types from '@/types' import * as types from '@/types'
import { BaseTask } from '@/task'
import { getCompiledTemplate } from '@/template' import { getCompiledTemplate } from '@/template'
import { import {
extractJSONArrayFromString, extractJSONArrayFromString,
extractJSONObjectFromString extractJSONObjectFromString
} from '@/utils' } from '@/utils'
import { BaseTask } from '../task'
import { BaseLLM } from './llm' import { BaseLLM } from './llm'
import { getNumTokensForChatMessages } from './llm-utils' import {
getChatMessageFunctionDefinitionsFromTasks,
getNumTokensForChatMessages
} from './llm-utils'
export abstract class BaseChatModel< export abstract class BaseChatModel<
TInput extends void | types.JsonObject = void, TInput extends void | types.JsonObject = void,
@ -79,7 +82,8 @@ export abstract class BaseChatModel<
} }
protected abstract _createChatCompletion( protected abstract _createChatCompletion(
messages: types.ChatMessage[] messages: types.ChatMessage[],
functions?: types.openai.ChatMessageFunction[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>> ): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public async buildMessages( public async buildMessages(
@ -100,7 +104,7 @@ export abstract class BaseChatModel<
: '' : ''
} }
}) })
.filter((message) => message.content) .filter((message) => message.content || message.function_call)
if (this._examples?.length) { if (this._examples?.length) {
// TODO: smarter example selection // TODO: smarter example selection
@ -161,9 +165,76 @@ export abstract class BaseChatModel<
console.log('>>>') console.log('>>>')
console.log(messages) console.log(messages)
const completion = await this._createChatCompletion(messages) let functions = this._modelParams?.functions
let isUsingTools = false
if (this.supportsTools) {
if (this._tools?.length) {
if (functions?.length) {
throw new Error(`Cannot specify both tools and functions`)
}
functions = getChatMessageFunctionDefinitionsFromTasks(this._tools)
isUsingTools = true
}
}
const completion = await this._createChatCompletion(messages, functions)
ctx.metadata.completion = completion ctx.metadata.completion = completion
if (completion.message.function_call) {
const functionCall = completion.message.function_call
if (isUsingTools) {
const tool = this._tools!.find(
(tool) => tool.nameForModel === functionCall.name
)
if (!tool) {
throw new errors.OutputValidationError(
`Unrecognized function call "${functionCall.name}"`
)
}
let functionArguments: any
try {
functionArguments = JSON.parse(jsonrepair(functionCall.arguments))
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON object: ${err.message}`,
{ cause: err }
)
} else {
throw err
}
}
// TODO: handle sub-task errors gracefully
const toolCallResponse = await tool.callWithMetadata(functionArguments)
// TODO: handle result as string or JSON
// TODO:
const taskCallContent = JSON.stringify(
toolCallResponse.result,
null,
1
).replaceAll(/\n ?/gm, ' ')
messages.push(completion.message)
messages.push({
role: 'function',
name: functionCall.name,
content: taskCallContent
})
// TODO: re-invoke completion with new messages
throw new Error('TODO')
}
}
let output: any = completion.message.content let output: any = completion.message.content
console.log('===') console.log('===')
@ -241,7 +312,7 @@ export abstract class BaseChatModel<
} }
} }
// TODO: this doesn't bode well, batman... // TODO: fix typescript issue here with recursive types
const safeResult = (outputSchema.safeParse as any)(output) const safeResult = (outputSchema.safeParse as any)(output)
if (!safeResult.success) { if (!safeResult.success) {

Wyświetl plik

@ -78,7 +78,7 @@ export abstract class BaseLLM<
} }
public override get nameForHuman(): string { public override get nameForHuman(): string {
return `${this._provider}:chat:${this._model}` return `${this.constructor.name} ${this._model}`
} }
examples(examples: types.LLMExample[]): this { examples(examples: types.LLMExample[]): this {

Wyświetl plik

@ -1,4 +1,4 @@
import { type SetOptional } from 'type-fest' import type { SetOptional } from 'type-fest'
import * as types from '@/types' import * as types from '@/types'
import { DEFAULT_OPENAI_MODEL } from '@/constants' import { DEFAULT_OPENAI_MODEL } from '@/constants'
@ -61,11 +61,11 @@ export class OpenAIChatModel<
} }
public override get nameForModel(): string { public override get nameForModel(): string {
return 'openai_chat' return 'openaiChatCompletion'
} }
public override get nameForHuman(): string { public override get nameForHuman(): string {
return 'OpenAIChatModel' return `OpenAIChatModel ${this._model}`
} }
public override get supportsTools(): boolean { public override get supportsTools(): boolean {
@ -73,14 +73,16 @@ export class OpenAIChatModel<
} }
protected override async _createChatCompletion( protected override async _createChatCompletion(
messages: types.ChatMessage[] messages: types.ChatMessage[],
functions?: types.openai.ChatMessageFunction[]
): Promise< ): Promise<
types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse> types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse>
> { > {
const res = await this._client.createChatCompletion({ const res = await this._client.createChatCompletion({
...this._modelParams, ...this._modelParams,
model: this._model, model: this._model,
messages messages,
functions
}) })
return res return res