kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add support for messages which are thunks
rodzic
1e4befcc1c
commit
dc30f6ed26
|
@ -9,7 +9,11 @@ import * as types from '@/types'
|
||||||
import { parseOutput } from '@/llms/parse-output'
|
import { parseOutput } from '@/llms/parse-output'
|
||||||
import { BaseTask } from '@/task'
|
import { BaseTask } from '@/task'
|
||||||
import { getCompiledTemplate } from '@/template'
|
import { getCompiledTemplate } from '@/template'
|
||||||
import { extractFunctionIdentifierFromString, stringifyForModel } from '@/utils'
|
import {
|
||||||
|
extractFunctionIdentifierFromString,
|
||||||
|
isFunction,
|
||||||
|
stringifyForModel
|
||||||
|
} from '@/utils'
|
||||||
|
|
||||||
import { BaseLLM } from './llm'
|
import { BaseLLM } from './llm'
|
||||||
import {
|
import {
|
||||||
|
@ -23,7 +27,7 @@ export abstract class BaseChatCompletion<
|
||||||
TModelParams extends Record<string, any> = Record<string, any>,
|
TModelParams extends Record<string, any> = Record<string, any>,
|
||||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||||
> extends BaseLLM<TInput, TOutput, TModelParams> {
|
> extends BaseLLM<TInput, TOutput, TModelParams> {
|
||||||
protected _messages: types.ChatMessage[]
|
protected _messages: types.ChatMessageInput<TInput>[]
|
||||||
protected _tools?: BaseTask<any, any>[]
|
protected _tools?: BaseTask<any, any>[]
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
|
@ -116,14 +120,17 @@ export abstract class BaseChatCompletion<
|
||||||
input = this.inputSchema.parse(input)
|
input = this.inputSchema.parse(input)
|
||||||
}
|
}
|
||||||
|
|
||||||
const messages = this._messages
|
const messages: types.ChatMessage[] = this._messages
|
||||||
.map((message) => {
|
.map((message) => {
|
||||||
return {
|
return {
|
||||||
...message,
|
...message,
|
||||||
content: message.content
|
content: message.content
|
||||||
? getCompiledTemplate(dedent(message.content))(input).trim()
|
? // support functions which return a string
|
||||||
|
isFunction(message.content)
|
||||||
|
? message.content(input!)
|
||||||
|
: getCompiledTemplate(dedent(message.content))(input).trim()
|
||||||
: ''
|
: ''
|
||||||
}
|
} as types.ChatMessage
|
||||||
})
|
})
|
||||||
.filter((message) => message.content || message.function_call)
|
.filter((message) => message.content || message.function_call)
|
||||||
|
|
||||||
|
@ -170,6 +177,7 @@ export abstract class BaseChatCompletion<
|
||||||
'additionalProperties',
|
'additionalProperties',
|
||||||
'$schema'
|
'$schema'
|
||||||
])
|
])
|
||||||
|
|
||||||
let label: string
|
let label: string
|
||||||
if (outputSchema instanceof z.ZodArray) {
|
if (outputSchema instanceof z.ZodArray) {
|
||||||
label = 'JSON array (minified)'
|
label = 'JSON array (minified)'
|
||||||
|
|
|
@ -14,7 +14,7 @@ const openaiModelsSupportingFunctions = new Set([
|
||||||
])
|
])
|
||||||
|
|
||||||
export class OpenAIChatCompletion<
|
export class OpenAIChatCompletion<
|
||||||
TInput extends types.TaskInput = any,
|
TInput extends types.TaskInput = void,
|
||||||
TOutput extends types.TaskOutput = string
|
TOutput extends types.TaskOutput = string
|
||||||
> extends BaseChatCompletion<
|
> extends BaseChatCompletion<
|
||||||
TInput,
|
TInput,
|
||||||
|
|
|
@ -156,7 +156,7 @@ export abstract class BaseTask<
|
||||||
/**
|
/**
|
||||||
* Calls this task with the given `input` and returns the result only.
|
* Calls this task with the given `input` and returns the result only.
|
||||||
*/
|
*/
|
||||||
public async call(input?: TInput): Promise<TOutput> {
|
public async call(input: TInput): Promise<TOutput> {
|
||||||
const res = await this.callWithMetadata(input)
|
const res = await this.callWithMetadata(input)
|
||||||
return res.result
|
return res.result
|
||||||
}
|
}
|
||||||
|
@ -165,7 +165,7 @@ export abstract class BaseTask<
|
||||||
* Calls this task with the given `input` and returns the result along with metadata.
|
* Calls this task with the given `input` and returns the result along with metadata.
|
||||||
*/
|
*/
|
||||||
public async callWithMetadata(
|
public async callWithMetadata(
|
||||||
input?: TInput,
|
input: TInput,
|
||||||
parentCtx?: types.TaskCallContext<any>
|
parentCtx?: types.TaskCallContext<any>
|
||||||
): Promise<types.TaskResponse<TOutput>> {
|
): Promise<types.TaskResponse<TOutput>> {
|
||||||
this.validate()
|
this.validate()
|
||||||
|
|
10
src/types.ts
10
src/types.ts
|
@ -73,12 +73,18 @@ export interface LLMOptions<
|
||||||
export type ChatMessage = openai.ChatMessage
|
export type ChatMessage = openai.ChatMessage
|
||||||
export type ChatMessageRole = openai.ChatMessageRole
|
export type ChatMessageRole = openai.ChatMessageRole
|
||||||
|
|
||||||
|
export type ChatMessageInput<TInput extends TaskInput = void> =
|
||||||
|
| ChatMessage
|
||||||
|
| {
|
||||||
|
content: (input: TInput | any) => string
|
||||||
|
}
|
||||||
|
|
||||||
export interface ChatModelOptions<
|
export interface ChatModelOptions<
|
||||||
TInput extends TaskInput = void,
|
TInput extends TaskInput = void,
|
||||||
TOutput extends TaskOutput = string,
|
TOutput extends TaskOutput = string,
|
||||||
TModelParams extends Record<string, any> = Record<string, any>
|
TModelParams extends Record<string, any> = Record<string, any>
|
||||||
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
||||||
messages: ChatMessage[]
|
messages: ChatMessageInput<TInput>[]
|
||||||
tools?: BaseTask<any, any>[]
|
tools?: BaseTask<any, any>[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,7 +142,7 @@ export interface TaskCallContext<
|
||||||
TInput extends TaskInput = void,
|
TInput extends TaskInput = void,
|
||||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||||
> {
|
> {
|
||||||
input?: TInput
|
input: TInput
|
||||||
retryMessage?: string
|
retryMessage?: string
|
||||||
|
|
||||||
attemptNumber: number
|
attemptNumber: number
|
||||||
|
|
|
@ -210,3 +210,8 @@ export function throttleKy(
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line @typescript-eslint/ban-types
|
||||||
|
export function isFunction(value: any): value is Function {
|
||||||
|
return typeof value === 'function'
|
||||||
|
}
|
||||||
|
|
|
@ -224,3 +224,29 @@ test('OpenAIChatCompletion ⇒ missing template variable', async (t) => {
|
||||||
message: 'Template error: "numFacts" not defined in input - 1:10'
|
message: 'Template error: "numFacts" not defined in input - 1:10'
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test('OpenAIChatCompletion ⇒ function inputs', async (t) => {
|
||||||
|
t.timeout(2 * 60 * 1000)
|
||||||
|
const agentic = createTestAgenticRuntime()
|
||||||
|
|
||||||
|
const builder = new OpenAIChatCompletion({
|
||||||
|
agentic,
|
||||||
|
modelParams: {
|
||||||
|
temperature: 0,
|
||||||
|
max_tokens: 30
|
||||||
|
},
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: (input) => `tell me about ${input.topic}`
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
.input(z.object({ topic: z.string() }))
|
||||||
|
.output(z.string())
|
||||||
|
|
||||||
|
const result = await builder.call({ topic: 'cats' })
|
||||||
|
t.truthy(typeof result === 'string')
|
||||||
|
|
||||||
|
expectTypeOf(result).toMatchTypeOf<string>()
|
||||||
|
})
|
||||||
|
|
Ładowanie…
Reference in New Issue