kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
old-agentic-v1^2
rodzic
8f4fde0260
commit
a4d89fcba6
|
@ -48,11 +48,7 @@ export class AnthropicChatModel<
|
|||
}
|
||||
|
||||
public override get nameForModel(): string {
|
||||
return 'anthropic_chat'
|
||||
}
|
||||
|
||||
public override get nameForHuman(): string {
|
||||
return 'AnthropicChatModel'
|
||||
return 'anthropicChatCompletion'
|
||||
}
|
||||
|
||||
protected override async _createChatCompletion(
|
||||
|
|
|
@ -6,15 +6,18 @@ import { printNode, zodToTs } from 'zod-to-ts'
|
|||
|
||||
import * as errors from '@/errors'
|
||||
import * as types from '@/types'
|
||||
import { BaseTask } from '@/task'
|
||||
import { getCompiledTemplate } from '@/template'
|
||||
import {
|
||||
extractJSONArrayFromString,
|
||||
extractJSONObjectFromString
|
||||
} from '@/utils'
|
||||
|
||||
import { BaseTask } from '../task'
|
||||
import { BaseLLM } from './llm'
|
||||
import { getNumTokensForChatMessages } from './llm-utils'
|
||||
import {
|
||||
getChatMessageFunctionDefinitionsFromTasks,
|
||||
getNumTokensForChatMessages
|
||||
} from './llm-utils'
|
||||
|
||||
export abstract class BaseChatModel<
|
||||
TInput extends void | types.JsonObject = void,
|
||||
|
@ -79,7 +82,8 @@ export abstract class BaseChatModel<
|
|||
}
|
||||
|
||||
protected abstract _createChatCompletion(
|
||||
messages: types.ChatMessage[]
|
||||
messages: types.ChatMessage[],
|
||||
functions?: types.openai.ChatMessageFunction[]
|
||||
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
||||
|
||||
public async buildMessages(
|
||||
|
@ -100,7 +104,7 @@ export abstract class BaseChatModel<
|
|||
: ''
|
||||
}
|
||||
})
|
||||
.filter((message) => message.content)
|
||||
.filter((message) => message.content || message.function_call)
|
||||
|
||||
if (this._examples?.length) {
|
||||
// TODO: smarter example selection
|
||||
|
@ -161,9 +165,76 @@ export abstract class BaseChatModel<
|
|||
console.log('>>>')
|
||||
console.log(messages)
|
||||
|
||||
const completion = await this._createChatCompletion(messages)
|
||||
let functions = this._modelParams?.functions
|
||||
let isUsingTools = false
|
||||
|
||||
if (this.supportsTools) {
|
||||
if (this._tools?.length) {
|
||||
if (functions?.length) {
|
||||
throw new Error(`Cannot specify both tools and functions`)
|
||||
}
|
||||
|
||||
functions = getChatMessageFunctionDefinitionsFromTasks(this._tools)
|
||||
isUsingTools = true
|
||||
}
|
||||
}
|
||||
|
||||
const completion = await this._createChatCompletion(messages, functions)
|
||||
ctx.metadata.completion = completion
|
||||
|
||||
if (completion.message.function_call) {
|
||||
const functionCall = completion.message.function_call
|
||||
|
||||
if (isUsingTools) {
|
||||
const tool = this._tools!.find(
|
||||
(tool) => tool.nameForModel === functionCall.name
|
||||
)
|
||||
|
||||
if (!tool) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Unrecognized function call "${functionCall.name}"`
|
||||
)
|
||||
}
|
||||
|
||||
let functionArguments: any
|
||||
try {
|
||||
functionArguments = JSON.parse(jsonrepair(functionCall.arguments))
|
||||
} catch (err: any) {
|
||||
if (err instanceof JSONRepairError) {
|
||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
} else if (err instanceof SyntaxError) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid JSON object: ${err.message}`,
|
||||
{ cause: err }
|
||||
)
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle sub-task errors gracefully
|
||||
const toolCallResponse = await tool.callWithMetadata(functionArguments)
|
||||
|
||||
// TODO: handle result as string or JSON
|
||||
// TODO:
|
||||
const taskCallContent = JSON.stringify(
|
||||
toolCallResponse.result,
|
||||
null,
|
||||
1
|
||||
).replaceAll(/\n ?/gm, ' ')
|
||||
|
||||
messages.push(completion.message)
|
||||
messages.push({
|
||||
role: 'function',
|
||||
name: functionCall.name,
|
||||
content: taskCallContent
|
||||
})
|
||||
|
||||
// TODO: re-invoke completion with new messages
|
||||
throw new Error('TODO')
|
||||
}
|
||||
}
|
||||
|
||||
let output: any = completion.message.content
|
||||
|
||||
console.log('===')
|
||||
|
@ -241,7 +312,7 @@ export abstract class BaseChatModel<
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: this doesn't bode well, batman...
|
||||
// TODO: fix typescript issue here with recursive types
|
||||
const safeResult = (outputSchema.safeParse as any)(output)
|
||||
|
||||
if (!safeResult.success) {
|
||||
|
|
|
@ -78,7 +78,7 @@ export abstract class BaseLLM<
|
|||
}
|
||||
|
||||
public override get nameForHuman(): string {
|
||||
return `${this._provider}:chat:${this._model}`
|
||||
return `${this.constructor.name} ${this._model}`
|
||||
}
|
||||
|
||||
examples(examples: types.LLMExample[]): this {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { type SetOptional } from 'type-fest'
|
||||
import type { SetOptional } from 'type-fest'
|
||||
|
||||
import * as types from '@/types'
|
||||
import { DEFAULT_OPENAI_MODEL } from '@/constants'
|
||||
|
@ -61,11 +61,11 @@ export class OpenAIChatModel<
|
|||
}
|
||||
|
||||
public override get nameForModel(): string {
|
||||
return 'openai_chat'
|
||||
return 'openaiChatCompletion'
|
||||
}
|
||||
|
||||
public override get nameForHuman(): string {
|
||||
return 'OpenAIChatModel'
|
||||
return `OpenAIChatModel ${this._model}`
|
||||
}
|
||||
|
||||
public override get supportsTools(): boolean {
|
||||
|
@ -73,14 +73,16 @@ export class OpenAIChatModel<
|
|||
}
|
||||
|
||||
protected override async _createChatCompletion(
|
||||
messages: types.ChatMessage[]
|
||||
messages: types.ChatMessage[],
|
||||
functions?: types.openai.ChatMessageFunction[]
|
||||
): Promise<
|
||||
types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse>
|
||||
> {
|
||||
const res = await this._client.createChatCompletion({
|
||||
...this._modelParams,
|
||||
model: this._model,
|
||||
messages
|
||||
messages,
|
||||
functions
|
||||
})
|
||||
|
||||
return res
|
||||
|
|
Ładowanie…
Reference in New Issue