kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add template tests and improve template errors
rodzic
7213565c27
commit
d788cc1600
|
@ -13,7 +13,7 @@ async function main() {
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
topic: z.string(),
|
topic: z.string(),
|
||||||
numFacts: z.number().int().default(5).optional()
|
numFacts: z.number().int().default(5)
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
.output(z.object({ facts: z.array(z.string()) }))
|
.output(z.object({ facts: z.array(z.string()) }))
|
||||||
|
|
|
@ -74,3 +74,11 @@ export class OutputValidationError extends BaseError {
|
||||||
Error.captureStackTrace?.(this, this.constructor)
|
Error.captureStackTrace?.(this, this.constructor)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export class TemplateValidationError extends BaseError {
|
||||||
|
constructor(message: string, opts: ErrorOptions = {}) {
|
||||||
|
super(message, opts)
|
||||||
|
|
||||||
|
Error.captureStackTrace?.(this, this.constructor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -50,12 +50,14 @@ export abstract class BaseLLM<
|
||||||
this._examples = options.examples
|
this._examples = options.examples
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
|
||||||
input<U>(inputSchema: ZodType<U>): BaseLLM<U, TOutput, TModelParams> {
|
input<U>(inputSchema: ZodType<U>): BaseLLM<U, TOutput, TModelParams> {
|
||||||
const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams>
|
const refinedInstance = this as unknown as BaseLLM<U, TOutput, TModelParams>
|
||||||
refinedInstance._inputSchema = inputSchema
|
refinedInstance._inputSchema = inputSchema
|
||||||
return refinedInstance
|
return refinedInstance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: use polymorphic `this` type to return correct BaseLLM subclass type
|
||||||
output<U>(outputSchema: ZodType<U>): BaseLLM<TInput, U, TModelParams> {
|
output<U>(outputSchema: ZodType<U>): BaseLLM<TInput, U, TModelParams> {
|
||||||
const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams>
|
const refinedInstance = this as unknown as BaseLLM<TInput, U, TModelParams>
|
||||||
refinedInstance._outputSchema = outputSchema
|
refinedInstance._outputSchema = outputSchema
|
||||||
|
@ -84,12 +86,12 @@ export abstract class BaseLLM<
|
||||||
return `${this._provider}:chat:${this._model}`
|
return `${this._provider}:chat:${this._model}`
|
||||||
}
|
}
|
||||||
|
|
||||||
examples(examples: types.LLMExample[]) {
|
examples(examples: types.LLMExample[]): this {
|
||||||
this._examples = examples
|
this._examples = examples
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
modelParams(params: Partial<TModelParams>) {
|
modelParams(params: Partial<TModelParams>): this {
|
||||||
// We assume that modelParams does not include nested objects.
|
// We assume that modelParams does not include nested objects.
|
||||||
// If it did, we would need to do a deep merge.
|
// If it did, we would need to do a deep merge.
|
||||||
this._modelParams = { ...this._modelParams, ...params } as TModelParams
|
this._modelParams = { ...this._modelParams, ...params } as TModelParams
|
||||||
|
@ -154,6 +156,7 @@ export abstract class BaseChatModel<
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: validate input message variables against input schema
|
// TODO: validate input message variables against input schema
|
||||||
|
console.log({ input })
|
||||||
|
|
||||||
const messages = this._messages
|
const messages = this._messages
|
||||||
.map((message) => {
|
.map((message) => {
|
||||||
|
|
|
@ -1,21 +1,36 @@
|
||||||
import Handlebars from 'handlebars'
|
import Handlebars from 'handlebars'
|
||||||
import QuickLRU from 'quick-lru'
|
import QuickLRU from 'quick-lru'
|
||||||
|
|
||||||
const lru = new QuickLRU({ maxSize: 1000 })
|
import { TemplateValidationError } from './errors'
|
||||||
|
|
||||||
|
export type CompiledTemplate = (data: unknown) => string
|
||||||
|
|
||||||
|
const lru = new QuickLRU<string, CompiledTemplate>({ maxSize: 1000 })
|
||||||
|
|
||||||
export function getCompiledTemplate(template: string) {
|
export function getCompiledTemplate(template: string) {
|
||||||
let compiledTemplate = lru.get(template) as HandlebarsTemplateDelegate
|
let compiledTemplate = lru.get(template)
|
||||||
|
|
||||||
if (compiledTemplate) {
|
if (compiledTemplate) {
|
||||||
return compiledTemplate
|
return compiledTemplate
|
||||||
}
|
}
|
||||||
|
|
||||||
compiledTemplate = Handlebars.compile(template, {
|
const handlebarsTemplate = Handlebars.compile(template, {
|
||||||
noEscape: true,
|
noEscape: true,
|
||||||
|
strict: true,
|
||||||
knownHelpers: {},
|
knownHelpers: {},
|
||||||
knownHelpersOnly: true
|
knownHelpersOnly: true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
compiledTemplate = (data: unknown) => {
|
||||||
|
try {
|
||||||
|
return handlebarsTemplate(data)
|
||||||
|
} catch (err: any) {
|
||||||
|
const msg = err.message?.replace('[object Object]', 'input')
|
||||||
|
const message = ['Template error', msg].filter(Boolean).join(': ')
|
||||||
|
throw new TemplateValidationError(message, { cause: err })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
lru.set(template, compiledTemplate)
|
lru.set(template, compiledTemplate)
|
||||||
return compiledTemplate
|
return compiledTemplate
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ import { expectTypeOf } from 'expect-type'
|
||||||
import sinon from 'sinon'
|
import sinon from 'sinon'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
import { OutputValidationError } from '@/errors'
|
import { OutputValidationError, TemplateValidationError } from '@/errors'
|
||||||
import { OpenAIChatModel } from '@/llms/openai'
|
import { OpenAIChatModel } from '@/llms/openai'
|
||||||
|
|
||||||
import { createTestAgenticRuntime } from './_utils'
|
import { createTestAgenticRuntime } from './_utils'
|
||||||
|
@ -128,3 +128,61 @@ test('OpenAIChatModel ⇒ retry logic', async (t) => {
|
||||||
})
|
})
|
||||||
t.is(fakeCall.callCount, 3)
|
t.is(fakeCall.callCount, 3)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test('OpenAIChatModel ⇒ template variables', async (t) => {
|
||||||
|
t.timeout(2 * 60 * 1000)
|
||||||
|
const agentic = createTestAgenticRuntime()
|
||||||
|
|
||||||
|
const query = agentic
|
||||||
|
.gpt3(`Give me {{numFacts}} random facts about {{topic}}`)
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
topic: z.string(),
|
||||||
|
numFacts: z.number().int().default(5)
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.output(z.object({ facts: z.array(z.string()) }))
|
||||||
|
.modelParams({ temperature: 0.5 })
|
||||||
|
|
||||||
|
const res0 = await query.call({ topic: 'cats' })
|
||||||
|
|
||||||
|
t.true(Array.isArray(res0.facts))
|
||||||
|
t.is(res0.facts.length, 5)
|
||||||
|
expectTypeOf(res0).toMatchTypeOf<{ facts: string[] }>()
|
||||||
|
|
||||||
|
for (const fact of res0.facts) {
|
||||||
|
t.true(typeof fact === 'string')
|
||||||
|
}
|
||||||
|
|
||||||
|
const res1 = await query.call({ topic: 'dogs', numFacts: 2 })
|
||||||
|
|
||||||
|
t.true(Array.isArray(res1.facts))
|
||||||
|
t.is(res1.facts.length, 2)
|
||||||
|
expectTypeOf(res1).toMatchTypeOf<{ facts: string[] }>()
|
||||||
|
|
||||||
|
for (const fact of res1.facts) {
|
||||||
|
t.true(typeof fact === 'string')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
test.only('OpenAIChatModel ⇒ missing template variable', async (t) => {
|
||||||
|
t.timeout(2 * 60 * 1000)
|
||||||
|
const agentic = createTestAgenticRuntime()
|
||||||
|
|
||||||
|
const builder = agentic
|
||||||
|
.gpt3(`Give me {{numFacts}} random facts about {{topic}}`)
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
topic: z.string(),
|
||||||
|
numFacts: z.number().int().default(5).optional()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.output(z.object({ facts: z.array(z.string()) }))
|
||||||
|
.modelParams({ temperature: 0.5 })
|
||||||
|
|
||||||
|
await t.throwsAsync(() => builder.call({ topic: 'cats' }), {
|
||||||
|
instanceOf: TemplateValidationError,
|
||||||
|
name: 'TemplateValidationError',
|
||||||
|
message: 'Template error: "numFacts" not defined in input - 1:10'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
Ładowanie…
Reference in New Issue