feat: improve openai function/task/tool support

Travis Fischer 2023-06-13 21:39:19 -07:00
rodzic c9d6619855
commit 0ace3bb9b8
20 zmienionych plików z 184 dodań i 46 usunięć

Wyświetl plik

@ -0,0 +1,26 @@
# Snapshot report for `test/llms/llm-utils.test.ts`
The actual snapshot is saved in `llm-utils.test.ts.snap`.
Generated by [AVA](https://avajs.dev).
## getChatMessageFunctionDefinitionFromTask
> Snapshot 1
{
description: 'Useful for getting the result of a math expression. The input to this tool should be a valid mathematical expression that could be executed by a simple calculator.',
name: 'calculator',
parameters: {
properties: {
expression: {
description: 'mathematical expression to evaluate',
type: 'string',
},
},
required: [
'expression',
],
type: 'object',
},
}

Plik binarny nie jest wyświetlany.

Wyświetl plik

@ -1,5 +1,4 @@
import { ZodTypeAny } from 'zod' import * as types from '@/types'
import { Agentic } from '@/agentic' import { Agentic } from '@/agentic'
import { BaseTask } from '@/task' import { BaseTask } from '@/task'
@ -37,8 +36,8 @@ export class HumanFeedbackMechanismCLI extends HumanFeedbackMechanism {
} }
export function withHumanFeedback< export function withHumanFeedback<
TInput extends ZodTypeAny = ZodTypeAny, TInput extends void | types.JsonObject = void,
TOutput extends ZodTypeAny = ZodTypeAny TOutput extends types.JsonValue = string
>( >(
task: BaseTask<TInput, TOutput>, task: BaseTask<TInput, TOutput>,
options: HumanFeedbackOptions = { options: HumanFeedbackOptions = {

Wyświetl plik

@ -9,8 +9,8 @@ import { BaseChatModel } from './chat'
const defaultStopSequences = [anthropic.HUMAN_PROMPT] const defaultStopSequences = [anthropic.HUMAN_PROMPT]
export class AnthropicChatModel< export class AnthropicChatModel<
TInput = any, TInput extends void | types.JsonObject = any,
TOutput = string TOutput extends types.JsonValue = string
> extends BaseChatModel< > extends BaseChatModel<
TInput, TInput,
TOutput, TOutput,

Wyświetl plik

@ -3,7 +3,6 @@ import pMap from 'p-map'
import { dedent } from 'ts-dedent' import { dedent } from 'ts-dedent'
import { type SetRequired } from 'type-fest' import { type SetRequired } from 'type-fest'
import { ZodType, z } from 'zod' import { ZodType, z } from 'zod'
import { zodToJsonSchema } from 'zod-to-json-schema'
import { printNode, zodToTs } from 'zod-to-ts' import { printNode, zodToTs } from 'zod-to-ts'
import * as errors from '@/errors' import * as errors from '@/errors'
@ -19,8 +18,8 @@ import { BaseTask } from '../task'
import { BaseLLM } from './llm' import { BaseLLM } from './llm'
export abstract class BaseChatModel< export abstract class BaseChatModel<
TInput = void, TInput extends void | types.JsonObject = void,
TOutput = string, TOutput extends types.JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any>, TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any> TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLM<TInput, TOutput, TModelParams> { > extends BaseLLM<TInput, TOutput, TModelParams> {
@ -40,7 +39,9 @@ export abstract class BaseChatModel<
} }
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type // TODO: use polymorphic `this` type to return correct BaseLLM subclass type
input<U>(inputSchema: ZodType<U>): BaseChatModel<U, TOutput, TModelParams> { input<U extends void | types.JsonObject>(
inputSchema: ZodType<U>
): BaseChatModel<U, TOutput, TModelParams> {
const refinedInstance = this as unknown as BaseChatModel< const refinedInstance = this as unknown as BaseChatModel<
U, U,
TOutput, TOutput,
@ -51,7 +52,9 @@ export abstract class BaseChatModel<
} }
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type // TODO: use polymorphic `this` type to return correct BaseLLM subclass type
output<U>(outputSchema: ZodType<U>): BaseChatModel<TInput, U, TModelParams> { output<U extends types.JsonValue>(
outputSchema: ZodType<U>
): BaseChatModel<TInput, U, TModelParams> {
const refinedInstance = this as unknown as BaseChatModel< const refinedInstance = this as unknown as BaseChatModel<
TInput, TInput,
U, U,
@ -72,14 +75,14 @@ export abstract class BaseChatModel<
return this return this
} }
protected abstract _createChatCompletion(
messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public get supportsTools(): boolean { public get supportsTools(): boolean {
return false return false
} }
protected abstract _createChatCompletion(
messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public async buildMessages( public async buildMessages(
input?: TInput, input?: TInput,
ctx?: types.TaskCallContext<TInput> ctx?: types.TaskCallContext<TInput>
@ -239,7 +242,8 @@ export abstract class BaseChatModel<
} }
} }
const safeResult = outputSchema.safeParse(output) // TODO: this doesn't bode well, batman...
const safeResult = (outputSchema.safeParse as any)(output)
if (!safeResult.success) { if (!safeResult.success) {
throw new errors.ZodOutputValidationError(safeResult.error) throw new errors.ZodOutputValidationError(safeResult.error)

Wyświetl plik

@ -1,4 +1,5 @@
export * from './llm' export * from './llm'
export * from './llm-utils'
export * from './chat' export * from './chat'
export * from './openai' export * from './openai'
export * from './anthropic' export * from './anthropic'

Wyświetl plik

@ -0,0 +1,38 @@
import { zodToJsonSchema } from 'zod-to-json-schema'
import * as types from '@/types'
import { BaseTask } from '@/task'
import { isValidTaskIdentifier } from '@/utils'
export function getChatMessageFunctionDefinitionFromTask(
task: BaseTask<any, any>
): types.openai.ChatMessageFunction {
const name = task.nameForModel
if (!isValidTaskIdentifier(name)) {
throw new Error(`Invalid task name "${name}"`)
}
const jsonSchema = zodToJsonSchema(task.inputSchema, {
name,
$refStrategy: 'none'
})
const parameters: any = jsonSchema.definitions?.[name]
if (parameters) {
if (parameters.additionalProperties === false) {
delete parameters['additionalProperties']
}
}
return {
name,
description: task.descForModel || task.nameForHuman,
parameters
}
}
export function getChatMessageFunctionDefinitionsFromTasks(
tasks: BaseTask<any, any>[]
): types.openai.ChatMessageFunction[] {
return tasks.map(getChatMessageFunctionDefinitionFromTask)
}

Wyświetl plik

@ -7,8 +7,8 @@ import { Tokenizer, getTokenizerForModel } from '@/tokenizer'
// TODO: TInput should only be allowed to be void or an object // TODO: TInput should only be allowed to be void or an object
export abstract class BaseLLM< export abstract class BaseLLM<
TInput = void, TInput extends void | types.JsonObject = void,
TOutput = string, TOutput extends types.JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTask<TInput, TOutput> { > extends BaseTask<TInput, TOutput> {
protected _inputSchema: ZodType<TInput> | undefined protected _inputSchema: ZodType<TInput> | undefined
@ -38,14 +38,18 @@ export abstract class BaseLLM<
} }
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type // TODO: use polymorphic `this` type to return correct BaseLLM subclass type
input<U>(inputSchema: ZodType<U>): BaseLLM<U, TOutput, TModelParams> { input<U extends void | types.JsonObject>(
inputSchema: ZodType<U>
): BaseLLM<U, TOutput, TModelParams> {
const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams> const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams>
refinedInstance._inputSchema = inputSchema refinedInstance._inputSchema = inputSchema
return refinedInstance return refinedInstance
} }
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type // TODO: use polymorphic `this` type to return correct BaseLLM subclass type
output<U>(outputSchema: ZodType<U>): BaseLLM<TInput, U, TModelParams> { output<U extends types.JsonValue>(
outputSchema: ZodType<U>
): BaseLLM<TInput, U, TModelParams> {
const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams> const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams>
refinedInstance._outputSchema = outputSchema refinedInstance._outputSchema = outputSchema
return refinedInstance return refinedInstance

Wyświetl plik

@ -13,8 +13,8 @@ const openaiModelsSupportingFunctions = new Set([
]) ])
export class OpenAIChatModel< export class OpenAIChatModel<
TInput = any, TInput extends void | types.JsonObject = any,
TOutput = string TOutput extends types.JsonValue = string
> extends BaseChatModel< > extends BaseChatModel<
TInput, TInput,
TOutput, TOutput,

Wyświetl plik

@ -18,7 +18,10 @@ import { Agentic } from '@/agentic'
* - Native function calls * - Native function calls
* - Invoking sub-agents * - Invoking sub-agents
*/ */
export abstract class BaseTask<TInput = void, TOutput = string> { export abstract class BaseTask<
TInput extends void | types.JsonObject = void,
TOutput extends types.JsonValue = string
> {
protected _agentic: Agentic protected _agentic: Agentic
protected _id: string protected _id: string
@ -26,6 +29,10 @@ export abstract class BaseTask<TInput = void, TOutput = string> {
protected _retryConfig: types.RetryConfig protected _retryConfig: types.RetryConfig
constructor(options: types.BaseTaskOptions) { constructor(options: types.BaseTaskOptions) {
if (!options.agentic) {
throw new Error('Passing "agentic" is required when creating a Task')
}
this._agentic = options.agentic this._agentic = options.agentic
this._timeoutMs = options.timeoutMs this._timeoutMs = options.timeoutMs
this._retryConfig = options.retryConfig ?? { this._retryConfig = options.retryConfig ?? {
@ -49,7 +56,7 @@ export abstract class BaseTask<TInput = void, TOutput = string> {
public abstract get nameForModel(): string public abstract get nameForModel(): string
public get nameForHuman(): string { public get nameForHuman(): string {
return this.nameForModel return this.constructor.name
} }
public get descForModel(): string { public get descForModel(): string {
@ -67,11 +74,17 @@ export abstract class BaseTask<TInput = void, TOutput = string> {
return this return this
} }
/**
* Calls this task with the given `input` and returns the result only.
*/
public async call(input?: TInput): Promise<TOutput> { public async call(input?: TInput): Promise<TOutput> {
const res = await this.callWithMetadata(input) const res = await this.callWithMetadata(input)
return res.result return res.result
} }
/**
* Calls this task with the given `input` and returns the result along with metadata.
*/
public async callWithMetadata( public async callWithMetadata(
input?: TInput input?: TInput
): Promise<types.TaskResponse<TOutput>> { ): Promise<types.TaskResponse<TOutput>> {
@ -126,6 +139,9 @@ export abstract class BaseTask<TInput = void, TOutput = string> {
} }
} }
/**
* Subclasses must implement the core `_call` logic for this task.
*/
protected abstract _call(ctx: types.TaskCallContext<TInput>): Promise<TOutput> protected abstract _call(ctx: types.TaskCallContext<TInput>): Promise<TOutput>
// TODO // TODO

Wyświetl plik

@ -4,7 +4,9 @@ import { z } from 'zod'
import * as types from '@/types' import * as types from '@/types'
import { BaseTask } from '@/task' import { BaseTask } from '@/task'
export const CalculatorInputSchema = z.string().describe('expression') export const CalculatorInputSchema = z.object({
expression: z.string().describe('mathematical expression to evaluate')
})
export const CalculatorOutputSchema = z export const CalculatorOutputSchema = z
.number() .number()
.describe('result of calculating the expression') .describe('result of calculating the expression')
@ -44,7 +46,7 @@ export class CalculatorTool extends BaseTask<
protected override async _call( protected override async _call(
ctx: types.TaskCallContext<CalculatorInput> ctx: types.TaskCallContext<CalculatorInput>
): Promise<CalculatorOutput> { ): Promise<CalculatorOutput> {
const result = Parser.evaluate(ctx.input!) const result = Parser.evaluate(ctx.input!.expression)
return result return result
} }
} }

Wyświetl plik

@ -7,7 +7,7 @@ import { BaseTask } from '@/task'
export const NovuNotificationToolInputSchema = z.object({ export const NovuNotificationToolInputSchema = z.object({
name: z.string(), name: z.string(),
payload: z.record(z.unknown()), payload: z.record(z.any()),
to: z.array( to: z.array(
z.object({ z.object({
subscriberId: z.string(), subscriberId: z.string(),

Wyświetl plik

@ -1,7 +1,7 @@
import * as openai from '@agentic/openai-fetch' import * as openai from '@agentic/openai-fetch'
import * as anthropic from '@anthropic-ai/sdk' import * as anthropic from '@anthropic-ai/sdk'
import type { Options as RetryOptions } from 'p-retry' import type { Options as RetryOptions } from 'p-retry'
import type { JsonObject } from 'type-fest' import type { JsonObject, JsonValue } from 'type-fest'
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
import type { Agentic } from './agentic' import type { Agentic } from './agentic'
@ -9,6 +9,7 @@ import type { BaseTask } from './task'
export { openai } export { openai }
export { anthropic } export { anthropic }
export type { JsonObject, JsonValue }
export type ParsedData<T extends ZodTypeAny> = T extends ZodTypeAny export type ParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
? output<T> ? output<T>
@ -32,8 +33,8 @@ export interface BaseTaskOptions {
} }
export interface BaseLLMOptions< export interface BaseLLMOptions<
TInput = void, TInput extends void | JsonObject = void,
TOutput = string, TOutput extends JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseTaskOptions { > extends BaseTaskOptions {
inputSchema?: ZodType<TInput> inputSchema?: ZodType<TInput>
@ -46,8 +47,8 @@ export interface BaseLLMOptions<
} }
export interface LLMOptions< export interface LLMOptions<
TInput = void, TInput extends void | JsonObject = void,
TOutput = string, TOutput extends JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> { > extends BaseLLMOptions<TInput, TOutput, TModelParams> {
promptTemplate?: string promptTemplate?: string
@ -59,8 +60,8 @@ export type ChatMessage = openai.ChatMessage
export type ChatMessageRole = openai.ChatMessageRole export type ChatMessageRole = openai.ChatMessageRole
export interface ChatModelOptions< export interface ChatModelOptions<
TInput = void, TInput extends void | JsonObject = void,
TOutput = string, TOutput extends JsonValue = string,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMOptions<TInput, TOutput, TModelParams> { > extends BaseLLMOptions<TInput, TOutput, TModelParams> {
messages: ChatMessage[] messages: ChatMessage[]
@ -105,7 +106,7 @@ export interface LLMTaskResponseMetadata<
} }
export interface TaskResponse< export interface TaskResponse<
TOutput = string, TOutput extends JsonValue = string,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> { > {
result: TOutput result: TOutput
@ -113,7 +114,7 @@ export interface TaskResponse<
} }
export interface TaskCallContext< export interface TaskCallContext<
TInput = void, TInput extends void | JsonObject = void,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> { > {
input?: TInput input?: TInput

Wyświetl plik

@ -2,14 +2,22 @@ import { customAlphabet, urlAlphabet } from 'nanoid'
import * as types from './types' import * as types from './types'
export const extractJSONObjectFromString = (text: string): string | undefined => export function extractJSONObjectFromString(text: string): string | undefined {
text.match(/\{(.|\n)*\}/gm)?.[0] return text.match(/\{(.|\n)*\}/gm)?.[0]
}
export const extractJSONArrayFromString = (text: string): string | undefined => export function extractJSONArrayFromString(text: string): string | undefined {
text.match(/\[(.|\n)*\]/gm)?.[0] return text.match(/\[(.|\n)*\]/gm)?.[0]
}
export const sleep = (ms: number) => export function sleep(ms: number) {
new Promise((resolve) => setTimeout(resolve, ms)) return new Promise((resolve) => setTimeout(resolve, ms))
}
export const defaultIDGeneratorFn: types.IDGeneratorFunction = export const defaultIDGeneratorFn: types.IDGeneratorFunction =
customAlphabet(urlAlphabet) customAlphabet(urlAlphabet)
const taskNameRegex = /^[a-zA-Z_][a-zA-Z0-9_-]{0,63}$/
export function isValidTaskIdentifier(id: string): boolean {
return !!id && taskNameRegex.test(id)
}

Wyświetl plik

@ -9,11 +9,11 @@ test('CalculatorTool', async (t) => {
const agentic = createTestAgenticRuntime() const agentic = createTestAgenticRuntime()
const tool = new CalculatorTool({ agentic }) const tool = new CalculatorTool({ agentic })
const res = await tool.call('1 + 1') const res = await tool.call({ expression: '1 + 1' })
t.is(res, 2) t.is(res, 2)
expectTypeOf(res).toMatchTypeOf<number>() expectTypeOf(res).toMatchTypeOf<number>()
const res2 = await tool.callWithMetadata('cos(0)') const res2 = await tool.callWithMetadata({ expression: 'cos(0)' })
t.is(res2.result, 1) t.is(res2.result, 1)
expectTypeOf(res2.result).toMatchTypeOf<number>() expectTypeOf(res2.result).toMatchTypeOf<number>()

Wyświetl plik

@ -3,7 +3,7 @@ import { expectTypeOf } from 'expect-type'
import { AnthropicChatModel } from '@/llms/anthropic' import { AnthropicChatModel } from '@/llms/anthropic'
import { createTestAgenticRuntime } from './_utils' import { createTestAgenticRuntime } from '../_utils'
test('AnthropicChatModel ⇒ string output', async (t) => { test('AnthropicChatModel ⇒ string output', async (t) => {
t.timeout(2 * 60 * 1000) t.timeout(2 * 60 * 1000)

Wyświetl plik

@ -0,0 +1,19 @@
import test from 'ava'
import { getChatMessageFunctionDefinitionFromTask } from '@/llms/llm-utils'
import { CalculatorTool } from '@/tools/calculator'
import { createTestAgenticRuntime } from '../_utils'
test('getChatMessageFunctionDefinitionFromTask', async (t) => {
const agentic = createTestAgenticRuntime()
const tool = new CalculatorTool({ agentic })
const functionDefinition = getChatMessageFunctionDefinitionFromTask(tool)
t.is(functionDefinition.name, 'calculator')
t.is(functionDefinition.description, tool.descForModel)
console.log(JSON.stringify(functionDefinition, null, 2))
t.snapshot(functionDefinition)
})

Wyświetl plik

@ -6,7 +6,7 @@ import { z } from 'zod'
import { OutputValidationError, TemplateValidationError } from '@/errors' import { OutputValidationError, TemplateValidationError } from '@/errors'
import { BaseChatModel, OpenAIChatModel } from '@/llms' import { BaseChatModel, OpenAIChatModel } from '@/llms'
import { createTestAgenticRuntime } from './_utils' import { createTestAgenticRuntime } from '../_utils'
test('OpenAIChatModel ⇒ types', async (t) => { test('OpenAIChatModel ⇒ types', async (t) => {
const agentic = createTestAgenticRuntime() const agentic = createTestAgenticRuntime()

20
test/utils.test.ts vendored 100644
Wyświetl plik

@ -0,0 +1,20 @@
import test from 'ava'
import { isValidTaskIdentifier } from '@/utils'
test('isValidTaskIdentifier - valid', async (t) => {
t.true(isValidTaskIdentifier('foo'))
t.true(isValidTaskIdentifier('foo_bar_179'))
t.true(isValidTaskIdentifier('fooBarBAZ'))
t.true(isValidTaskIdentifier('foo-bar-baz_'))
t.true(isValidTaskIdentifier('_'))
t.true(isValidTaskIdentifier('_foo___'))
})
test('isValidTaskIdentifier - invalid', async (t) => {
t.false(isValidTaskIdentifier(null as any))
t.false(isValidTaskIdentifier(''))
t.false(isValidTaskIdentifier('-'))
t.false(isValidTaskIdentifier('x'.repeat(65)))
t.false(isValidTaskIdentifier('-foo'))
})