diff --git a/examples/facts.ts b/examples/facts.ts index 7cf8977..0c308c0 100644 --- a/examples/facts.ts +++ b/examples/facts.ts @@ -13,7 +13,7 @@ async function main() { .input( z.object({ topic: z.string(), - numFacts: z.number().int().default(5).optional() + numFacts: z.number().int().default(5) }) ) .output(z.object({ facts: z.array(z.string()) })) diff --git a/src/errors.ts b/src/errors.ts index 1bb71c6..63f09a2 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -74,3 +74,11 @@ export class OutputValidationError extends BaseError { Error.captureStackTrace?.(this, this.constructor) } } + +export class TemplateValidationError extends BaseError { + constructor(message: string, opts: ErrorOptions = {}) { + super(message, opts) + + Error.captureStackTrace?.(this, this.constructor) + } +} diff --git a/src/llms/llm.ts b/src/llms/llm.ts index c014242..7eca5f0 100644 --- a/src/llms/llm.ts +++ b/src/llms/llm.ts @@ -50,12 +50,14 @@ export abstract class BaseLLM< this._examples = options.examples } + // TODO: use polymorphic `this` type to return correct BaseLLM subclass type input(inputSchema: ZodType): BaseLLM { const refinedInstance = this as unknown as BaseLLM refinedInstance._inputSchema = inputSchema return refinedInstance } + // TODO: use polymorphic `this` type to return correct BaseLLM subclass type output(outputSchema: ZodType): BaseLLM { const refinedInstance = this as unknown as BaseLLM refinedInstance._outputSchema = outputSchema @@ -84,12 +86,12 @@ export abstract class BaseLLM< return `${this._provider}:chat:${this._model}` } - examples(examples: types.LLMExample[]) { + examples(examples: types.LLMExample[]): this { this._examples = examples return this } - modelParams(params: Partial) { + modelParams(params: Partial): this { // We assume that modelParams does not include nested objects. // If it did, we would need to do a deep merge. this._modelParams = { ...this._modelParams, ...params } as TModelParams @@ -154,6 +156,7 @@ export abstract class BaseChatModel< } // TODO: validate input message variables against input schema + console.log({ input }) const messages = this._messages .map((message) => { diff --git a/src/template.ts b/src/template.ts index c2ce26e..894188e 100644 --- a/src/template.ts +++ b/src/template.ts @@ -1,21 +1,36 @@ import Handlebars from 'handlebars' import QuickLRU from 'quick-lru' -const lru = new QuickLRU({ maxSize: 1000 }) +import { TemplateValidationError } from './errors' + +export type CompiledTemplate = (data: unknown) => string + +const lru = new QuickLRU({ maxSize: 1000 }) export function getCompiledTemplate(template: string) { - let compiledTemplate = lru.get(template) as HandlebarsTemplateDelegate + let compiledTemplate = lru.get(template) if (compiledTemplate) { return compiledTemplate } - compiledTemplate = Handlebars.compile(template, { + const handlebarsTemplate = Handlebars.compile(template, { noEscape: true, + strict: true, knownHelpers: {}, knownHelpersOnly: true }) + compiledTemplate = (data: unknown) => { + try { + return handlebarsTemplate(data) + } catch (err: any) { + const msg = err.message?.replace('[object Object]', 'input') + const message = ['Template error', msg].filter(Boolean).join(': ') + throw new TemplateValidationError(message, { cause: err }) + } + } + lru.set(template, compiledTemplate) return compiledTemplate } diff --git a/test/openai.test.ts b/test/openai.test.ts index f97fafb..9863706 100644 --- a/test/openai.test.ts +++ b/test/openai.test.ts @@ -3,7 +3,7 @@ import { expectTypeOf } from 'expect-type' import sinon from 'sinon' import { z } from 'zod' -import { OutputValidationError } from '@/errors' +import { OutputValidationError, TemplateValidationError } from '@/errors' import { OpenAIChatModel } from '@/llms/openai' import { createTestAgenticRuntime } from './_utils' @@ -128,3 +128,61 @@ test('OpenAIChatModel ⇒ retry logic', async (t) => { }) t.is(fakeCall.callCount, 3) }) + +test('OpenAIChatModel ⇒ template variables', async (t) => { + t.timeout(2 * 60 * 1000) + const agentic = createTestAgenticRuntime() + + const query = agentic + .gpt3(`Give me {{numFacts}} random facts about {{topic}}`) + .input( + z.object({ + topic: z.string(), + numFacts: z.number().int().default(5) + }) + ) + .output(z.object({ facts: z.array(z.string()) })) + .modelParams({ temperature: 0.5 }) + + const res0 = await query.call({ topic: 'cats' }) + + t.true(Array.isArray(res0.facts)) + t.is(res0.facts.length, 5) + expectTypeOf(res0).toMatchTypeOf<{ facts: string[] }>() + + for (const fact of res0.facts) { + t.true(typeof fact === 'string') + } + + const res1 = await query.call({ topic: 'dogs', numFacts: 2 }) + + t.true(Array.isArray(res1.facts)) + t.is(res1.facts.length, 2) + expectTypeOf(res1).toMatchTypeOf<{ facts: string[] }>() + + for (const fact of res1.facts) { + t.true(typeof fact === 'string') + } +}) + +test.only('OpenAIChatModel ⇒ missing template variable', async (t) => { + t.timeout(2 * 60 * 1000) + const agentic = createTestAgenticRuntime() + + const builder = agentic + .gpt3(`Give me {{numFacts}} random facts about {{topic}}`) + .input( + z.object({ + topic: z.string(), + numFacts: z.number().int().default(5).optional() + }) + ) + .output(z.object({ facts: z.array(z.string()) })) + .modelParams({ temperature: 0.5 }) + + await t.throwsAsync(() => builder.call({ topic: 'cats' }), { + instanceOf: TemplateValidationError, + name: 'TemplateValidationError', + message: 'Template error: "numFacts" not defined in input - 1:10' + }) +})