From f668be17811e1af60900d0263f1eeca6c13867e7 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Tue, 13 Jun 2023 21:39:19 -0700 Subject: [PATCH] feat: improve openai function/task/tool support --- .../.snapshots/test/llms/llm-utils.test.ts.md | 26 ++++++++++++ .../test/llms/llm-utils.test.ts.snap | Bin 0 -> 495 bytes legacy/src/human-feedback.ts | 7 ++-- legacy/src/llms/anthropic.ts | 4 +- legacy/src/llms/chat.ts | 24 ++++++----- legacy/src/llms/index.ts | 1 + legacy/src/llms/llm-utils.ts | 38 ++++++++++++++++++ legacy/src/llms/llm.ts | 12 ++++-- legacy/src/llms/openai.ts | 4 +- legacy/src/task.ts | 20 ++++++++- legacy/src/tools/calculator.ts | 6 ++- legacy/src/tools/novu.ts | 2 +- legacy/src/types.ts | 19 ++++----- legacy/src/utils.ts | 20 ++++++--- legacy/test/calculator.test.ts | 4 +- legacy/test/{ => llms}/anthropic.test.ts | 2 +- legacy/test/llms/llm-utils.test.ts | 19 +++++++++ legacy/test/{ => llms}/openai.test.ts | 2 +- .../{openai.test.ts => openai-fetch.test.ts} | 0 legacy/test/utils.test.ts | 20 +++++++++ 20 files changed, 184 insertions(+), 46 deletions(-) create mode 100644 legacy/.snapshots/test/llms/llm-utils.test.ts.md create mode 100644 legacy/.snapshots/test/llms/llm-utils.test.ts.snap create mode 100644 legacy/src/llms/llm-utils.ts rename legacy/test/{ => llms}/anthropic.test.ts (93%) create mode 100644 legacy/test/llms/llm-utils.test.ts rename legacy/test/{ => llms}/openai.test.ts (99%) rename legacy/test/services/{openai.test.ts => openai-fetch.test.ts} (100%) create mode 100644 legacy/test/utils.test.ts diff --git a/legacy/.snapshots/test/llms/llm-utils.test.ts.md b/legacy/.snapshots/test/llms/llm-utils.test.ts.md new file mode 100644 index 00000000..94402b40 --- /dev/null +++ b/legacy/.snapshots/test/llms/llm-utils.test.ts.md @@ -0,0 +1,26 @@ +# Snapshot report for `test/llms/llm-utils.test.ts` + +The actual snapshot is saved in `llm-utils.test.ts.snap`. + +Generated by [AVA](https://avajs.dev). + +## getChatMessageFunctionDefinitionFromTask + +> Snapshot 1 + + { + description: 'Useful for getting the result of a math expression. The input to this tool should be a valid mathematical expression that could be executed by a simple calculator.', + name: 'calculator', + parameters: { + properties: { + expression: { + description: 'mathematical expression to evaluate', + type: 'string', + }, + }, + required: [ + 'expression', + ], + type: 'object', + }, + } diff --git a/legacy/.snapshots/test/llms/llm-utils.test.ts.snap b/legacy/.snapshots/test/llms/llm-utils.test.ts.snap new file mode 100644 index 0000000000000000000000000000000000000000..95ea163e34a340d66cfd1fe71f1b4efebdf5dfce GIT binary patch literal 495 zcmVS&n;l`sNwX+NPVS#p~Siue;nJ* literal 0 HcmV?d00001 diff --git a/legacy/src/human-feedback.ts b/legacy/src/human-feedback.ts index 2a71ab30..40c1d5a3 100644 --- a/legacy/src/human-feedback.ts +++ b/legacy/src/human-feedback.ts @@ -1,5 +1,4 @@ -import { ZodTypeAny } from 'zod' - +import * as types from '@/types' import { Agentic } from '@/agentic' import { BaseTask } from '@/task' @@ -37,8 +36,8 @@ export class HumanFeedbackMechanismCLI extends HumanFeedbackMechanism { } export function withHumanFeedback< - TInput extends ZodTypeAny = ZodTypeAny, - TOutput extends ZodTypeAny = ZodTypeAny + TInput extends void | types.JsonObject = void, + TOutput extends types.JsonValue = string >( task: BaseTask, options: HumanFeedbackOptions = { diff --git a/legacy/src/llms/anthropic.ts b/legacy/src/llms/anthropic.ts index f29b17e9..e7da9fa1 100644 --- a/legacy/src/llms/anthropic.ts +++ b/legacy/src/llms/anthropic.ts @@ -9,8 +9,8 @@ import { BaseChatModel } from './chat' const defaultStopSequences = [anthropic.HUMAN_PROMPT] export class AnthropicChatModel< - TInput = any, - TOutput = string + TInput extends void | types.JsonObject = any, + TOutput extends types.JsonValue = string > extends BaseChatModel< TInput, TOutput, diff --git a/legacy/src/llms/chat.ts b/legacy/src/llms/chat.ts index edeb6686..5127f202 100644 --- a/legacy/src/llms/chat.ts +++ b/legacy/src/llms/chat.ts @@ -3,7 +3,6 @@ import pMap from 'p-map' import { dedent } from 'ts-dedent' import { type SetRequired } from 'type-fest' import { ZodType, z } from 'zod' -import { zodToJsonSchema } from 'zod-to-json-schema' import { printNode, zodToTs } from 'zod-to-ts' import * as errors from '@/errors' @@ -19,8 +18,8 @@ import { BaseTask } from '../task' import { BaseLLM } from './llm' export abstract class BaseChatModel< - TInput = void, - TOutput = string, + TInput extends void | types.JsonObject = void, + TOutput extends types.JsonValue = string, TModelParams extends Record = Record, TChatCompletionResponse extends Record = Record > extends BaseLLM { @@ -40,7 +39,9 @@ export abstract class BaseChatModel< } // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - input(inputSchema: ZodType): BaseChatModel { + input( + inputSchema: ZodType + ): BaseChatModel { const refinedInstance = this as unknown as BaseChatModel< U, TOutput, @@ -51,7 +52,9 @@ export abstract class BaseChatModel< } // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - output(outputSchema: ZodType): BaseChatModel { + output( + outputSchema: ZodType + ): BaseChatModel { const refinedInstance = this as unknown as BaseChatModel< TInput, U, @@ -72,14 +75,14 @@ export abstract class BaseChatModel< return this } - protected abstract _createChatCompletion( - messages: types.ChatMessage[] - ): Promise> - public get supportsTools(): boolean { return false } + protected abstract _createChatCompletion( + messages: types.ChatMessage[] + ): Promise> + public async buildMessages( input?: TInput, ctx?: types.TaskCallContext @@ -239,7 +242,8 @@ export abstract class BaseChatModel< } } - const safeResult = outputSchema.safeParse(output) + // TODO: this doesn't bode well, batman... + const safeResult = (outputSchema.safeParse as any)(output) if (!safeResult.success) { throw new errors.ZodOutputValidationError(safeResult.error) diff --git a/legacy/src/llms/index.ts b/legacy/src/llms/index.ts index 0b742fe2..b3409acb 100644 --- a/legacy/src/llms/index.ts +++ b/legacy/src/llms/index.ts @@ -1,4 +1,5 @@ export * from './llm' +export * from './llm-utils' export * from './chat' export * from './openai' export * from './anthropic' diff --git a/legacy/src/llms/llm-utils.ts b/legacy/src/llms/llm-utils.ts new file mode 100644 index 00000000..25b0c4be --- /dev/null +++ b/legacy/src/llms/llm-utils.ts @@ -0,0 +1,38 @@ +import { zodToJsonSchema } from 'zod-to-json-schema' + +import * as types from '@/types' +import { BaseTask } from '@/task' +import { isValidTaskIdentifier } from '@/utils' + +export function getChatMessageFunctionDefinitionFromTask( + task: BaseTask +): types.openai.ChatMessageFunction { + const name = task.nameForModel + if (!isValidTaskIdentifier(name)) { + throw new Error(`Invalid task name "${name}"`) + } + + const jsonSchema = zodToJsonSchema(task.inputSchema, { + name, + $refStrategy: 'none' + }) + + const parameters: any = jsonSchema.definitions?.[name] + if (parameters) { + if (parameters.additionalProperties === false) { + delete parameters['additionalProperties'] + } + } + + return { + name, + description: task.descForModel || task.nameForHuman, + parameters + } +} + +export function getChatMessageFunctionDefinitionsFromTasks( + tasks: BaseTask[] +): types.openai.ChatMessageFunction[] { + return tasks.map(getChatMessageFunctionDefinitionFromTask) +} diff --git a/legacy/src/llms/llm.ts b/legacy/src/llms/llm.ts index 932d7c94..98e3c27b 100644 --- a/legacy/src/llms/llm.ts +++ b/legacy/src/llms/llm.ts @@ -7,8 +7,8 @@ import { Tokenizer, getTokenizerForModel } from '@/tokenizer' // TODO: TInput should only be allowed to be void or an object export abstract class BaseLLM< - TInput = void, - TOutput = string, + TInput extends void | types.JsonObject = void, + TOutput extends types.JsonValue = string, TModelParams extends Record = Record > extends BaseTask { protected _inputSchema: ZodType | undefined @@ -38,14 +38,18 @@ export abstract class BaseLLM< } // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - input(inputSchema: ZodType): BaseLLM { + input( + inputSchema: ZodType + ): BaseLLM { const refinedInstance = this as unknown as BaseLLM refinedInstance._inputSchema = inputSchema return refinedInstance } // TODO: use polymorphic `this` type to return correct BaseLLM subclass type - output(outputSchema: ZodType): BaseLLM { + output( + outputSchema: ZodType + ): BaseLLM { const refinedInstance = this as unknown as BaseLLM refinedInstance._outputSchema = outputSchema return refinedInstance diff --git a/legacy/src/llms/openai.ts b/legacy/src/llms/openai.ts index 95781013..e26cd294 100644 --- a/legacy/src/llms/openai.ts +++ b/legacy/src/llms/openai.ts @@ -13,8 +13,8 @@ const openaiModelsSupportingFunctions = new Set([ ]) export class OpenAIChatModel< - TInput = any, - TOutput = string + TInput extends void | types.JsonObject = any, + TOutput extends types.JsonValue = string > extends BaseChatModel< TInput, TOutput, diff --git a/legacy/src/task.ts b/legacy/src/task.ts index 44ea49c5..e6f504a0 100644 --- a/legacy/src/task.ts +++ b/legacy/src/task.ts @@ -18,7 +18,10 @@ import { Agentic } from '@/agentic' * - Native function calls * - Invoking sub-agents */ -export abstract class BaseTask { +export abstract class BaseTask< + TInput extends void | types.JsonObject = void, + TOutput extends types.JsonValue = string +> { protected _agentic: Agentic protected _id: string @@ -26,6 +29,10 @@ export abstract class BaseTask { protected _retryConfig: types.RetryConfig constructor(options: types.BaseTaskOptions) { + if (!options.agentic) { + throw new Error('Passing "agentic" is required when creating a Task') + } + this._agentic = options.agentic this._timeoutMs = options.timeoutMs this._retryConfig = options.retryConfig ?? { @@ -49,7 +56,7 @@ export abstract class BaseTask { public abstract get nameForModel(): string public get nameForHuman(): string { - return this.nameForModel + return this.constructor.name } public get descForModel(): string { @@ -67,11 +74,17 @@ export abstract class BaseTask { return this } + /** + * Calls this task with the given `input` and returns the result only. + */ public async call(input?: TInput): Promise { const res = await this.callWithMetadata(input) return res.result } + /** + * Calls this task with the given `input` and returns the result along with metadata. + */ public async callWithMetadata( input?: TInput ): Promise> { @@ -126,6 +139,9 @@ export abstract class BaseTask { } } + /** + * Subclasses must implement the core `_call` logic for this task. + */ protected abstract _call(ctx: types.TaskCallContext): Promise // TODO diff --git a/legacy/src/tools/calculator.ts b/legacy/src/tools/calculator.ts index fe67b017..6850374d 100644 --- a/legacy/src/tools/calculator.ts +++ b/legacy/src/tools/calculator.ts @@ -4,7 +4,9 @@ import { z } from 'zod' import * as types from '@/types' import { BaseTask } from '@/task' -export const CalculatorInputSchema = z.string().describe('expression') +export const CalculatorInputSchema = z.object({ + expression: z.string().describe('mathematical expression to evaluate') +}) export const CalculatorOutputSchema = z .number() .describe('result of calculating the expression') @@ -44,7 +46,7 @@ export class CalculatorTool extends BaseTask< protected override async _call( ctx: types.TaskCallContext ): Promise { - const result = Parser.evaluate(ctx.input!) + const result = Parser.evaluate(ctx.input!.expression) return result } } diff --git a/legacy/src/tools/novu.ts b/legacy/src/tools/novu.ts index 5f218da3..2fec0283 100644 --- a/legacy/src/tools/novu.ts +++ b/legacy/src/tools/novu.ts @@ -7,7 +7,7 @@ import { BaseTask } from '@/task' export const NovuNotificationToolInputSchema = z.object({ name: z.string(), - payload: z.record(z.unknown()), + payload: z.record(z.any()), to: z.array( z.object({ subscriberId: z.string(), diff --git a/legacy/src/types.ts b/legacy/src/types.ts index e2991eb4..37c16bbf 100644 --- a/legacy/src/types.ts +++ b/legacy/src/types.ts @@ -1,7 +1,7 @@ import * as openai from '@agentic/openai-fetch' import * as anthropic from '@anthropic-ai/sdk' import type { Options as RetryOptions } from 'p-retry' -import type { JsonObject } from 'type-fest' +import type { JsonObject, JsonValue } from 'type-fest' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import type { Agentic } from './agentic' @@ -9,6 +9,7 @@ import type { BaseTask } from './task' export { openai } export { anthropic } +export type { JsonObject, JsonValue } export type ParsedData = T extends ZodTypeAny ? output @@ -32,8 +33,8 @@ export interface BaseTaskOptions { } export interface BaseLLMOptions< - TInput = void, - TOutput = string, + TInput extends void | JsonObject = void, + TOutput extends JsonValue = string, TModelParams extends Record = Record > extends BaseTaskOptions { inputSchema?: ZodType @@ -46,8 +47,8 @@ export interface BaseLLMOptions< } export interface LLMOptions< - TInput = void, - TOutput = string, + TInput extends void | JsonObject = void, + TOutput extends JsonValue = string, TModelParams extends Record = Record > extends BaseLLMOptions { promptTemplate?: string @@ -59,8 +60,8 @@ export type ChatMessage = openai.ChatMessage export type ChatMessageRole = openai.ChatMessageRole export interface ChatModelOptions< - TInput = void, - TOutput = string, + TInput extends void | JsonObject = void, + TOutput extends JsonValue = string, TModelParams extends Record = Record > extends BaseLLMOptions { messages: ChatMessage[] @@ -105,7 +106,7 @@ export interface LLMTaskResponseMetadata< } export interface TaskResponse< - TOutput = string, + TOutput extends JsonValue = string, TMetadata extends TaskResponseMetadata = TaskResponseMetadata > { result: TOutput @@ -113,7 +114,7 @@ export interface TaskResponse< } export interface TaskCallContext< - TInput = void, + TInput extends void | JsonObject = void, TMetadata extends TaskResponseMetadata = TaskResponseMetadata > { input?: TInput diff --git a/legacy/src/utils.ts b/legacy/src/utils.ts index 1863a359..d3e8d17b 100644 --- a/legacy/src/utils.ts +++ b/legacy/src/utils.ts @@ -2,14 +2,22 @@ import { customAlphabet, urlAlphabet } from 'nanoid' import * as types from './types' -export const extractJSONObjectFromString = (text: string): string | undefined => - text.match(/\{(.|\n)*\}/gm)?.[0] +export function extractJSONObjectFromString(text: string): string | undefined { + return text.match(/\{(.|\n)*\}/gm)?.[0] +} -export const extractJSONArrayFromString = (text: string): string | undefined => - text.match(/\[(.|\n)*\]/gm)?.[0] +export function extractJSONArrayFromString(text: string): string | undefined { + return text.match(/\[(.|\n)*\]/gm)?.[0] +} -export const sleep = (ms: number) => - new Promise((resolve) => setTimeout(resolve, ms)) +export function sleep(ms: number) { + return new Promise((resolve) => setTimeout(resolve, ms)) +} export const defaultIDGeneratorFn: types.IDGeneratorFunction = customAlphabet(urlAlphabet) + +const taskNameRegex = /^[a-zA-Z_][a-zA-Z0-9_-]{0,63}$/ +export function isValidTaskIdentifier(id: string): boolean { + return !!id && taskNameRegex.test(id) +} diff --git a/legacy/test/calculator.test.ts b/legacy/test/calculator.test.ts index a12e0f5c..055fea62 100644 --- a/legacy/test/calculator.test.ts +++ b/legacy/test/calculator.test.ts @@ -9,11 +9,11 @@ test('CalculatorTool', async (t) => { const agentic = createTestAgenticRuntime() const tool = new CalculatorTool({ agentic }) - const res = await tool.call('1 + 1') + const res = await tool.call({ expression: '1 + 1' }) t.is(res, 2) expectTypeOf(res).toMatchTypeOf() - const res2 = await tool.callWithMetadata('cos(0)') + const res2 = await tool.callWithMetadata({ expression: 'cos(0)' }) t.is(res2.result, 1) expectTypeOf(res2.result).toMatchTypeOf() diff --git a/legacy/test/anthropic.test.ts b/legacy/test/llms/anthropic.test.ts similarity index 93% rename from legacy/test/anthropic.test.ts rename to legacy/test/llms/anthropic.test.ts index 00f4e2d4..bf940edf 100644 --- a/legacy/test/anthropic.test.ts +++ b/legacy/test/llms/anthropic.test.ts @@ -3,7 +3,7 @@ import { expectTypeOf } from 'expect-type' import { AnthropicChatModel } from '@/llms/anthropic' -import { createTestAgenticRuntime } from './_utils' +import { createTestAgenticRuntime } from '../_utils' test('AnthropicChatModel ⇒ string output', async (t) => { t.timeout(2 * 60 * 1000) diff --git a/legacy/test/llms/llm-utils.test.ts b/legacy/test/llms/llm-utils.test.ts new file mode 100644 index 00000000..f6d14fdd --- /dev/null +++ b/legacy/test/llms/llm-utils.test.ts @@ -0,0 +1,19 @@ +import test from 'ava' + +import { getChatMessageFunctionDefinitionFromTask } from '@/llms/llm-utils' +import { CalculatorTool } from '@/tools/calculator' + +import { createTestAgenticRuntime } from '../_utils' + +test('getChatMessageFunctionDefinitionFromTask', async (t) => { + const agentic = createTestAgenticRuntime() + + const tool = new CalculatorTool({ agentic }) + const functionDefinition = getChatMessageFunctionDefinitionFromTask(tool) + + t.is(functionDefinition.name, 'calculator') + t.is(functionDefinition.description, tool.descForModel) + + console.log(JSON.stringify(functionDefinition, null, 2)) + t.snapshot(functionDefinition) +}) diff --git a/legacy/test/openai.test.ts b/legacy/test/llms/openai.test.ts similarity index 99% rename from legacy/test/openai.test.ts rename to legacy/test/llms/openai.test.ts index 6314f2a0..c08bcc11 100644 --- a/legacy/test/openai.test.ts +++ b/legacy/test/llms/openai.test.ts @@ -6,7 +6,7 @@ import { z } from 'zod' import { OutputValidationError, TemplateValidationError } from '@/errors' import { BaseChatModel, OpenAIChatModel } from '@/llms' -import { createTestAgenticRuntime } from './_utils' +import { createTestAgenticRuntime } from '../_utils' test('OpenAIChatModel ⇒ types', async (t) => { const agentic = createTestAgenticRuntime() diff --git a/legacy/test/services/openai.test.ts b/legacy/test/services/openai-fetch.test.ts similarity index 100% rename from legacy/test/services/openai.test.ts rename to legacy/test/services/openai-fetch.test.ts diff --git a/legacy/test/utils.test.ts b/legacy/test/utils.test.ts new file mode 100644 index 00000000..dbf4c2fa --- /dev/null +++ b/legacy/test/utils.test.ts @@ -0,0 +1,20 @@ +import test from 'ava' + +import { isValidTaskIdentifier } from '@/utils' + +test('isValidTaskIdentifier - valid', async (t) => { + t.true(isValidTaskIdentifier('foo')) + t.true(isValidTaskIdentifier('foo_bar_179')) + t.true(isValidTaskIdentifier('fooBarBAZ')) + t.true(isValidTaskIdentifier('foo-bar-baz_')) + t.true(isValidTaskIdentifier('_')) + t.true(isValidTaskIdentifier('_foo___')) +}) + +test('isValidTaskIdentifier - invalid', async (t) => { + t.false(isValidTaskIdentifier(null as any)) + t.false(isValidTaskIdentifier('')) + t.false(isValidTaskIdentifier('-')) + t.false(isValidTaskIdentifier('x'.repeat(65))) + t.false(isValidTaskIdentifier('-foo')) +})