feat: add support for messages which are thunks

old-agentic-v1^2
Travis Fischer 2023-06-19 12:47:32 -07:00
rodzic 1e4befcc1c
commit dc30f6ed26
6 zmienionych plików z 55 dodań i 10 usunięć

Wyświetl plik

@ -9,7 +9,11 @@ import * as types from '@/types'
import { parseOutput } from '@/llms/parse-output'
import { BaseTask } from '@/task'
import { getCompiledTemplate } from '@/template'
import { extractFunctionIdentifierFromString, stringifyForModel } from '@/utils'
import {
extractFunctionIdentifierFromString,
isFunction,
stringifyForModel
} from '@/utils'
import { BaseLLM } from './llm'
import {
@ -23,7 +27,7 @@ export abstract class BaseChatCompletion<
TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLM<TInput, TOutput, TModelParams> {
protected _messages: types.ChatMessage[]
protected _messages: types.ChatMessageInput<TInput>[]
protected _tools?: BaseTask<any, any>[]
constructor(
@ -116,14 +120,17 @@ export abstract class BaseChatCompletion<
input = this.inputSchema.parse(input)
}
const messages = this._messages
const messages: types.ChatMessage[] = this._messages
.map((message) => {
return {
...message,
content: message.content
? getCompiledTemplate(dedent(message.content))(input).trim()
? // support functions which return a string
isFunction(message.content)
? message.content(input!)
: getCompiledTemplate(dedent(message.content))(input).trim()
: ''
}
} as types.ChatMessage
})
.filter((message) => message.content || message.function_call)
@ -170,6 +177,7 @@ export abstract class BaseChatCompletion<
'additionalProperties',
'$schema'
])
let label: string
if (outputSchema instanceof z.ZodArray) {
label = 'JSON array (minified)'

Wyświetl plik

@ -14,7 +14,7 @@ const openaiModelsSupportingFunctions = new Set([
])
export class OpenAIChatCompletion<
TInput extends types.TaskInput = any,
TInput extends types.TaskInput = void,
TOutput extends types.TaskOutput = string
> extends BaseChatCompletion<
TInput,

Wyświetl plik

@ -156,7 +156,7 @@ export abstract class BaseTask<
/**
* Calls this task with the given `input` and returns the result only.
*/
public async call(input?: TInput): Promise<TOutput> {
public async call(input: TInput): Promise<TOutput> {
const res = await this.callWithMetadata(input)
return res.result
}
@ -165,7 +165,7 @@ export abstract class BaseTask<
* Calls this task with the given `input` and returns the result along with metadata.
*/
public async callWithMetadata(
input?: TInput,
input: TInput,
parentCtx?: types.TaskCallContext<any>
): Promise<types.TaskResponse<TOutput>> {
this.validate()

Wyświetl plik

@ -73,12 +73,18 @@ export interface LLMOptions<
export type ChatMessage = openai.ChatMessage
export type ChatMessageRole = openai.ChatMessageRole
export type ChatMessageInput<TInput extends TaskInput = void> =
| ChatMessage
| {
content: (input: TInput | any) => string
}
export interface ChatModelOptions<
TInput extends TaskInput = void,
TOutput extends TaskOutput = string,
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
messages: ChatMessage[]
messages: ChatMessageInput<TInput>[]
tools?: BaseTask<any, any>[]
}
@ -136,7 +142,7 @@ export interface TaskCallContext<
TInput extends TaskInput = void,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
input?: TInput
input: TInput
retryMessage?: string
attemptNumber: number

Wyświetl plik

@ -210,3 +210,8 @@ export function throttleKy(
}
})
}
// eslint-disable-next-line @typescript-eslint/ban-types
export function isFunction(value: any): value is Function {
return typeof value === 'function'
}

Wyświetl plik

@ -224,3 +224,29 @@ test('OpenAIChatCompletion ⇒ missing template variable', async (t) => {
message: 'Template error: "numFacts" not defined in input - 1:10'
})
})
test('OpenAIChatCompletion ⇒ function inputs', async (t) => {
t.timeout(2 * 60 * 1000)
const agentic = createTestAgenticRuntime()
const builder = new OpenAIChatCompletion({
agentic,
modelParams: {
temperature: 0,
max_tokens: 30
},
messages: [
{
role: 'user',
content: (input) => `tell me about ${input.topic}`
}
]
})
.input(z.object({ topic: z.string() }))
.output(z.string())
const result = await builder.call({ topic: 'cats' })
t.truthy(typeof result === 'string')
expectTypeOf(result).toMatchTypeOf<string>()
})