old-agentic-v1^2
Travis Fischer 2023-06-14 00:56:27 -07:00
rodzic 8f4fde0260
commit a4d89fcba6
4 zmienionych plików z 86 dodań i 17 usunięć

Wyświetl plik

@ -48,11 +48,7 @@ export class AnthropicChatModel<
}
public override get nameForModel(): string {
return 'anthropic_chat'
}
public override get nameForHuman(): string {
return 'AnthropicChatModel'
return 'anthropicChatCompletion'
}
protected override async _createChatCompletion(

Wyświetl plik

@ -6,15 +6,18 @@ import { printNode, zodToTs } from 'zod-to-ts'
import * as errors from '@/errors'
import * as types from '@/types'
import { BaseTask } from '@/task'
import { getCompiledTemplate } from '@/template'
import {
extractJSONArrayFromString,
extractJSONObjectFromString
} from '@/utils'
import { BaseTask } from '../task'
import { BaseLLM } from './llm'
import { getNumTokensForChatMessages } from './llm-utils'
import {
getChatMessageFunctionDefinitionsFromTasks,
getNumTokensForChatMessages
} from './llm-utils'
export abstract class BaseChatModel<
TInput extends void | types.JsonObject = void,
@ -79,7 +82,8 @@ export abstract class BaseChatModel<
}
protected abstract _createChatCompletion(
messages: types.ChatMessage[]
messages: types.ChatMessage[],
functions?: types.openai.ChatMessageFunction[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public async buildMessages(
@ -100,7 +104,7 @@ export abstract class BaseChatModel<
: ''
}
})
.filter((message) => message.content)
.filter((message) => message.content || message.function_call)
if (this._examples?.length) {
// TODO: smarter example selection
@ -161,9 +165,76 @@ export abstract class BaseChatModel<
console.log('>>>')
console.log(messages)
const completion = await this._createChatCompletion(messages)
let functions = this._modelParams?.functions
let isUsingTools = false
if (this.supportsTools) {
if (this._tools?.length) {
if (functions?.length) {
throw new Error(`Cannot specify both tools and functions`)
}
functions = getChatMessageFunctionDefinitionsFromTasks(this._tools)
isUsingTools = true
}
}
const completion = await this._createChatCompletion(messages, functions)
ctx.metadata.completion = completion
if (completion.message.function_call) {
const functionCall = completion.message.function_call
if (isUsingTools) {
const tool = this._tools!.find(
(tool) => tool.nameForModel === functionCall.name
)
if (!tool) {
throw new errors.OutputValidationError(
`Unrecognized function call "${functionCall.name}"`
)
}
let functionArguments: any
try {
functionArguments = JSON.parse(jsonrepair(functionCall.arguments))
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON object: ${err.message}`,
{ cause: err }
)
} else {
throw err
}
}
// TODO: handle sub-task errors gracefully
const toolCallResponse = await tool.callWithMetadata(functionArguments)
// TODO: handle result as string or JSON
// TODO:
const taskCallContent = JSON.stringify(
toolCallResponse.result,
null,
1
).replaceAll(/\n ?/gm, ' ')
messages.push(completion.message)
messages.push({
role: 'function',
name: functionCall.name,
content: taskCallContent
})
// TODO: re-invoke completion with new messages
throw new Error('TODO')
}
}
let output: any = completion.message.content
console.log('===')
@ -241,7 +312,7 @@ export abstract class BaseChatModel<
}
}
// TODO: this doesn't bode well, batman...
// TODO: fix typescript issue here with recursive types
const safeResult = (outputSchema.safeParse as any)(output)
if (!safeResult.success) {

Wyświetl plik

@ -78,7 +78,7 @@ export abstract class BaseLLM<
}
public override get nameForHuman(): string {
return `${this._provider}:chat:${this._model}`
return `${this.constructor.name} ${this._model}`
}
examples(examples: types.LLMExample[]): this {

Wyświetl plik

@ -1,4 +1,4 @@
import { type SetOptional } from 'type-fest'
import type { SetOptional } from 'type-fest'
import * as types from '@/types'
import { DEFAULT_OPENAI_MODEL } from '@/constants'
@ -61,11 +61,11 @@ export class OpenAIChatModel<
}
public override get nameForModel(): string {
return 'openai_chat'
return 'openaiChatCompletion'
}
public override get nameForHuman(): string {
return 'OpenAIChatModel'
return `OpenAIChatModel ${this._model}`
}
public override get supportsTools(): boolean {
@ -73,14 +73,16 @@ export class OpenAIChatModel<
}
protected override async _createChatCompletion(
messages: types.ChatMessage[]
messages: types.ChatMessage[],
functions?: types.openai.ChatMessageFunction[]
): Promise<
types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse>
> {
const res = await this._client.createChatCompletion({
...this._modelParams,
model: this._model,
messages
messages,
functions
})
return res