kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
old-agentic-v1^2
rodzic
8f4fde0260
commit
a4d89fcba6
|
@ -48,11 +48,7 @@ export class AnthropicChatModel<
|
||||||
}
|
}
|
||||||
|
|
||||||
public override get nameForModel(): string {
|
public override get nameForModel(): string {
|
||||||
return 'anthropic_chat'
|
return 'anthropicChatCompletion'
|
||||||
}
|
|
||||||
|
|
||||||
public override get nameForHuman(): string {
|
|
||||||
return 'AnthropicChatModel'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override async _createChatCompletion(
|
protected override async _createChatCompletion(
|
||||||
|
|
|
@ -6,15 +6,18 @@ import { printNode, zodToTs } from 'zod-to-ts'
|
||||||
|
|
||||||
import * as errors from '@/errors'
|
import * as errors from '@/errors'
|
||||||
import * as types from '@/types'
|
import * as types from '@/types'
|
||||||
|
import { BaseTask } from '@/task'
|
||||||
import { getCompiledTemplate } from '@/template'
|
import { getCompiledTemplate } from '@/template'
|
||||||
import {
|
import {
|
||||||
extractJSONArrayFromString,
|
extractJSONArrayFromString,
|
||||||
extractJSONObjectFromString
|
extractJSONObjectFromString
|
||||||
} from '@/utils'
|
} from '@/utils'
|
||||||
|
|
||||||
import { BaseTask } from '../task'
|
|
||||||
import { BaseLLM } from './llm'
|
import { BaseLLM } from './llm'
|
||||||
import { getNumTokensForChatMessages } from './llm-utils'
|
import {
|
||||||
|
getChatMessageFunctionDefinitionsFromTasks,
|
||||||
|
getNumTokensForChatMessages
|
||||||
|
} from './llm-utils'
|
||||||
|
|
||||||
export abstract class BaseChatModel<
|
export abstract class BaseChatModel<
|
||||||
TInput extends void | types.JsonObject = void,
|
TInput extends void | types.JsonObject = void,
|
||||||
|
@ -79,7 +82,8 @@ export abstract class BaseChatModel<
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract _createChatCompletion(
|
protected abstract _createChatCompletion(
|
||||||
messages: types.ChatMessage[]
|
messages: types.ChatMessage[],
|
||||||
|
functions?: types.openai.ChatMessageFunction[]
|
||||||
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
||||||
|
|
||||||
public async buildMessages(
|
public async buildMessages(
|
||||||
|
@ -100,7 +104,7 @@ export abstract class BaseChatModel<
|
||||||
: ''
|
: ''
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.filter((message) => message.content)
|
.filter((message) => message.content || message.function_call)
|
||||||
|
|
||||||
if (this._examples?.length) {
|
if (this._examples?.length) {
|
||||||
// TODO: smarter example selection
|
// TODO: smarter example selection
|
||||||
|
@ -161,9 +165,76 @@ export abstract class BaseChatModel<
|
||||||
console.log('>>>')
|
console.log('>>>')
|
||||||
console.log(messages)
|
console.log(messages)
|
||||||
|
|
||||||
const completion = await this._createChatCompletion(messages)
|
let functions = this._modelParams?.functions
|
||||||
|
let isUsingTools = false
|
||||||
|
|
||||||
|
if (this.supportsTools) {
|
||||||
|
if (this._tools?.length) {
|
||||||
|
if (functions?.length) {
|
||||||
|
throw new Error(`Cannot specify both tools and functions`)
|
||||||
|
}
|
||||||
|
|
||||||
|
functions = getChatMessageFunctionDefinitionsFromTasks(this._tools)
|
||||||
|
isUsingTools = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const completion = await this._createChatCompletion(messages, functions)
|
||||||
ctx.metadata.completion = completion
|
ctx.metadata.completion = completion
|
||||||
|
|
||||||
|
if (completion.message.function_call) {
|
||||||
|
const functionCall = completion.message.function_call
|
||||||
|
|
||||||
|
if (isUsingTools) {
|
||||||
|
const tool = this._tools!.find(
|
||||||
|
(tool) => tool.nameForModel === functionCall.name
|
||||||
|
)
|
||||||
|
|
||||||
|
if (!tool) {
|
||||||
|
throw new errors.OutputValidationError(
|
||||||
|
`Unrecognized function call "${functionCall.name}"`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
let functionArguments: any
|
||||||
|
try {
|
||||||
|
functionArguments = JSON.parse(jsonrepair(functionCall.arguments))
|
||||||
|
} catch (err: any) {
|
||||||
|
if (err instanceof JSONRepairError) {
|
||||||
|
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||||
|
} else if (err instanceof SyntaxError) {
|
||||||
|
throw new errors.OutputValidationError(
|
||||||
|
`Invalid JSON object: ${err.message}`,
|
||||||
|
{ cause: err }
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle sub-task errors gracefully
|
||||||
|
const toolCallResponse = await tool.callWithMetadata(functionArguments)
|
||||||
|
|
||||||
|
// TODO: handle result as string or JSON
|
||||||
|
// TODO:
|
||||||
|
const taskCallContent = JSON.stringify(
|
||||||
|
toolCallResponse.result,
|
||||||
|
null,
|
||||||
|
1
|
||||||
|
).replaceAll(/\n ?/gm, ' ')
|
||||||
|
|
||||||
|
messages.push(completion.message)
|
||||||
|
messages.push({
|
||||||
|
role: 'function',
|
||||||
|
name: functionCall.name,
|
||||||
|
content: taskCallContent
|
||||||
|
})
|
||||||
|
|
||||||
|
// TODO: re-invoke completion with new messages
|
||||||
|
throw new Error('TODO')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let output: any = completion.message.content
|
let output: any = completion.message.content
|
||||||
|
|
||||||
console.log('===')
|
console.log('===')
|
||||||
|
@ -241,7 +312,7 @@ export abstract class BaseChatModel<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this doesn't bode well, batman...
|
// TODO: fix typescript issue here with recursive types
|
||||||
const safeResult = (outputSchema.safeParse as any)(output)
|
const safeResult = (outputSchema.safeParse as any)(output)
|
||||||
|
|
||||||
if (!safeResult.success) {
|
if (!safeResult.success) {
|
||||||
|
|
|
@ -78,7 +78,7 @@ export abstract class BaseLLM<
|
||||||
}
|
}
|
||||||
|
|
||||||
public override get nameForHuman(): string {
|
public override get nameForHuman(): string {
|
||||||
return `${this._provider}:chat:${this._model}`
|
return `${this.constructor.name} ${this._model}`
|
||||||
}
|
}
|
||||||
|
|
||||||
examples(examples: types.LLMExample[]): this {
|
examples(examples: types.LLMExample[]): this {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { type SetOptional } from 'type-fest'
|
import type { SetOptional } from 'type-fest'
|
||||||
|
|
||||||
import * as types from '@/types'
|
import * as types from '@/types'
|
||||||
import { DEFAULT_OPENAI_MODEL } from '@/constants'
|
import { DEFAULT_OPENAI_MODEL } from '@/constants'
|
||||||
|
@ -61,11 +61,11 @@ export class OpenAIChatModel<
|
||||||
}
|
}
|
||||||
|
|
||||||
public override get nameForModel(): string {
|
public override get nameForModel(): string {
|
||||||
return 'openai_chat'
|
return 'openaiChatCompletion'
|
||||||
}
|
}
|
||||||
|
|
||||||
public override get nameForHuman(): string {
|
public override get nameForHuman(): string {
|
||||||
return 'OpenAIChatModel'
|
return `OpenAIChatModel ${this._model}`
|
||||||
}
|
}
|
||||||
|
|
||||||
public override get supportsTools(): boolean {
|
public override get supportsTools(): boolean {
|
||||||
|
@ -73,14 +73,16 @@ export class OpenAIChatModel<
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override async _createChatCompletion(
|
protected override async _createChatCompletion(
|
||||||
messages: types.ChatMessage[]
|
messages: types.ChatMessage[],
|
||||||
|
functions?: types.openai.ChatMessageFunction[]
|
||||||
): Promise<
|
): Promise<
|
||||||
types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse>
|
types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse>
|
||||||
> {
|
> {
|
||||||
const res = await this._client.createChatCompletion({
|
const res = await this._client.createChatCompletion({
|
||||||
...this._modelParams,
|
...this._modelParams,
|
||||||
model: this._model,
|
model: this._model,
|
||||||
messages
|
messages,
|
||||||
|
functions
|
||||||
})
|
})
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
Ładowanie…
Reference in New Issue