feat: add support for messages which are thunks

old-agentic-v1^2
Travis Fischer 2023-06-19 12:47:32 -07:00
rodzic 1e4befcc1c
commit dc30f6ed26
6 zmienionych plików z 55 dodań i 10 usunięć

Wyświetl plik

@ -9,7 +9,11 @@ import * as types from '@/types'
import { parseOutput } from '@/llms/parse-output' import { parseOutput } from '@/llms/parse-output'
import { BaseTask } from '@/task' import { BaseTask } from '@/task'
import { getCompiledTemplate } from '@/template' import { getCompiledTemplate } from '@/template'
import { extractFunctionIdentifierFromString, stringifyForModel } from '@/utils' import {
extractFunctionIdentifierFromString,
isFunction,
stringifyForModel
} from '@/utils'
import { BaseLLM } from './llm' import { BaseLLM } from './llm'
import { import {
@ -23,7 +27,7 @@ export abstract class BaseChatCompletion<
TModelParams extends Record<string, any> = Record<string, any>, TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any> TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLM<TInput, TOutput, TModelParams> { > extends BaseLLM<TInput, TOutput, TModelParams> {
protected _messages: types.ChatMessage[] protected _messages: types.ChatMessageInput<TInput>[]
protected _tools?: BaseTask<any, any>[] protected _tools?: BaseTask<any, any>[]
constructor( constructor(
@ -116,14 +120,17 @@ export abstract class BaseChatCompletion<
input = this.inputSchema.parse(input) input = this.inputSchema.parse(input)
} }
const messages = this._messages const messages: types.ChatMessage[] = this._messages
.map((message) => { .map((message) => {
return { return {
...message, ...message,
content: message.content content: message.content
? getCompiledTemplate(dedent(message.content))(input).trim() ? // support functions which return a string
isFunction(message.content)
? message.content(input!)
: getCompiledTemplate(dedent(message.content))(input).trim()
: '' : ''
} } as types.ChatMessage
}) })
.filter((message) => message.content || message.function_call) .filter((message) => message.content || message.function_call)
@ -170,6 +177,7 @@ export abstract class BaseChatCompletion<
'additionalProperties', 'additionalProperties',
'$schema' '$schema'
]) ])
let label: string let label: string
if (outputSchema instanceof z.ZodArray) { if (outputSchema instanceof z.ZodArray) {
label = 'JSON array (minified)' label = 'JSON array (minified)'

Wyświetl plik

@ -14,7 +14,7 @@ const openaiModelsSupportingFunctions = new Set([
]) ])
export class OpenAIChatCompletion< export class OpenAIChatCompletion<
TInput extends types.TaskInput = any, TInput extends types.TaskInput = void,
TOutput extends types.TaskOutput = string TOutput extends types.TaskOutput = string
> extends BaseChatCompletion< > extends BaseChatCompletion<
TInput, TInput,

Wyświetl plik

@ -156,7 +156,7 @@ export abstract class BaseTask<
/** /**
* Calls this task with the given `input` and returns the result only. * Calls this task with the given `input` and returns the result only.
*/ */
public async call(input?: TInput): Promise<TOutput> { public async call(input: TInput): Promise<TOutput> {
const res = await this.callWithMetadata(input) const res = await this.callWithMetadata(input)
return res.result return res.result
} }
@ -165,7 +165,7 @@ export abstract class BaseTask<
* Calls this task with the given `input` and returns the result along with metadata. * Calls this task with the given `input` and returns the result along with metadata.
*/ */
public async callWithMetadata( public async callWithMetadata(
input?: TInput, input: TInput,
parentCtx?: types.TaskCallContext<any> parentCtx?: types.TaskCallContext<any>
): Promise<types.TaskResponse<TOutput>> { ): Promise<types.TaskResponse<TOutput>> {
this.validate() this.validate()

Wyświetl plik

@ -73,12 +73,18 @@ export interface LLMOptions<
export type ChatMessage = openai.ChatMessage export type ChatMessage = openai.ChatMessage
export type ChatMessageRole = openai.ChatMessageRole export type ChatMessageRole = openai.ChatMessageRole
export type ChatMessageInput<TInput extends TaskInput = void> =
| ChatMessage
| {
content: (input: TInput | any) => string
}
export interface ChatModelOptions< export interface ChatModelOptions<
TInput extends TaskInput = void, TInput extends TaskInput = void,
TOutput extends TaskOutput = string, TOutput extends TaskOutput = string,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> { > extends BaseLLMOptions<TInput, TOutput, TModelParams> {
messages: ChatMessage[] messages: ChatMessageInput<TInput>[]
tools?: BaseTask<any, any>[] tools?: BaseTask<any, any>[]
} }
@ -136,7 +142,7 @@ export interface TaskCallContext<
TInput extends TaskInput = void, TInput extends TaskInput = void,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> { > {
input?: TInput input: TInput
retryMessage?: string retryMessage?: string
attemptNumber: number attemptNumber: number

Wyświetl plik

@ -210,3 +210,8 @@ export function throttleKy(
} }
}) })
} }
// eslint-disable-next-line @typescript-eslint/ban-types
export function isFunction(value: any): value is Function {
return typeof value === 'function'
}

Wyświetl plik

@ -224,3 +224,29 @@ test('OpenAIChatCompletion ⇒ missing template variable', async (t) => {
message: 'Template error: "numFacts" not defined in input - 1:10' message: 'Template error: "numFacts" not defined in input - 1:10'
}) })
}) })
test('OpenAIChatCompletion ⇒ function inputs', async (t) => {
t.timeout(2 * 60 * 1000)
const agentic = createTestAgenticRuntime()
const builder = new OpenAIChatCompletion({
agentic,
modelParams: {
temperature: 0,
max_tokens: 30
},
messages: [
{
role: 'user',
content: (input) => `tell me about ${input.topic}`
}
]
})
.input(z.object({ topic: z.string() }))
.output(z.string())
const result = await builder.call({ topic: 'cats' })
t.truthy(typeof result === 'string')
expectTypeOf(result).toMatchTypeOf<string>()
})