kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add template tests and improve template errors
rodzic
7213565c27
commit
d788cc1600
|
@ -13,7 +13,7 @@ async function main() {
|
|||
.input(
|
||||
z.object({
|
||||
topic: z.string(),
|
||||
numFacts: z.number().int().default(5).optional()
|
||||
numFacts: z.number().int().default(5)
|
||||
})
|
||||
)
|
||||
.output(z.object({ facts: z.array(z.string()) }))
|
||||
|
|
|
@ -74,3 +74,11 @@ export class OutputValidationError extends BaseError {
|
|||
Error.captureStackTrace?.(this, this.constructor)
|
||||
}
|
||||
}
|
||||
|
||||
export class TemplateValidationError extends BaseError {
|
||||
constructor(message: string, opts: ErrorOptions = {}) {
|
||||
super(message, opts)
|
||||
|
||||
Error.captureStackTrace?.(this, this.constructor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,12 +50,14 @@ export abstract class BaseLLM<
|
|||
this._examples = options.examples
|
||||
}
|
||||
|
||||
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
|
||||
input<U>(inputSchema: ZodType<U>): BaseLLM<U, TOutput, TModelParams> {
|
||||
const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams>
|
||||
refinedInstance._inputSchema = inputSchema
|
||||
return refinedInstance
|
||||
}
|
||||
|
||||
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
|
||||
output<U>(outputSchema: ZodType<U>): BaseLLM<TInput, U, TModelParams> {
|
||||
const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams>
|
||||
refinedInstance._outputSchema = outputSchema
|
||||
|
@ -84,12 +86,12 @@ export abstract class BaseLLM<
|
|||
return `${this._provider}:chat:${this._model}`
|
||||
}
|
||||
|
||||
examples(examples: types.LLMExample[]) {
|
||||
examples(examples: types.LLMExample[]): this {
|
||||
this._examples = examples
|
||||
return this
|
||||
}
|
||||
|
||||
modelParams(params: Partial<TModelParams>) {
|
||||
modelParams(params: Partial<TModelParams>): this {
|
||||
// We assume that modelParams does not include nested objects.
|
||||
// If it did, we would need to do a deep merge.
|
||||
this._modelParams = { ...this._modelParams, ...params } as TModelParams
|
||||
|
@ -154,6 +156,7 @@ export abstract class BaseChatModel<
|
|||
}
|
||||
|
||||
// TODO: validate input message variables against input schema
|
||||
console.log({ input })
|
||||
|
||||
const messages = this._messages
|
||||
.map((message) => {
|
||||
|
|
|
@ -1,21 +1,36 @@
|
|||
import Handlebars from 'handlebars'
|
||||
import QuickLRU from 'quick-lru'
|
||||
|
||||
const lru = new QuickLRU({ maxSize: 1000 })
|
||||
import { TemplateValidationError } from './errors'
|
||||
|
||||
export type CompiledTemplate = (data: unknown) => string
|
||||
|
||||
const lru = new QuickLRU<string, CompiledTemplate>({ maxSize: 1000 })
|
||||
|
||||
export function getCompiledTemplate(template: string) {
|
||||
let compiledTemplate = lru.get(template) as HandlebarsTemplateDelegate
|
||||
let compiledTemplate = lru.get(template)
|
||||
|
||||
if (compiledTemplate) {
|
||||
return compiledTemplate
|
||||
}
|
||||
|
||||
compiledTemplate = Handlebars.compile(template, {
|
||||
const handlebarsTemplate = Handlebars.compile(template, {
|
||||
noEscape: true,
|
||||
strict: true,
|
||||
knownHelpers: {},
|
||||
knownHelpersOnly: true
|
||||
})
|
||||
|
||||
compiledTemplate = (data: unknown) => {
|
||||
try {
|
||||
return handlebarsTemplate(data)
|
||||
} catch (err: any) {
|
||||
const msg = err.message?.replace('[object Object]', 'input')
|
||||
const message = ['Template error', msg].filter(Boolean).join(': ')
|
||||
throw new TemplateValidationError(message, { cause: err })
|
||||
}
|
||||
}
|
||||
|
||||
lru.set(template, compiledTemplate)
|
||||
return compiledTemplate
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ import { expectTypeOf } from 'expect-type'
|
|||
import sinon from 'sinon'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { OutputValidationError } from '@/errors'
|
||||
import { OutputValidationError, TemplateValidationError } from '@/errors'
|
||||
import { OpenAIChatModel } from '@/llms/openai'
|
||||
|
||||
import { createTestAgenticRuntime } from './_utils'
|
||||
|
@ -128,3 +128,61 @@ test('OpenAIChatModel ⇒ retry logic', async (t) => {
|
|||
})
|
||||
t.is(fakeCall.callCount, 3)
|
||||
})
|
||||
|
||||
test('OpenAIChatModel ⇒ template variables', async (t) => {
|
||||
t.timeout(2 * 60 * 1000)
|
||||
const agentic = createTestAgenticRuntime()
|
||||
|
||||
const query = agentic
|
||||
.gpt3(`Give me {{numFacts}} random facts about {{topic}}`)
|
||||
.input(
|
||||
z.object({
|
||||
topic: z.string(),
|
||||
numFacts: z.number().int().default(5)
|
||||
})
|
||||
)
|
||||
.output(z.object({ facts: z.array(z.string()) }))
|
||||
.modelParams({ temperature: 0.5 })
|
||||
|
||||
const res0 = await query.call({ topic: 'cats' })
|
||||
|
||||
t.true(Array.isArray(res0.facts))
|
||||
t.is(res0.facts.length, 5)
|
||||
expectTypeOf(res0).toMatchTypeOf<{ facts: string[] }>()
|
||||
|
||||
for (const fact of res0.facts) {
|
||||
t.true(typeof fact === 'string')
|
||||
}
|
||||
|
||||
const res1 = await query.call({ topic: 'dogs', numFacts: 2 })
|
||||
|
||||
t.true(Array.isArray(res1.facts))
|
||||
t.is(res1.facts.length, 2)
|
||||
expectTypeOf(res1).toMatchTypeOf<{ facts: string[] }>()
|
||||
|
||||
for (const fact of res1.facts) {
|
||||
t.true(typeof fact === 'string')
|
||||
}
|
||||
})
|
||||
|
||||
test.only('OpenAIChatModel ⇒ missing template variable', async (t) => {
|
||||
t.timeout(2 * 60 * 1000)
|
||||
const agentic = createTestAgenticRuntime()
|
||||
|
||||
const builder = agentic
|
||||
.gpt3(`Give me {{numFacts}} random facts about {{topic}}`)
|
||||
.input(
|
||||
z.object({
|
||||
topic: z.string(),
|
||||
numFacts: z.number().int().default(5).optional()
|
||||
})
|
||||
)
|
||||
.output(z.object({ facts: z.array(z.string()) }))
|
||||
.modelParams({ temperature: 0.5 })
|
||||
|
||||
await t.throwsAsync(() => builder.call({ topic: 'cats' }), {
|
||||
instanceOf: TemplateValidationError,
|
||||
name: 'TemplateValidationError',
|
||||
message: 'Template error: "numFacts" not defined in input - 1:10'
|
||||
})
|
||||
})
|
||||
|
|
Ładowanie…
Reference in New Issue