kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add support for messages which are thunks
rodzic
1e4befcc1c
commit
dc30f6ed26
|
@ -9,7 +9,11 @@ import * as types from '@/types'
|
|||
import { parseOutput } from '@/llms/parse-output'
|
||||
import { BaseTask } from '@/task'
|
||||
import { getCompiledTemplate } from '@/template'
|
||||
import { extractFunctionIdentifierFromString, stringifyForModel } from '@/utils'
|
||||
import {
|
||||
extractFunctionIdentifierFromString,
|
||||
isFunction,
|
||||
stringifyForModel
|
||||
} from '@/utils'
|
||||
|
||||
import { BaseLLM } from './llm'
|
||||
import {
|
||||
|
@ -23,7 +27,7 @@ export abstract class BaseChatCompletion<
|
|||
TModelParams extends Record<string, any> = Record<string, any>,
|
||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||
> extends BaseLLM<TInput, TOutput, TModelParams> {
|
||||
protected _messages: types.ChatMessage[]
|
||||
protected _messages: types.ChatMessageInput<TInput>[]
|
||||
protected _tools?: BaseTask<any, any>[]
|
||||
|
||||
constructor(
|
||||
|
@ -116,14 +120,17 @@ export abstract class BaseChatCompletion<
|
|||
input = this.inputSchema.parse(input)
|
||||
}
|
||||
|
||||
const messages = this._messages
|
||||
const messages: types.ChatMessage[] = this._messages
|
||||
.map((message) => {
|
||||
return {
|
||||
...message,
|
||||
content: message.content
|
||||
? getCompiledTemplate(dedent(message.content))(input).trim()
|
||||
? // support functions which return a string
|
||||
isFunction(message.content)
|
||||
? message.content(input!)
|
||||
: getCompiledTemplate(dedent(message.content))(input).trim()
|
||||
: ''
|
||||
}
|
||||
} as types.ChatMessage
|
||||
})
|
||||
.filter((message) => message.content || message.function_call)
|
||||
|
||||
|
@ -170,6 +177,7 @@ export abstract class BaseChatCompletion<
|
|||
'additionalProperties',
|
||||
'$schema'
|
||||
])
|
||||
|
||||
let label: string
|
||||
if (outputSchema instanceof z.ZodArray) {
|
||||
label = 'JSON array (minified)'
|
||||
|
|
|
@ -14,7 +14,7 @@ const openaiModelsSupportingFunctions = new Set([
|
|||
])
|
||||
|
||||
export class OpenAIChatCompletion<
|
||||
TInput extends types.TaskInput = any,
|
||||
TInput extends types.TaskInput = void,
|
||||
TOutput extends types.TaskOutput = string
|
||||
> extends BaseChatCompletion<
|
||||
TInput,
|
||||
|
|
|
@ -156,7 +156,7 @@ export abstract class BaseTask<
|
|||
/**
|
||||
* Calls this task with the given `input` and returns the result only.
|
||||
*/
|
||||
public async call(input?: TInput): Promise<TOutput> {
|
||||
public async call(input: TInput): Promise<TOutput> {
|
||||
const res = await this.callWithMetadata(input)
|
||||
return res.result
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ export abstract class BaseTask<
|
|||
* Calls this task with the given `input` and returns the result along with metadata.
|
||||
*/
|
||||
public async callWithMetadata(
|
||||
input?: TInput,
|
||||
input: TInput,
|
||||
parentCtx?: types.TaskCallContext<any>
|
||||
): Promise<types.TaskResponse<TOutput>> {
|
||||
this.validate()
|
||||
|
|
10
src/types.ts
10
src/types.ts
|
@ -73,12 +73,18 @@ export interface LLMOptions<
|
|||
export type ChatMessage = openai.ChatMessage
|
||||
export type ChatMessageRole = openai.ChatMessageRole
|
||||
|
||||
export type ChatMessageInput<TInput extends TaskInput = void> =
|
||||
| ChatMessage
|
||||
| {
|
||||
content: (input: TInput | any) => string
|
||||
}
|
||||
|
||||
export interface ChatModelOptions<
|
||||
TInput extends TaskInput = void,
|
||||
TOutput extends TaskOutput = string,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
||||
messages: ChatMessage[]
|
||||
messages: ChatMessageInput<TInput>[]
|
||||
tools?: BaseTask<any, any>[]
|
||||
}
|
||||
|
||||
|
@ -136,7 +142,7 @@ export interface TaskCallContext<
|
|||
TInput extends TaskInput = void,
|
||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||
> {
|
||||
input?: TInput
|
||||
input: TInput
|
||||
retryMessage?: string
|
||||
|
||||
attemptNumber: number
|
||||
|
|
|
@ -210,3 +210,8 @@ export function throttleKy(
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/ban-types
|
||||
export function isFunction(value: any): value is Function {
|
||||
return typeof value === 'function'
|
||||
}
|
||||
|
|
|
@ -224,3 +224,29 @@ test('OpenAIChatCompletion ⇒ missing template variable', async (t) => {
|
|||
message: 'Template error: "numFacts" not defined in input - 1:10'
|
||||
})
|
||||
})
|
||||
|
||||
test('OpenAIChatCompletion ⇒ function inputs', async (t) => {
|
||||
t.timeout(2 * 60 * 1000)
|
||||
const agentic = createTestAgenticRuntime()
|
||||
|
||||
const builder = new OpenAIChatCompletion({
|
||||
agentic,
|
||||
modelParams: {
|
||||
temperature: 0,
|
||||
max_tokens: 30
|
||||
},
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: (input) => `tell me about ${input.topic}`
|
||||
}
|
||||
]
|
||||
})
|
||||
.input(z.object({ topic: z.string() }))
|
||||
.output(z.string())
|
||||
|
||||
const result = await builder.call({ topic: 'cats' })
|
||||
t.truthy(typeof result === 'string')
|
||||
|
||||
expectTypeOf(result).toMatchTypeOf<string>()
|
||||
})
|
||||
|
|
Ładowanie…
Reference in New Issue