feat: add support for openai's structured output generation to @agentic/core

pull/658/head
Travis Fischer 2024-08-07 04:52:49 -05:00
rodzic e3a409df67
commit f4b79d69b5
14 zmienionych plików z 1548 dodań i 188 usunięć

Wyświetl plik

@ -10,6 +10,7 @@ jobs:
fail-fast: true fail-fast: true
matrix: matrix:
node-version: node-version:
- 18
- 20 - 20
- 21 - 21
- 22 - 22
@ -27,7 +28,7 @@ jobs:
uses: pnpm/action-setup@v3 uses: pnpm/action-setup@v3
id: pnpm-install id: pnpm-install
with: with:
version: 9.6.0 version: 9.7.0
run_install: false run_install: false
- name: Get pnpm store directory - name: Get pnpm store directory

Wyświetl plik

@ -12,6 +12,7 @@ async function main() {
}) })
const result = await extractObject({ const result = await extractObject({
name: 'extract-user',
chatFn: chatModel.run.bind(chatModel), chatFn: chatModel.run.bind(chatModel),
params: { params: {
messages: [ messages: [
@ -25,7 +26,8 @@ async function main() {
name: z.string(), name: z.string(),
age: z.number(), age: z.number(),
location: z.string().optional() location: z.string().optional()
}) }),
strict: true
}) })
console.log(result) console.log(result)

Wyświetl plik

@ -40,11 +40,11 @@
"jsonrepair": "^3.6.1", "jsonrepair": "^3.6.1",
"ky": "^1.5.0", "ky": "^1.5.0",
"normalize-url": "^8.0.1", "normalize-url": "^8.0.1",
"openai-zod-to-json-schema": "^1.0.0",
"p-map": "^7.0.2", "p-map": "^7.0.2",
"p-throttle": "^6.1.0", "p-throttle": "^6.1.0",
"quick-lru": "^7.0.0", "quick-lru": "^7.0.0",
"type-fest": "^4.21.0", "type-fest": "^4.21.0",
"zod-to-json-schema": "^3.23.2",
"zod-validation-error": "^3.3.0" "zod-validation-error": "^3.3.0"
}, },
"peerDependencies": { "peerDependencies": {
@ -52,7 +52,7 @@
}, },
"devDependencies": { "devDependencies": {
"@agentic/tsconfig": "workspace:*", "@agentic/tsconfig": "workspace:*",
"openai-fetch": "^2.0.4" "openai-fetch": "^2.1.0"
}, },
"publishConfig": { "publishConfig": {
"access": "public" "access": "public"

Wyświetl plik

@ -10,15 +10,42 @@ import { asSchema, augmentSystemMessageWithJsonSchema } from './schema'
import { getErrorMessage } from './utils' import { getErrorMessage } from './utils'
export type AIChainParams<Result extends types.AIChainResult = string> = { export type AIChainParams<Result extends types.AIChainResult = string> = {
/** Name of the chain */
name: string
/** Chat completions function */
chatFn: types.ChatFn chatFn: types.ChatFn
/** Description of the chain */
description?: string
/** Optional chat completion params */
params?: types.Simplify< params?: types.Simplify<
Partial<Omit<types.ChatParams, 'tools' | 'functions'>> Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
> >
/** Optional tools */
tools?: types.AIFunctionLike[] tools?: types.AIFunctionLike[]
/** Optional response schema */
schema?: z.ZodType<Result> | types.Schema<Result> schema?: z.ZodType<Result> | types.Schema<Result>
/**
* Whether or not the response schema should be treated as strict for
* constrained structured output generation.
*/
strict?: boolean
/** Max number of LLM calls to allow */
maxCalls?: number maxCalls?: number
/** Max number of retries to allow */
maxRetries?: number maxRetries?: number
/** Max concurrency when invoking tool calls */
toolCallConcurrency?: number toolCallConcurrency?: number
/** Whether or not to inject the schema into the context */
injectSchemaIntoSystemMessage?: boolean injectSchemaIntoSystemMessage?: boolean
} }
@ -38,6 +65,8 @@ export type AIChainParams<Result extends types.AIChainResult = string> = {
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`). * exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
*/ */
export function createAIChain<Result extends types.AIChainResult = string>({ export function createAIChain<Result extends types.AIChainResult = string>({
name,
description,
chatFn, chatFn,
params, params,
schema: rawSchema, schema: rawSchema,
@ -45,13 +74,26 @@ export function createAIChain<Result extends types.AIChainResult = string>({
maxCalls = 5, maxCalls = 5,
maxRetries = 2, maxRetries = 2,
toolCallConcurrency = 8, toolCallConcurrency = 8,
injectSchemaIntoSystemMessage = true injectSchemaIntoSystemMessage = false,
strict = false
}: AIChainParams<Result>): types.AIChain<Result> { }: AIChainParams<Result>): types.AIChain<Result> {
const functionSet = new AIFunctionSet(tools) const functionSet = new AIFunctionSet(tools)
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
const defaultParams: Partial<types.ChatParams> | undefined = const defaultParams: Partial<types.ChatParams> | undefined =
rawSchema && !functionSet.size schema && !functionSet.size
? { ? {
response_format: { type: 'json_object' } response_format: strict
? {
type: 'json_schema',
json_schema: {
name,
description,
strict,
schema: schema.jsonSchema
}
}
: { type: 'json_object' }
} }
: undefined : undefined
@ -77,8 +119,6 @@ export function createAIChain<Result extends types.AIChainResult = string>({
throw new Error('AIChain error: "messages" is empty') throw new Error('AIChain error: "messages" is empty')
} }
const schema = rawSchema ? asSchema(rawSchema) : undefined
if (schema && injectSchemaIntoSystemMessage) { if (schema && injectSchemaIntoSystemMessage) {
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem) const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
const lastSystemMessageContent = const lastSystemMessageContent =
@ -101,6 +141,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
do { do {
++numCalls ++numCalls
const response = await chatFn({ const response = await chatFn({
...modelParams, ...modelParams,
messages, messages,
@ -150,7 +191,9 @@ export function createAIChain<Result extends types.AIChainResult = string>({
'Function calls are not supported; expected tool call' 'Function calls are not supported; expected tool call'
) )
} else if (Msg.isAssistant(message)) { } else if (Msg.isAssistant(message)) {
if (schema && schema.validate) { if (message.refusal) {
throw new AbortError(`Model refusal: ${message.refusal}`)
} else if (schema && schema.validate) {
const result = schema.validate(message.content) const result = schema.validate(message.content)
if (result.success) { if (result.success) {
@ -177,7 +220,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
if (numErrors > maxRetries) { if (numErrors > maxRetries) {
throw new Error( throw new Error(
`Chain failed after ${numErrors} errors: ${err.message}`, `Chain ${name} failed after ${numErrors} errors: ${err.message}`,
{ {
cause: err cause: err
} }
@ -186,6 +229,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
} }
} while (numCalls < maxCalls) } while (numCalls < maxCalls)
throw new Error(`Chain aborted after reaching max ${maxCalls} calls`) throw new Error(
`Chain "${name}" aborted after reaching max ${maxCalls} calls`
)
} }
} }

Wyświetl plik

@ -1,4 +1,4 @@
import { describe, expect, it } from 'vitest' import { describe, expect, test } from 'vitest'
import { z } from 'zod' import { z } from 'zod'
import { createAIFunction } from './create-ai-function' import { createAIFunction } from './create-ai-function'
@ -18,7 +18,7 @@ const fullName = createAIFunction(
) )
describe('createAIFunction()', () => { describe('createAIFunction()', () => {
it('exposes OpenAI function calling spec', () => { test('exposes OpenAI function calling spec', () => {
expect(fullName.spec.name).toEqual('fullName') expect(fullName.spec.name).toEqual('fullName')
expect(fullName.spec.description).toEqual( expect(fullName.spec.description).toEqual(
'Returns the full name of a person.' 'Returns the full name of a person.'
@ -34,7 +34,7 @@ describe('createAIFunction()', () => {
}) })
}) })
it('executes the function', async () => { test('executes the function', async () => {
expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual( expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual(
'John Doe' 'John Doe'
) )

Wyświetl plik

@ -22,6 +22,11 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
description?: string description?: string
/** Zod schema for the arguments string. */ /** Zod schema for the arguments string. */
inputSchema: InputSchema inputSchema: InputSchema
/**
* Whether or not to enable structured output generation based on the given
* zod schema.
*/
strict?: boolean
}, },
/** Implementation of the function to call with the parsed arguments. */ /** Implementation of the function to call with the parsed arguments. */
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output> implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
@ -59,12 +64,15 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
return implementation(parsedInput) return implementation(parsedInput)
} }
const strict = !!spec.strict
aiFunction.inputSchema = spec.inputSchema aiFunction.inputSchema = spec.inputSchema
aiFunction.parseInput = parseInput aiFunction.parseInput = parseInput
aiFunction.spec = { aiFunction.spec = {
name: spec.name, name: spec.name,
description: spec.description?.trim() ?? '', description: spec.description?.trim() ?? '',
parameters: zodToJsonSchema(spec.inputSchema) parameters: zodToJsonSchema(spec.inputSchema, { strict }),
strict
} }
aiFunction.impl = implementation aiFunction.impl = implementation

Wyświetl plik

@ -10,6 +10,7 @@ export interface PrivateAIFunctionMetadata {
description: string description: string
inputSchema: z.AnyZodObject inputSchema: z.AnyZodObject
methodName: string methodName: string
strict?: boolean
} }
// Polyfill for `Symbol.metadata` // Polyfill for `Symbol.metadata`
@ -69,11 +70,13 @@ export function aiFunction<
>({ >({
name, name,
description, description,
inputSchema inputSchema,
strict
}: { }: {
name?: string name?: string
description: string description: string
inputSchema: InputSchema inputSchema: InputSchema
strict?: boolean
}) { }) {
return ( return (
_targetMethod: ( _targetMethod: (
@ -99,7 +102,8 @@ export function aiFunction<
name: name ?? methodName, name: name ?? methodName,
description, description,
inputSchema, inputSchema,
methodName methodName,
strict
}) })
context.addInitializer(function () { context.addInitializer(function () {

Wyświetl plik

@ -1,11 +1,11 @@
import type * as OpenAI from 'openai-fetch' import type * as OpenAI from 'openai-fetch'
import { describe, expect, expectTypeOf, it } from 'vitest' import { describe, expect, expectTypeOf, test } from 'vitest'
import type * as types from './types' import type * as types from './types'
import { Msg } from './message' import { Msg } from './message'
describe('Msg', () => { describe('Msg', () => {
it('creates a message and fixes indentation', () => { test('creates a message and fixes indentation', () => {
const msgContent = ` const msgContent = `
Hello, World! Hello, World!
` `
@ -14,7 +14,7 @@ describe('Msg', () => {
expect(msg.content).toEqual('Hello, World!') expect(msg.content).toEqual('Hello, World!')
}) })
it('supports disabling indentation fixing', () => { test('supports disabling indentation fixing', () => {
const msgContent = ` const msgContent = `
Hello, World! Hello, World!
` `
@ -22,7 +22,7 @@ describe('Msg', () => {
expect(msg.content).toEqual('\n Hello, World!\n ') expect(msg.content).toEqual('\n Hello, World!\n ')
}) })
it('handles tool calls request', () => { test('handles tool calls request', () => {
const msg = Msg.toolCall([ const msg = Msg.toolCall([
{ {
id: 'fake-tool-call-id', id: 'fake-tool-call-id',
@ -37,13 +37,13 @@ describe('Msg', () => {
expect(Msg.isToolCall(msg)).toBe(true) expect(Msg.isToolCall(msg)).toBe(true)
}) })
it('handles tool call response', () => { test('handles tool call response', () => {
const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id') const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id')
expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>() expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>()
expect(Msg.isToolResult(msg)).toBe(true) expect(Msg.isToolResult(msg)).toBe(true)
}) })
it('prompt message types should interop with openai-fetch message types', () => { test('prompt message types should interop with openai-fetch message types', () => {
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>() expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>()
expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>() expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>()
expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>() expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>()

Wyświetl plik

@ -7,10 +7,17 @@ import { cleanStringForModel, stringifyForModel } from './utils'
*/ */
export interface Msg { export interface Msg {
/** /**
* The contents of the message. `content` is required for all messages, and * The contents of the message. `content` may be null for assistant messages
* may be null for assistant messages with function calls. * with function calls or `undefined` for assistant messages if a `refusal`
* was given by the model.
*/ */
content: string | null content?: string | null
/**
* The reason the model refused to generate this message or `undefined` if the
* message was generated successfully.
*/
refusal?: string | null
/** /**
* The role of the messages author. One of `system`, `user`, `assistant`, * The role of the messages author. One of `system`, `user`, `assistant`,
@ -95,7 +102,8 @@ export namespace Msg {
export type Assistant = { export type Assistant = {
role: 'assistant' role: 'assistant'
name?: string name?: string
content: string content?: string
refusal?: string
} }
/** Message with arguments to call a function. */ /** Message with arguments to call a function. */
@ -262,7 +270,7 @@ export namespace Msg {
return Msg.toolCall(msg.tool_calls) return Msg.toolCall(msg.tool_calls)
} else if (msg.content === null && msg.function_call != null) { } else if (msg.content === null && msg.function_call != null) {
return Msg.funcCall(msg.function_call) return Msg.funcCall(msg.function_call)
} else if (msg.content !== null) { } else if (msg.content !== null && msg.content !== undefined) {
return Msg.assistant(msg.content) return Msg.assistant(msg.content)
} else { } else {
// @TODO: probably don't want to error here // @TODO: probably don't want to error here

Wyświetl plik

@ -59,9 +59,10 @@ export function isZodSchema(value: unknown): value is z.ZodType {
} }
export function asSchema<TData>( export function asSchema<TData>(
schema: z.Schema<TData> | Schema<TData> schema: z.Schema<TData> | Schema<TData>,
opts: { strict?: boolean } = {}
): Schema<TData> { ): Schema<TData> {
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema) return isSchema(schema) ? schema : createSchemaFromZodSchema(schema, opts)
} }
/** /**
@ -84,9 +85,10 @@ export function createSchema<TData = unknown>(
} }
export function createSchemaFromZodSchema<TData>( export function createSchemaFromZodSchema<TData>(
zodSchema: z.Schema<TData> zodSchema: z.Schema<TData>,
opts: { strict?: boolean } = {}
): Schema<TData> { ): Schema<TData> {
return createSchema(zodToJsonSchema(zodSchema), { return createSchema(zodToJsonSchema(zodSchema, opts), {
validate: (value) => { validate: (value) => {
return safeParseStructuredOutput(value, zodSchema) return safeParseStructuredOutput(value, zodSchema)
} }

Wyświetl plik

@ -9,7 +9,7 @@ export type { Msg } from './message'
export type { Schema } from './schema' export type { Schema } from './schema'
export type { KyInstance } from 'ky' export type { KyInstance } from 'ky'
export type { ThrottledFunction } from 'p-throttle' export type { ThrottledFunction } from 'p-throttle'
export type { SetRequired, Simplify } from 'type-fest' export type { SetOptional, SetRequired, Simplify } from 'type-fest'
export type Nullable<T> = T | null export type Nullable<T> = T | null
@ -33,6 +33,15 @@ export interface AIFunctionSpec {
/** JSON schema spec of the function's input parameters */ /** JSON schema spec of the function's input parameters */
parameters: JSONSchema parameters: JSONSchema
/**
* Whether to enable strict schema adherence when generating the function
* parameters. If set to true, the model will always follow the exact schema
* defined in the `schema` field. Only a subset of JSON Schema is supported
* when `strict` is `true`. To learn more, read the
* [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
*/
strict?: boolean
} }
export interface AIToolSpec { export interface AIToolSpec {
@ -102,7 +111,17 @@ export interface ChatParams {
max_tokens?: number max_tokens?: number
presence_penalty?: number presence_penalty?: number
frequency_penalty?: number frequency_penalty?: number
response_format?: { type: 'text' | 'json_object' } response_format?:
| {
type: 'text'
}
| {
type: 'json_object'
}
| {
type: 'json_schema'
json_schema: ResponseFormatJSONSchema
}
seed?: number seed?: number
stop?: string | null | Array<string> stop?: string | null | Array<string>
temperature?: number temperature?: number
@ -111,6 +130,34 @@ export interface ChatParams {
user?: string user?: string
} }
export interface ResponseFormatJSONSchema {
/**
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
* underscores and dashes, with a maximum length of 64.
*/
name: string
/**
* A description of what the response format is for, used by the model to
* determine how to respond in the format.
*/
description?: string
/**
* The schema for the response format, described as a JSON Schema object.
*/
schema?: JSONSchema
/**
* Whether to enable strict schema adherence when generating the output. If
* set to true, the model will always follow the exact schema defined in the
* `schema` field. Only a subset of JSON Schema is supported when `strict`
* is `true`. To learn more, read the
* [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
*/
strict?: boolean
}
/** An OpenAI-compatible chat completions API */ /** An OpenAI-compatible chat completions API */
export type ChatFn = ( export type ChatFn = (
params: Simplify<SetOptional<ChatParams, 'model'>> params: Simplify<SetOptional<ChatParams, 'model'>>

Wyświetl plik

@ -1,10 +1,10 @@
import { describe, expect, it } from 'vitest' import { describe, expect, test } from 'vitest'
import { z } from 'zod' import { z } from 'zod'
import { zodToJsonSchema } from './zod-to-json-schema' import { zodToJsonSchema } from './zod-to-json-schema'
describe('zodToJsonSchema', () => { describe('zodToJsonSchema', () => {
it('handles basic objects', () => { test('handles basic objects', () => {
const params = zodToJsonSchema( const params = zodToJsonSchema(
z.object({ z.object({
name: z.string().min(1).describe('Name of the person'), name: z.string().min(1).describe('Name of the person'),
@ -30,7 +30,7 @@ describe('zodToJsonSchema', () => {
}) })
}) })
it('handles enums and unions', () => { test('handles enums and unions', () => {
const params = zodToJsonSchema( const params = zodToJsonSchema(
z.object({ z.object({
name: z.string().min(1).describe('Name of the person'), name: z.string().min(1).describe('Name of the person'),

Wyświetl plik

@ -1,13 +1,23 @@
import type { z } from 'zod' import type { z } from 'zod'
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema' import { zodToJsonSchema as zodToJsonSchemaImpl } from 'openai-zod-to-json-schema'
import type * as types from './types' import type * as types from './types'
import { omit } from './utils' import { omit } from './utils'
/** Generate a JSON Schema from a Zod schema. */ /** Generate a JSON Schema from a Zod schema. */
export function zodToJsonSchema(schema: z.ZodType): types.JSONSchema { export function zodToJsonSchema(
schema: z.ZodType,
{
strict = false
}: {
strict?: boolean
} = {}
): types.JSONSchema {
return omit( return omit(
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }), zodToJsonSchemaImpl(schema, {
$refStrategy: 'none',
openaiStrictMode: strict
}),
'$schema', '$schema',
'default', 'default',
'definitions', 'definitions',

Plik diff jest za duży Load Diff