feat: add support for openai's structured output generation to @agentic/core

pull/658/head
Travis Fischer 2024-08-07 04:52:49 -05:00
rodzic e3a409df67
commit f4b79d69b5
14 zmienionych plików z 1548 dodań i 188 usunięć

Wyświetl plik

@ -10,6 +10,7 @@ jobs:
fail-fast: true
matrix:
node-version:
- 18
- 20
- 21
- 22
@ -27,7 +28,7 @@ jobs:
uses: pnpm/action-setup@v3
id: pnpm-install
with:
version: 9.6.0
version: 9.7.0
run_install: false
- name: Get pnpm store directory

Wyświetl plik

@ -12,6 +12,7 @@ async function main() {
})
const result = await extractObject({
name: 'extract-user',
chatFn: chatModel.run.bind(chatModel),
params: {
messages: [
@ -25,7 +26,8 @@ async function main() {
name: z.string(),
age: z.number(),
location: z.string().optional()
})
}),
strict: true
})
console.log(result)

Wyświetl plik

@ -40,11 +40,11 @@
"jsonrepair": "^3.6.1",
"ky": "^1.5.0",
"normalize-url": "^8.0.1",
"openai-zod-to-json-schema": "^1.0.0",
"p-map": "^7.0.2",
"p-throttle": "^6.1.0",
"quick-lru": "^7.0.0",
"type-fest": "^4.21.0",
"zod-to-json-schema": "^3.23.2",
"zod-validation-error": "^3.3.0"
},
"peerDependencies": {
@ -52,7 +52,7 @@
},
"devDependencies": {
"@agentic/tsconfig": "workspace:*",
"openai-fetch": "^2.0.4"
"openai-fetch": "^2.1.0"
},
"publishConfig": {
"access": "public"

Wyświetl plik

@ -10,15 +10,42 @@ import { asSchema, augmentSystemMessageWithJsonSchema } from './schema'
import { getErrorMessage } from './utils'
export type AIChainParams<Result extends types.AIChainResult = string> = {
/** Name of the chain */
name: string
/** Chat completions function */
chatFn: types.ChatFn
/** Description of the chain */
description?: string
/** Optional chat completion params */
params?: types.Simplify<
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
>
/** Optional tools */
tools?: types.AIFunctionLike[]
/** Optional response schema */
schema?: z.ZodType<Result> | types.Schema<Result>
/**
* Whether or not the response schema should be treated as strict for
* constrained structured output generation.
*/
strict?: boolean
/** Max number of LLM calls to allow */
maxCalls?: number
/** Max number of retries to allow */
maxRetries?: number
/** Max concurrency when invoking tool calls */
toolCallConcurrency?: number
/** Whether or not to inject the schema into the context */
injectSchemaIntoSystemMessage?: boolean
}
@ -38,6 +65,8 @@ export type AIChainParams<Result extends types.AIChainResult = string> = {
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
*/
export function createAIChain<Result extends types.AIChainResult = string>({
name,
description,
chatFn,
params,
schema: rawSchema,
@ -45,13 +74,26 @@ export function createAIChain<Result extends types.AIChainResult = string>({
maxCalls = 5,
maxRetries = 2,
toolCallConcurrency = 8,
injectSchemaIntoSystemMessage = true
injectSchemaIntoSystemMessage = false,
strict = false
}: AIChainParams<Result>): types.AIChain<Result> {
const functionSet = new AIFunctionSet(tools)
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
const defaultParams: Partial<types.ChatParams> | undefined =
rawSchema && !functionSet.size
schema && !functionSet.size
? {
response_format: { type: 'json_object' }
response_format: strict
? {
type: 'json_schema',
json_schema: {
name,
description,
strict,
schema: schema.jsonSchema
}
}
: { type: 'json_object' }
}
: undefined
@ -77,8 +119,6 @@ export function createAIChain<Result extends types.AIChainResult = string>({
throw new Error('AIChain error: "messages" is empty')
}
const schema = rawSchema ? asSchema(rawSchema) : undefined
if (schema && injectSchemaIntoSystemMessage) {
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
const lastSystemMessageContent =
@ -101,6 +141,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
do {
++numCalls
const response = await chatFn({
...modelParams,
messages,
@ -150,7 +191,9 @@ export function createAIChain<Result extends types.AIChainResult = string>({
'Function calls are not supported; expected tool call'
)
} else if (Msg.isAssistant(message)) {
if (schema && schema.validate) {
if (message.refusal) {
throw new AbortError(`Model refusal: ${message.refusal}`)
} else if (schema && schema.validate) {
const result = schema.validate(message.content)
if (result.success) {
@ -177,7 +220,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
if (numErrors > maxRetries) {
throw new Error(
`Chain failed after ${numErrors} errors: ${err.message}`,
`Chain ${name} failed after ${numErrors} errors: ${err.message}`,
{
cause: err
}
@ -186,6 +229,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
}
} while (numCalls < maxCalls)
throw new Error(`Chain aborted after reaching max ${maxCalls} calls`)
throw new Error(
`Chain "${name}" aborted after reaching max ${maxCalls} calls`
)
}
}

Wyświetl plik

@ -1,4 +1,4 @@
import { describe, expect, it } from 'vitest'
import { describe, expect, test } from 'vitest'
import { z } from 'zod'
import { createAIFunction } from './create-ai-function'
@ -18,7 +18,7 @@ const fullName = createAIFunction(
)
describe('createAIFunction()', () => {
it('exposes OpenAI function calling spec', () => {
test('exposes OpenAI function calling spec', () => {
expect(fullName.spec.name).toEqual('fullName')
expect(fullName.spec.description).toEqual(
'Returns the full name of a person.'
@ -34,7 +34,7 @@ describe('createAIFunction()', () => {
})
})
it('executes the function', async () => {
test('executes the function', async () => {
expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual(
'John Doe'
)

Wyświetl plik

@ -22,6 +22,11 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
description?: string
/** Zod schema for the arguments string. */
inputSchema: InputSchema
/**
* Whether or not to enable structured output generation based on the given
* zod schema.
*/
strict?: boolean
},
/** Implementation of the function to call with the parsed arguments. */
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
@ -59,12 +64,15 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
return implementation(parsedInput)
}
const strict = !!spec.strict
aiFunction.inputSchema = spec.inputSchema
aiFunction.parseInput = parseInput
aiFunction.spec = {
name: spec.name,
description: spec.description?.trim() ?? '',
parameters: zodToJsonSchema(spec.inputSchema)
parameters: zodToJsonSchema(spec.inputSchema, { strict }),
strict
}
aiFunction.impl = implementation

Wyświetl plik

@ -10,6 +10,7 @@ export interface PrivateAIFunctionMetadata {
description: string
inputSchema: z.AnyZodObject
methodName: string
strict?: boolean
}
// Polyfill for `Symbol.metadata`
@ -69,11 +70,13 @@ export function aiFunction<
>({
name,
description,
inputSchema
inputSchema,
strict
}: {
name?: string
description: string
inputSchema: InputSchema
strict?: boolean
}) {
return (
_targetMethod: (
@ -99,7 +102,8 @@ export function aiFunction<
name: name ?? methodName,
description,
inputSchema,
methodName
methodName,
strict
})
context.addInitializer(function () {

Wyświetl plik

@ -1,11 +1,11 @@
import type * as OpenAI from 'openai-fetch'
import { describe, expect, expectTypeOf, it } from 'vitest'
import { describe, expect, expectTypeOf, test } from 'vitest'
import type * as types from './types'
import { Msg } from './message'
describe('Msg', () => {
it('creates a message and fixes indentation', () => {
test('creates a message and fixes indentation', () => {
const msgContent = `
Hello, World!
`
@ -14,7 +14,7 @@ describe('Msg', () => {
expect(msg.content).toEqual('Hello, World!')
})
it('supports disabling indentation fixing', () => {
test('supports disabling indentation fixing', () => {
const msgContent = `
Hello, World!
`
@ -22,7 +22,7 @@ describe('Msg', () => {
expect(msg.content).toEqual('\n Hello, World!\n ')
})
it('handles tool calls request', () => {
test('handles tool calls request', () => {
const msg = Msg.toolCall([
{
id: 'fake-tool-call-id',
@ -37,13 +37,13 @@ describe('Msg', () => {
expect(Msg.isToolCall(msg)).toBe(true)
})
it('handles tool call response', () => {
test('handles tool call response', () => {
const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id')
expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>()
expect(Msg.isToolResult(msg)).toBe(true)
})
it('prompt message types should interop with openai-fetch message types', () => {
test('prompt message types should interop with openai-fetch message types', () => {
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>()
expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>()
expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>()

Wyświetl plik

@ -7,10 +7,17 @@ import { cleanStringForModel, stringifyForModel } from './utils'
*/
export interface Msg {
/**
* The contents of the message. `content` is required for all messages, and
* may be null for assistant messages with function calls.
* The contents of the message. `content` may be null for assistant messages
* with function calls or `undefined` for assistant messages if a `refusal`
* was given by the model.
*/
content: string | null
content?: string | null
/**
* The reason the model refused to generate this message or `undefined` if the
* message was generated successfully.
*/
refusal?: string | null
/**
* The role of the messages author. One of `system`, `user`, `assistant`,
@ -95,7 +102,8 @@ export namespace Msg {
export type Assistant = {
role: 'assistant'
name?: string
content: string
content?: string
refusal?: string
}
/** Message with arguments to call a function. */
@ -262,7 +270,7 @@ export namespace Msg {
return Msg.toolCall(msg.tool_calls)
} else if (msg.content === null && msg.function_call != null) {
return Msg.funcCall(msg.function_call)
} else if (msg.content !== null) {
} else if (msg.content !== null && msg.content !== undefined) {
return Msg.assistant(msg.content)
} else {
// @TODO: probably don't want to error here

Wyświetl plik

@ -59,9 +59,10 @@ export function isZodSchema(value: unknown): value is z.ZodType {
}
export function asSchema<TData>(
schema: z.Schema<TData> | Schema<TData>
schema: z.Schema<TData> | Schema<TData>,
opts: { strict?: boolean } = {}
): Schema<TData> {
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema)
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema, opts)
}
/**
@ -84,9 +85,10 @@ export function createSchema<TData = unknown>(
}
export function createSchemaFromZodSchema<TData>(
zodSchema: z.Schema<TData>
zodSchema: z.Schema<TData>,
opts: { strict?: boolean } = {}
): Schema<TData> {
return createSchema(zodToJsonSchema(zodSchema), {
return createSchema(zodToJsonSchema(zodSchema, opts), {
validate: (value) => {
return safeParseStructuredOutput(value, zodSchema)
}

Wyświetl plik

@ -9,7 +9,7 @@ export type { Msg } from './message'
export type { Schema } from './schema'
export type { KyInstance } from 'ky'
export type { ThrottledFunction } from 'p-throttle'
export type { SetRequired, Simplify } from 'type-fest'
export type { SetOptional, SetRequired, Simplify } from 'type-fest'
export type Nullable<T> = T | null
@ -33,6 +33,15 @@ export interface AIFunctionSpec {
/** JSON schema spec of the function's input parameters */
parameters: JSONSchema
/**
* Whether to enable strict schema adherence when generating the function
* parameters. If set to true, the model will always follow the exact schema
* defined in the `schema` field. Only a subset of JSON Schema is supported
* when `strict` is `true`. To learn more, read the
* [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
*/
strict?: boolean
}
export interface AIToolSpec {
@ -102,7 +111,17 @@ export interface ChatParams {
max_tokens?: number
presence_penalty?: number
frequency_penalty?: number
response_format?: { type: 'text' | 'json_object' }
response_format?:
| {
type: 'text'
}
| {
type: 'json_object'
}
| {
type: 'json_schema'
json_schema: ResponseFormatJSONSchema
}
seed?: number
stop?: string | null | Array<string>
temperature?: number
@ -111,6 +130,34 @@ export interface ChatParams {
user?: string
}
export interface ResponseFormatJSONSchema {
/**
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
* underscores and dashes, with a maximum length of 64.
*/
name: string
/**
* A description of what the response format is for, used by the model to
* determine how to respond in the format.
*/
description?: string
/**
* The schema for the response format, described as a JSON Schema object.
*/
schema?: JSONSchema
/**
* Whether to enable strict schema adherence when generating the output. If
* set to true, the model will always follow the exact schema defined in the
* `schema` field. Only a subset of JSON Schema is supported when `strict`
* is `true`. To learn more, read the
* [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
*/
strict?: boolean
}
/** An OpenAI-compatible chat completions API */
export type ChatFn = (
params: Simplify<SetOptional<ChatParams, 'model'>>

Wyświetl plik

@ -1,10 +1,10 @@
import { describe, expect, it } from 'vitest'
import { describe, expect, test } from 'vitest'
import { z } from 'zod'
import { zodToJsonSchema } from './zod-to-json-schema'
describe('zodToJsonSchema', () => {
it('handles basic objects', () => {
test('handles basic objects', () => {
const params = zodToJsonSchema(
z.object({
name: z.string().min(1).describe('Name of the person'),
@ -30,7 +30,7 @@ describe('zodToJsonSchema', () => {
})
})
it('handles enums and unions', () => {
test('handles enums and unions', () => {
const params = zodToJsonSchema(
z.object({
name: z.string().min(1).describe('Name of the person'),

Wyświetl plik

@ -1,13 +1,23 @@
import type { z } from 'zod'
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema'
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'openai-zod-to-json-schema'
import type * as types from './types'
import { omit } from './utils'
/** Generate a JSON Schema from a Zod schema. */
export function zodToJsonSchema(schema: z.ZodType): types.JSONSchema {
export function zodToJsonSchema(
schema: z.ZodType,
{
strict = false
}: {
strict?: boolean
} = {}
): types.JSONSchema {
return omit(
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }),
zodToJsonSchemaImpl(schema, {
$refStrategy: 'none',
openaiStrictMode: strict
}),
'$schema',
'default',
'definitions',

Plik diff jest za duży Load Diff