diff --git a/package.json b/package.json index ca641c3..3fe2de5 100644 --- a/package.json +++ b/package.json @@ -29,6 +29,7 @@ "dist" ], "scripts": { + "preinstall": "npx only-allow pnpm", "build": "tsup", "dev": "tsup --watch", "clean": "del dist", @@ -71,6 +72,7 @@ "lint-staged": "^15.2.4", "np": "^10.0.5", "npm-run-all2": "^6.2.0", + "only-allow": "^1.2.1", "prettier": "^3.2.5", "tsup": "^8.0.2", "tsx": "^4.10.5", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a7f8541..c95a9ec 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -84,6 +84,9 @@ importers: npm-run-all2: specifier: ^6.2.0 version: 6.2.0 + only-allow: + specifier: ^1.2.1 + version: 1.2.1 prettier: specifier: ^3.2.5 version: 3.2.5 @@ -2505,6 +2508,10 @@ packages: resolution: {integrity: sha512-VXJjc87FScF88uafS3JllDgvAm+c/Slfz06lorj2uAY34rlUu0Nt+v8wreiImcrgAjjIHp1rXpTDlLOGw29WwQ==} engines: {node: '>=18'} + only-allow@1.2.1: + resolution: {integrity: sha512-M7CJbmv7UCopc0neRKdzfoGWaVZC+xC1925GitKH9EAqYFzX9//25Q7oX4+jw0tiCCj+t5l6VZh8UPH23NZkMA==} + hasBin: true + open@10.1.0: resolution: {integrity: sha512-mnkeQ1qP5Ue2wd+aivTD3NHd/lZ96Lu0jgf0pwktLPtx6cTZiH7tyeGRRHs0zX0rbrahXPnXlUnbeXyaBBuIaw==} engines: {node: '>=18'} @@ -3429,6 +3436,10 @@ packages: resolution: {integrity: sha512-K4jVyjnBdgvc86Y6BkaLZEN933SwYOuBFkdmBu9ZfkcAbdVbpITnDmjvZ/aQjRXQrv5EPkTnD1s39GiiqbngCw==} engines: {node: '>= 0.4'} + which-pm-runs@1.1.0: + resolution: {integrity: sha512-n1brCuqClxfFfq/Rb0ICg9giSZqCS+pLtccdag6C2HyufBrh3fBOiy9nb6ggRMvWOVH5GrdJskj5iGTZNxd7SA==} + engines: {node: '>=4'} + which-typed-array@1.1.15: resolution: {integrity: sha512-oV0jmFtUky6CXfkqehVvBP/LSWJ2sy4vWMioiENyJLePrBO/yKyV9OyJySfAKosh+RYkIl5zJCNZ8/4JncrpdA==} engines: {node: '>= 0.4'} @@ -6040,6 +6051,10 @@ snapshots: dependencies: mimic-function: 5.0.1 + only-allow@1.2.1: + dependencies: + which-pm-runs: 1.1.0 + open@10.1.0: dependencies: default-browser: 5.2.1 @@ -7011,6 +7026,8 @@ snapshots: is-weakmap: 2.0.2 is-weakset: 2.0.3 + which-pm-runs@1.1.0: {} + which-typed-array@1.1.15: dependencies: available-typed-arrays: 1.0.7 diff --git a/src/ai-function-set.ts b/src/ai-function-set.ts new file mode 100644 index 0000000..f2ca814 --- /dev/null +++ b/src/ai-function-set.ts @@ -0,0 +1,70 @@ +import type { AIToolSet } from './ai-tool-set.js' +import type * as types from './types.ts' + +export class AIFunctionSet implements Iterable { + protected readonly _map: Map + + constructor(functions?: readonly types.AIFunction[]) { + this._map = new Map(functions ? functions.map((fn) => [fn.name, fn]) : null) + } + + get size(): number { + return this._map.size + } + + add(fn: types.AIFunction): this { + this._map.set(fn.name, fn) + return this + } + + get(name: string): types.AIFunction | undefined { + return this._map.get(name) + } + + set(name: string, fn: types.AIFunction): this { + this._map.set(name, fn) + return this + } + + has(name: string): boolean { + return this._map.has(name) + } + + clear(): void { + this._map.clear() + } + + delete(name: string): boolean { + return this._map.delete(name) + } + + pick(...keys: string[]): AIFunctionSet { + const keysToIncludeSet = new Set(keys) + return new AIFunctionSet( + Array.from(this).filter((fn) => keysToIncludeSet.has(fn.spec.name)) + ) + } + + omit(...keys: string[]): AIFunctionSet { + const keysToExcludeSet = new Set(keys) + return new AIFunctionSet( + Array.from(this).filter((fn) => !keysToExcludeSet.has(fn.spec.name)) + ) + } + + get entries(): IterableIterator { + return this._map.values() + } + + [Symbol.iterator](): Iterator { + return this.entries + } + + static fromAIToolSet(tools: AIToolSet): AIFunctionSet { + return new AIFunctionSet( + Array.from(tools) + .filter((tool) => tool.spec.type === 'function') + .map((tool) => tool.function) + ) + } +} diff --git a/src/ai-function.test.ts b/src/ai-function.test.ts new file mode 100644 index 0000000..8b5262a --- /dev/null +++ b/src/ai-function.test.ts @@ -0,0 +1,42 @@ +import { describe, expect, it } from 'vitest' +import { z } from 'zod' + +import { createAIFunction } from './ai-function.js' + +const fullName = createAIFunction( + { + name: 'fullName', + description: 'Returns the full name of a person.', + inputSchema: z.object({ + first: z.string(), + last: z.string() + }) + }, + async ({ first, last }) => { + return `${first} ${last}` + } +) + +describe('createAIFunction()', () => { + it('exposes OpenAI function calling spec', () => { + expect(fullName.spec.name).toEqual('fullName') + expect(fullName.spec.description).toEqual( + 'Returns the full name of a person.' + ) + expect(fullName.spec.parameters).toEqual({ + properties: { + first: { type: 'string' }, + last: { type: 'string' } + }, + required: ['first', 'last'], + type: 'object', + additionalProperties: false + }) + }) + + it('executes the function', async () => { + expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual( + 'John Doe' + ) + }) +}) diff --git a/src/ai-function.ts b/src/ai-function.ts new file mode 100644 index 0000000..504c972 --- /dev/null +++ b/src/ai-function.ts @@ -0,0 +1,60 @@ +import type { z } from 'zod' + +import type * as types from './types.js' +import { parseStructuredOutput } from './parse-structured-output.js' +import { assert } from './utils.js' +import { zodToJsonSchema } from './zod-to-json-schema.js' + +/** + * Create a function meant to be used with OpenAI tool or function calling. + * + * The returned function will parse the arguments string and call the + * implementation function with the parsed arguments. + * + * The `spec` property of the returned function is the spec for adding the + * function to the OpenAI API `functions` property. + */ +export function createAIFunction, Return>( + spec: { + /** Name of the function. */ + name: string + /** Description of the function. */ + description?: string + /** Zod schema for the arguments string. */ + inputSchema: InputSchema + }, + /** Implementation of the function to call with the parsed arguments. */ + implementation: (params: z.infer) => types.MaybePromise +): types.AIFunction { + /** Parse the arguments string, optionally reading from a message. */ + const parseInput = (input: string | types.Msg) => { + if (typeof input === 'string') { + return parseStructuredOutput(input, spec.inputSchema) + } else { + const args = input.function_call?.arguments + assert( + args, + `Missing required function_call.arguments for function ${spec.name}` + ) + return parseStructuredOutput(args, spec.inputSchema) + } + } + + // Call the implementation function with the parsed arguments. + const aiFunction: types.AIFunction = ( + input: string | types.Msg + ) => { + const parsedInput = parseInput(input) + return implementation(parsedInput) + } + + aiFunction.inputSchema = spec.inputSchema + aiFunction.parseInput = parseInput + aiFunction.spec = { + name: spec.name, + description: spec.description?.trim() ?? '', + parameters: zodToJsonSchema(spec.inputSchema) + } + + return aiFunction +} diff --git a/src/ai-tool-set.ts b/src/ai-tool-set.ts new file mode 100644 index 0000000..cb1d77e --- /dev/null +++ b/src/ai-tool-set.ts @@ -0,0 +1,88 @@ +import type * as types from './types.js' +import { AIFunctionSet } from './ai-function-set.js' + +export class AIToolSet implements Iterable { + protected _map: Map + + constructor(tools?: readonly types.AITool[]) { + this._map = new Map( + tools ? tools.map((tool) => [tool.function.name, tool]) : [] + ) + } + + get size(): number { + return this._map.size + } + + add(tool: types.AITool): this { + this._map.set(tool.function.name, tool) + return this + } + + get(name: string): types.AITool | undefined { + return this._map.get(name) + } + + set(name: string, tool: types.AITool): this { + this._map.set(name, tool) + return this + } + + has(name: string): boolean { + return this._map.has(name) + } + + clear(): void { + this._map.clear() + } + + delete(name: string): boolean { + return this._map.delete(name) + } + + pick(...keys: string[]): AIToolSet { + const keysToIncludeSet = new Set(keys) + return new AIToolSet( + Array.from(this).filter((tool) => + keysToIncludeSet.has(tool.function.name) + ) + ) + } + + omit(...keys: string[]): AIToolSet { + const keysToExcludeSet = new Set(keys) + return new AIToolSet( + Array.from(this).filter( + (tool) => !keysToExcludeSet.has(tool.function.name) + ) + ) + } + + get entries(): IterableIterator { + return this._map.values() + } + + [Symbol.iterator](): Iterator { + return this.entries + } + + static fromAIFunctionSet(functions: AIFunctionSet): AIToolSet { + return new AIToolSet( + Array.from(functions).map((fn) => ({ + function: fn, + spec: { + type: 'function' as const, + function: fn.spec + } + })) + ) + } + + static fromFunctions(functions: types.AIFunction[]): AIToolSet { + return AIToolSet.fromAIFunctionSet(new AIFunctionSet(functions)) + } + + static fromTools(tools: types.AITool[]): AIToolSet { + return new AIToolSet(tools) + } +} diff --git a/src/fns.ts b/src/fns.ts index 06a5d14..cef242a 100644 --- a/src/fns.ts +++ b/src/fns.ts @@ -1,44 +1,40 @@ import './symbol-polyfill.js' -import type { z } from 'zod' +import type * as z from 'zod' import type * as types from './types.js' -import { FunctionSet } from './function-set.js' -import { ToolSet } from './tool-set.js' -import { zodToJsonSchema } from './zod-to-json-schema.js' +import { createAIFunction } from './ai-function.js' +import { AIFunctionSet } from './ai-function-set.js' +import { AIToolSet } from './ai-tool-set.js' +import { assert } from './utils.js' export const invocableMetadataKey = Symbol('invocable') export abstract class AIToolsProvider { - private _tools?: ToolSet - private _functions?: FunctionSet + private _tools?: AIToolSet + private _functions?: AIFunctionSet - get namespace() { - return this.constructor.name - } - - get tools(): ToolSet { + get tools(): AIToolSet { if (!this._tools) { - this._tools = ToolSet.fromFunctionSet(this.functions) + this._tools = AIToolSet.fromAIFunctionSet(this.functions) } return this._tools } - get functions(): FunctionSet { + get functions(): AIFunctionSet { if (!this._functions) { const metadata = this.constructor[Symbol.metadata] const invocables = (metadata?.invocables as Invocable[]) ?? [] - const namespace = this.namespace - const functions = invocables.map((invocable) => ({ - ...invocable, - name: invocable.name ?? `${namespace}_${invocable.propertyKey}`, - callback: (this as any)[invocable.propertyKey].bind(target) - })) + const aiFunctions = invocables.map((invocable) => { + const impl = (this as any)[invocable.methodName]?.bind(this) + assert(impl) - const functions = invocables.map(getFunctionSpec) - this._functions = new FunctionSet(functions) + return createAIFunction(invocable, impl) + }) + + this._functions = new AIFunctionSet(aiFunctions) } return this._functions @@ -48,29 +44,15 @@ export abstract class AIToolsProvider { export interface Invocable { name: string description?: string - inputSchema?: z.AnyZodObject - callback: (args: Record) => Promise -} - -function getFunctionSpec(invocable: Invocable): types.AIFunctionSpec { - const { name, description, inputSchema } = invocable - - return { - name, - description, - parameters: inputSchema - ? zodToJsonSchema(inputSchema) - : { - type: 'object', - properties: {} - } - } + inputSchema: z.AnyZodObject + methodName: string } export function aiFunction< This, - Args extends any[], - Return extends Promise + InputSchema extends z.SomeZodObject, + OptionalArgs extends Array, + Return extends types.MaybePromise >({ name, description, @@ -78,16 +60,21 @@ export function aiFunction< }: { name?: string description?: string - - // params must be an object, so the underlying function should only expect a - // single parameter - inputSchema?: z.AnyZodObject + inputSchema: InputSchema }) { return ( - targetMethod: (this: This, ...args: Args) => Return, + _targetMethod: ( + this: This, + input: z.infer, + ...optionalArgs: OptionalArgs + ) => Return, context: ClassMethodDecoratorContext< This, - (this: This, ...args: Args) => Return + ( + this: This, + input: z.infer, + ...optionalArgs: OptionalArgs + ) => Return > ) => { const methodName = String(context.name) @@ -99,18 +86,11 @@ export function aiFunction< name: name ?? methodName, description, inputSchema, - callback: targetMethod + methodName }) - return targetMethod - - // function replacementMethod(this: This, ...args: Args): Return { - // console.log(`LOG: Entering method '${methodName}'.`) - // const result = targetMethod.call(this, ...args) - // console.log(`LOG: Exiting method '${methodName}'.`) - // return result - // } - - // return replacementMethod + // context.addInitializer(function () { + // ;(this as any)[methodName] = (this as any)[methodName].bind(this) + // }) } } diff --git a/src/function-set.ts b/src/function-set.ts deleted file mode 100644 index 0f6085c..0000000 --- a/src/function-set.ts +++ /dev/null @@ -1,70 +0,0 @@ -import type { ToolSet } from './tool-set.js' -import type * as types from './types.ts' - -export class FunctionSet implements Iterable { - protected _map: Map - - constructor(functions?: readonly types.AIFunctionSpec[] | null) { - this._map = new Map(functions ? functions.map((fn) => [fn.name, fn]) : null) - } - - get size(): number { - return this._map.size - } - - add(fn: types.AIFunctionSpec): this { - this._map.set(fn.name, fn) - return this - } - - get(name: string): types.AIFunctionSpec | undefined { - return this._map.get(name) - } - - set(name: string, fn: types.AIFunctionSpec): this { - this._map.set(name, fn) - return this - } - - has(name: string): boolean { - return this._map.has(name) - } - - clear(): void { - this._map.clear() - } - - delete(name: string): boolean { - return this._map.delete(name) - } - - pick(...keys: string[]): FunctionSet { - const keysToIncludeSet = new Set(keys) - return new FunctionSet( - Array.from(this).filter((fn) => keysToIncludeSet.has(fn.name)) - ) - } - - omit(...keys: string[]): FunctionSet { - const keysToExcludeSet = new Set(keys) - return new FunctionSet( - Array.from(this).filter((fn) => !keysToExcludeSet.has(fn.name)) - ) - } - - get entries(): IterableIterator { - return this._map.values() - } - - [Symbol.iterator](): Iterator { - return this.entries - } - - static fromToolSet(toolSet: ToolSet): FunctionSet { - return new FunctionSet( - Array.from(toolSet) - .filter((tool) => tool.type === 'function') - .map((tool) => tool.function) - ) - } -} diff --git a/src/index.ts b/src/index.ts index 43c8648..2cfdd4c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,11 @@ -export * from './function-set.js' +export * from './ai-function.js' +export * from './ai-function-set.js' +export * from './ai-tool-set.js' +export * from './errors.js' +export * from './fns.js' export * from './parse-structured-output.js' export * from './services/index.js' -export * from './tool-set.js' +export * from './stringify-for-model.js' export type * from './types.js' export * from './utils.js' +export * from './zod-to-json-schema.js' diff --git a/src/tool-set.ts b/src/tool-set.ts deleted file mode 100644 index b0250df..0000000 --- a/src/tool-set.ts +++ /dev/null @@ -1,85 +0,0 @@ -import type * as types from './types.ts' -import { FunctionSet } from './function-set.js' - -export class ToolSet implements Iterable { - protected _map: Map - - constructor(tools?: readonly types.AIToolSpec[] | null) { - this._map = new Map( - tools ? tools.map((tool) => [tool.function.name, tool]) : null - ) - } - - get size(): number { - return this._map.size - } - - add(tool: types.AIToolSpec): this { - this._map.set(tool.function.name, tool) - return this - } - - get(name: string): types.AIToolSpec | undefined { - return this._map.get(name) - } - - set(name: string, tool: types.AIToolSpec): this { - this._map.set(name, tool) - return this - } - - has(name: string): boolean { - return this._map.has(name) - } - - clear(): void { - this._map.clear() - } - - delete(name: string): boolean { - return this._map.delete(name) - } - - pick(...keys: string[]): ToolSet { - const keysToIncludeSet = new Set(keys) - return new ToolSet( - Array.from(this).filter((tool) => - keysToIncludeSet.has(tool.function.name) - ) - ) - } - - omit(...keys: string[]): ToolSet { - const keysToExcludeSet = new Set(keys) - return new ToolSet( - Array.from(this).filter( - (tool) => !keysToExcludeSet.has(tool.function.name) - ) - ) - } - - get entries(): IterableIterator { - return this._map.values() - } - - [Symbol.iterator](): Iterator { - return this.entries - } - - static fromFunctionSet(functionSet: FunctionSet): ToolSet { - return new ToolSet( - Array.from(functionSet).map((fn) => ({ - type: 'function' as const, - function: fn - })) - ) - } - - static fromFunctionSpecs(functionSpecs: types.AIFunctionSpec[]): ToolSet { - return ToolSet.fromFunctionSet(new FunctionSet(functionSpecs)) - } - - static fromToolSpecs(toolSpecs: types.AIToolSpec[]): ToolSet { - return new ToolSet(toolSpecs) - } -} diff --git a/src/types.ts b/src/types.ts index bb1979c..8ae4cb2 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,9 +1,16 @@ +import type { Jsonifiable } from 'type-fest' +import type { z } from 'zod' + export type { KyInstance } from 'ky' export type { ThrottledFunction } from 'p-throttle' // TODO export type DeepNullable = T | null +export type MaybePromise = T | Promise + +export type RelaxedJsonifiable = Jsonifiable | Record + export interface AIFunctionSpec { name: string description?: string @@ -16,16 +23,58 @@ export interface AIToolSpec { } /** - * Generic/default OpenAI message without any narrowing applied + * A function meant to be used with LLM function calling. + */ +export interface AIFunction< + InputSchema extends z.ZodObject = z.ZodObject, + Return = any +> { + /** The implementation of the function, with arg parsing and validation. */ + (input: string | Msg): MaybePromise + + /** The Zod schema for the arguments string. */ + inputSchema: InputSchema + + /** Parse the function arguments from a message. */ + parseInput(input: string | Msg): z.infer + + /** The function spec for the OpenAI API `functions` property. */ + spec: AIFunctionSpec +} + +/** + * A tool meant to be used with LLM function calling. + */ +export interface AITool< + InputSchema extends z.ZodObject = z.ZodObject, + Return = any +> { + function: AIFunction + + /** The tool spec for the OpenAI API `tools` property. */ + spec: AIToolSpec +} + +/** + * Generic/default OpenAI message without any narrowing applied. */ export interface Msg { - /** The contents of the message. `content` is required for all messages, and may be null for assistant messages with function calls. */ + /** + * The contents of the message. `content` is required for all messages, and + * may be null for assistant messages with function calls. + */ content: string | null - /** The role of the messages author. One of `system`, `user`, `assistant`, 'tool', or `function`. */ + /** + * The role of the messages author. One of `system`, `user`, `assistant`, + * 'tool', or `function`. + */ role: Msg.Role - /** The name and arguments of a function that should be called, as generated by the model. */ + /** + * The name and arguments of a function that should be called, as generated + * by the model. + */ function_call?: Msg.Call.Function /** The tool calls generated by the model, such as function calls. */ @@ -45,15 +94,21 @@ export interface Msg { name?: string } -/** Narrowed ChatModel.Message types. */ +/** Narrowed Message types. */ export namespace Msg { - /** The possible roles for a message. */ + /** Possible roles for a message. */ export type Role = 'system' | 'user' | 'assistant' | 'function' | 'tool' export namespace Call { - /** The name and arguments of a function that should be called, as generated by the model. */ + /** + * The name and arguments of a function that should be called, as generated + * by the model. + */ export type Function = { - /** The arguments to call the function with, as generated by the model in JSON format. */ + /** + * The arguments to call the function with, as generated by the model in + * JSON format. + */ arguments: string /** The name of the function to call. */