kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add support for openai's structured output generation to @agentic/core
rodzic
e3a409df67
commit
f4b79d69b5
|
@ -10,6 +10,7 @@ jobs:
|
|||
fail-fast: true
|
||||
matrix:
|
||||
node-version:
|
||||
- 18
|
||||
- 20
|
||||
- 21
|
||||
- 22
|
||||
|
@ -27,7 +28,7 @@ jobs:
|
|||
uses: pnpm/action-setup@v3
|
||||
id: pnpm-install
|
||||
with:
|
||||
version: 9.6.0
|
||||
version: 9.7.0
|
||||
run_install: false
|
||||
|
||||
- name: Get pnpm store directory
|
||||
|
|
|
@ -12,6 +12,7 @@ async function main() {
|
|||
})
|
||||
|
||||
const result = await extractObject({
|
||||
name: 'extract-user',
|
||||
chatFn: chatModel.run.bind(chatModel),
|
||||
params: {
|
||||
messages: [
|
||||
|
@ -25,7 +26,8 @@ async function main() {
|
|||
name: z.string(),
|
||||
age: z.number(),
|
||||
location: z.string().optional()
|
||||
})
|
||||
}),
|
||||
strict: true
|
||||
})
|
||||
|
||||
console.log(result)
|
||||
|
|
|
@ -40,11 +40,11 @@
|
|||
"jsonrepair": "^3.6.1",
|
||||
"ky": "^1.5.0",
|
||||
"normalize-url": "^8.0.1",
|
||||
"openai-zod-to-json-schema": "^1.0.0",
|
||||
"p-map": "^7.0.2",
|
||||
"p-throttle": "^6.1.0",
|
||||
"quick-lru": "^7.0.0",
|
||||
"type-fest": "^4.21.0",
|
||||
"zod-to-json-schema": "^3.23.2",
|
||||
"zod-validation-error": "^3.3.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
@ -52,7 +52,7 @@
|
|||
},
|
||||
"devDependencies": {
|
||||
"@agentic/tsconfig": "workspace:*",
|
||||
"openai-fetch": "^2.0.4"
|
||||
"openai-fetch": "^2.1.0"
|
||||
},
|
||||
"publishConfig": {
|
||||
"access": "public"
|
||||
|
|
|
@ -10,15 +10,42 @@ import { asSchema, augmentSystemMessageWithJsonSchema } from './schema'
|
|||
import { getErrorMessage } from './utils'
|
||||
|
||||
export type AIChainParams<Result extends types.AIChainResult = string> = {
|
||||
/** Name of the chain */
|
||||
name: string
|
||||
|
||||
/** Chat completions function */
|
||||
chatFn: types.ChatFn
|
||||
|
||||
/** Description of the chain */
|
||||
description?: string
|
||||
|
||||
/** Optional chat completion params */
|
||||
params?: types.Simplify<
|
||||
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
|
||||
>
|
||||
|
||||
/** Optional tools */
|
||||
tools?: types.AIFunctionLike[]
|
||||
|
||||
/** Optional response schema */
|
||||
schema?: z.ZodType<Result> | types.Schema<Result>
|
||||
|
||||
/**
|
||||
* Whether or not the response schema should be treated as strict for
|
||||
* constrained structured output generation.
|
||||
*/
|
||||
strict?: boolean
|
||||
|
||||
/** Max number of LLM calls to allow */
|
||||
maxCalls?: number
|
||||
|
||||
/** Max number of retries to allow */
|
||||
maxRetries?: number
|
||||
|
||||
/** Max concurrency when invoking tool calls */
|
||||
toolCallConcurrency?: number
|
||||
|
||||
/** Whether or not to inject the schema into the context */
|
||||
injectSchemaIntoSystemMessage?: boolean
|
||||
}
|
||||
|
||||
|
@ -38,6 +65,8 @@ export type AIChainParams<Result extends types.AIChainResult = string> = {
|
|||
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
|
||||
*/
|
||||
export function createAIChain<Result extends types.AIChainResult = string>({
|
||||
name,
|
||||
description,
|
||||
chatFn,
|
||||
params,
|
||||
schema: rawSchema,
|
||||
|
@ -45,13 +74,26 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
maxCalls = 5,
|
||||
maxRetries = 2,
|
||||
toolCallConcurrency = 8,
|
||||
injectSchemaIntoSystemMessage = true
|
||||
injectSchemaIntoSystemMessage = false,
|
||||
strict = false
|
||||
}: AIChainParams<Result>): types.AIChain<Result> {
|
||||
const functionSet = new AIFunctionSet(tools)
|
||||
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
|
||||
|
||||
const defaultParams: Partial<types.ChatParams> | undefined =
|
||||
rawSchema && !functionSet.size
|
||||
schema && !functionSet.size
|
||||
? {
|
||||
response_format: { type: 'json_object' }
|
||||
response_format: strict
|
||||
? {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name,
|
||||
description,
|
||||
strict,
|
||||
schema: schema.jsonSchema
|
||||
}
|
||||
}
|
||||
: { type: 'json_object' }
|
||||
}
|
||||
: undefined
|
||||
|
||||
|
@ -77,8 +119,6 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
throw new Error('AIChain error: "messages" is empty')
|
||||
}
|
||||
|
||||
const schema = rawSchema ? asSchema(rawSchema) : undefined
|
||||
|
||||
if (schema && injectSchemaIntoSystemMessage) {
|
||||
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
|
||||
const lastSystemMessageContent =
|
||||
|
@ -101,6 +141,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
|
||||
do {
|
||||
++numCalls
|
||||
|
||||
const response = await chatFn({
|
||||
...modelParams,
|
||||
messages,
|
||||
|
@ -150,7 +191,9 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
'Function calls are not supported; expected tool call'
|
||||
)
|
||||
} else if (Msg.isAssistant(message)) {
|
||||
if (schema && schema.validate) {
|
||||
if (message.refusal) {
|
||||
throw new AbortError(`Model refusal: ${message.refusal}`)
|
||||
} else if (schema && schema.validate) {
|
||||
const result = schema.validate(message.content)
|
||||
|
||||
if (result.success) {
|
||||
|
@ -177,7 +220,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
|
||||
if (numErrors > maxRetries) {
|
||||
throw new Error(
|
||||
`Chain failed after ${numErrors} errors: ${err.message}`,
|
||||
`Chain ${name} failed after ${numErrors} errors: ${err.message}`,
|
||||
{
|
||||
cause: err
|
||||
}
|
||||
|
@ -186,6 +229,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
}
|
||||
} while (numCalls < maxCalls)
|
||||
|
||||
throw new Error(`Chain aborted after reaching max ${maxCalls} calls`)
|
||||
throw new Error(
|
||||
`Chain "${name}" aborted after reaching max ${maxCalls} calls`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { describe, expect, it } from 'vitest'
|
||||
import { describe, expect, test } from 'vitest'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { createAIFunction } from './create-ai-function'
|
||||
|
@ -18,7 +18,7 @@ const fullName = createAIFunction(
|
|||
)
|
||||
|
||||
describe('createAIFunction()', () => {
|
||||
it('exposes OpenAI function calling spec', () => {
|
||||
test('exposes OpenAI function calling spec', () => {
|
||||
expect(fullName.spec.name).toEqual('fullName')
|
||||
expect(fullName.spec.description).toEqual(
|
||||
'Returns the full name of a person.'
|
||||
|
@ -34,7 +34,7 @@ describe('createAIFunction()', () => {
|
|||
})
|
||||
})
|
||||
|
||||
it('executes the function', async () => {
|
||||
test('executes the function', async () => {
|
||||
expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual(
|
||||
'John Doe'
|
||||
)
|
||||
|
|
|
@ -22,6 +22,11 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
|
|||
description?: string
|
||||
/** Zod schema for the arguments string. */
|
||||
inputSchema: InputSchema
|
||||
/**
|
||||
* Whether or not to enable structured output generation based on the given
|
||||
* zod schema.
|
||||
*/
|
||||
strict?: boolean
|
||||
},
|
||||
/** Implementation of the function to call with the parsed arguments. */
|
||||
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
|
||||
|
@ -59,12 +64,15 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
|
|||
return implementation(parsedInput)
|
||||
}
|
||||
|
||||
const strict = !!spec.strict
|
||||
|
||||
aiFunction.inputSchema = spec.inputSchema
|
||||
aiFunction.parseInput = parseInput
|
||||
aiFunction.spec = {
|
||||
name: spec.name,
|
||||
description: spec.description?.trim() ?? '',
|
||||
parameters: zodToJsonSchema(spec.inputSchema)
|
||||
parameters: zodToJsonSchema(spec.inputSchema, { strict }),
|
||||
strict
|
||||
}
|
||||
aiFunction.impl = implementation
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ export interface PrivateAIFunctionMetadata {
|
|||
description: string
|
||||
inputSchema: z.AnyZodObject
|
||||
methodName: string
|
||||
strict?: boolean
|
||||
}
|
||||
|
||||
// Polyfill for `Symbol.metadata`
|
||||
|
@ -69,11 +70,13 @@ export function aiFunction<
|
|||
>({
|
||||
name,
|
||||
description,
|
||||
inputSchema
|
||||
inputSchema,
|
||||
strict
|
||||
}: {
|
||||
name?: string
|
||||
description: string
|
||||
inputSchema: InputSchema
|
||||
strict?: boolean
|
||||
}) {
|
||||
return (
|
||||
_targetMethod: (
|
||||
|
@ -99,7 +102,8 @@ export function aiFunction<
|
|||
name: name ?? methodName,
|
||||
description,
|
||||
inputSchema,
|
||||
methodName
|
||||
methodName,
|
||||
strict
|
||||
})
|
||||
|
||||
context.addInitializer(function () {
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import type * as OpenAI from 'openai-fetch'
|
||||
import { describe, expect, expectTypeOf, it } from 'vitest'
|
||||
import { describe, expect, expectTypeOf, test } from 'vitest'
|
||||
|
||||
import type * as types from './types'
|
||||
import { Msg } from './message'
|
||||
|
||||
describe('Msg', () => {
|
||||
it('creates a message and fixes indentation', () => {
|
||||
test('creates a message and fixes indentation', () => {
|
||||
const msgContent = `
|
||||
Hello, World!
|
||||
`
|
||||
|
@ -14,7 +14,7 @@ describe('Msg', () => {
|
|||
expect(msg.content).toEqual('Hello, World!')
|
||||
})
|
||||
|
||||
it('supports disabling indentation fixing', () => {
|
||||
test('supports disabling indentation fixing', () => {
|
||||
const msgContent = `
|
||||
Hello, World!
|
||||
`
|
||||
|
@ -22,7 +22,7 @@ describe('Msg', () => {
|
|||
expect(msg.content).toEqual('\n Hello, World!\n ')
|
||||
})
|
||||
|
||||
it('handles tool calls request', () => {
|
||||
test('handles tool calls request', () => {
|
||||
const msg = Msg.toolCall([
|
||||
{
|
||||
id: 'fake-tool-call-id',
|
||||
|
@ -37,13 +37,13 @@ describe('Msg', () => {
|
|||
expect(Msg.isToolCall(msg)).toBe(true)
|
||||
})
|
||||
|
||||
it('handles tool call response', () => {
|
||||
test('handles tool call response', () => {
|
||||
const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id')
|
||||
expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>()
|
||||
expect(Msg.isToolResult(msg)).toBe(true)
|
||||
})
|
||||
|
||||
it('prompt message types should interop with openai-fetch message types', () => {
|
||||
test('prompt message types should interop with openai-fetch message types', () => {
|
||||
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>()
|
||||
expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>()
|
||||
expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>()
|
||||
|
|
|
@ -7,10 +7,17 @@ import { cleanStringForModel, stringifyForModel } from './utils'
|
|||
*/
|
||||
export interface Msg {
|
||||
/**
|
||||
* The contents of the message. `content` is required for all messages, and
|
||||
* may be null for assistant messages with function calls.
|
||||
* The contents of the message. `content` may be null for assistant messages
|
||||
* with function calls or `undefined` for assistant messages if a `refusal`
|
||||
* was given by the model.
|
||||
*/
|
||||
content: string | null
|
||||
content?: string | null
|
||||
|
||||
/**
|
||||
* The reason the model refused to generate this message or `undefined` if the
|
||||
* message was generated successfully.
|
||||
*/
|
||||
refusal?: string | null
|
||||
|
||||
/**
|
||||
* The role of the messages author. One of `system`, `user`, `assistant`,
|
||||
|
@ -95,7 +102,8 @@ export namespace Msg {
|
|||
export type Assistant = {
|
||||
role: 'assistant'
|
||||
name?: string
|
||||
content: string
|
||||
content?: string
|
||||
refusal?: string
|
||||
}
|
||||
|
||||
/** Message with arguments to call a function. */
|
||||
|
@ -262,7 +270,7 @@ export namespace Msg {
|
|||
return Msg.toolCall(msg.tool_calls)
|
||||
} else if (msg.content === null && msg.function_call != null) {
|
||||
return Msg.funcCall(msg.function_call)
|
||||
} else if (msg.content !== null) {
|
||||
} else if (msg.content !== null && msg.content !== undefined) {
|
||||
return Msg.assistant(msg.content)
|
||||
} else {
|
||||
// @TODO: probably don't want to error here
|
||||
|
|
|
@ -59,9 +59,10 @@ export function isZodSchema(value: unknown): value is z.ZodType {
|
|||
}
|
||||
|
||||
export function asSchema<TData>(
|
||||
schema: z.Schema<TData> | Schema<TData>
|
||||
schema: z.Schema<TData> | Schema<TData>,
|
||||
opts: { strict?: boolean } = {}
|
||||
): Schema<TData> {
|
||||
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema)
|
||||
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema, opts)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -84,9 +85,10 @@ export function createSchema<TData = unknown>(
|
|||
}
|
||||
|
||||
export function createSchemaFromZodSchema<TData>(
|
||||
zodSchema: z.Schema<TData>
|
||||
zodSchema: z.Schema<TData>,
|
||||
opts: { strict?: boolean } = {}
|
||||
): Schema<TData> {
|
||||
return createSchema(zodToJsonSchema(zodSchema), {
|
||||
return createSchema(zodToJsonSchema(zodSchema, opts), {
|
||||
validate: (value) => {
|
||||
return safeParseStructuredOutput(value, zodSchema)
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ export type { Msg } from './message'
|
|||
export type { Schema } from './schema'
|
||||
export type { KyInstance } from 'ky'
|
||||
export type { ThrottledFunction } from 'p-throttle'
|
||||
export type { SetRequired, Simplify } from 'type-fest'
|
||||
export type { SetOptional, SetRequired, Simplify } from 'type-fest'
|
||||
|
||||
export type Nullable<T> = T | null
|
||||
|
||||
|
@ -33,6 +33,15 @@ export interface AIFunctionSpec {
|
|||
|
||||
/** JSON schema spec of the function's input parameters */
|
||||
parameters: JSONSchema
|
||||
|
||||
/**
|
||||
* Whether to enable strict schema adherence when generating the function
|
||||
* parameters. If set to true, the model will always follow the exact schema
|
||||
* defined in the `schema` field. Only a subset of JSON Schema is supported
|
||||
* when `strict` is `true`. To learn more, read the
|
||||
* [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
|
||||
*/
|
||||
strict?: boolean
|
||||
}
|
||||
|
||||
export interface AIToolSpec {
|
||||
|
@ -102,7 +111,17 @@ export interface ChatParams {
|
|||
max_tokens?: number
|
||||
presence_penalty?: number
|
||||
frequency_penalty?: number
|
||||
response_format?: { type: 'text' | 'json_object' }
|
||||
response_format?:
|
||||
| {
|
||||
type: 'text'
|
||||
}
|
||||
| {
|
||||
type: 'json_object'
|
||||
}
|
||||
| {
|
||||
type: 'json_schema'
|
||||
json_schema: ResponseFormatJSONSchema
|
||||
}
|
||||
seed?: number
|
||||
stop?: string | null | Array<string>
|
||||
temperature?: number
|
||||
|
@ -111,6 +130,34 @@ export interface ChatParams {
|
|||
user?: string
|
||||
}
|
||||
|
||||
export interface ResponseFormatJSONSchema {
|
||||
/**
|
||||
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
|
||||
* underscores and dashes, with a maximum length of 64.
|
||||
*/
|
||||
name: string
|
||||
|
||||
/**
|
||||
* A description of what the response format is for, used by the model to
|
||||
* determine how to respond in the format.
|
||||
*/
|
||||
description?: string
|
||||
|
||||
/**
|
||||
* The schema for the response format, described as a JSON Schema object.
|
||||
*/
|
||||
schema?: JSONSchema
|
||||
|
||||
/**
|
||||
* Whether to enable strict schema adherence when generating the output. If
|
||||
* set to true, the model will always follow the exact schema defined in the
|
||||
* `schema` field. Only a subset of JSON Schema is supported when `strict`
|
||||
* is `true`. To learn more, read the
|
||||
* [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
|
||||
*/
|
||||
strict?: boolean
|
||||
}
|
||||
|
||||
/** An OpenAI-compatible chat completions API */
|
||||
export type ChatFn = (
|
||||
params: Simplify<SetOptional<ChatParams, 'model'>>
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import { describe, expect, it } from 'vitest'
|
||||
import { describe, expect, test } from 'vitest'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { zodToJsonSchema } from './zod-to-json-schema'
|
||||
|
||||
describe('zodToJsonSchema', () => {
|
||||
it('handles basic objects', () => {
|
||||
test('handles basic objects', () => {
|
||||
const params = zodToJsonSchema(
|
||||
z.object({
|
||||
name: z.string().min(1).describe('Name of the person'),
|
||||
|
@ -30,7 +30,7 @@ describe('zodToJsonSchema', () => {
|
|||
})
|
||||
})
|
||||
|
||||
it('handles enums and unions', () => {
|
||||
test('handles enums and unions', () => {
|
||||
const params = zodToJsonSchema(
|
||||
z.object({
|
||||
name: z.string().min(1).describe('Name of the person'),
|
||||
|
|
|
@ -1,13 +1,23 @@
|
|||
import type { z } from 'zod'
|
||||
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema'
|
||||
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'openai-zod-to-json-schema'
|
||||
|
||||
import type * as types from './types'
|
||||
import { omit } from './utils'
|
||||
|
||||
/** Generate a JSON Schema from a Zod schema. */
|
||||
export function zodToJsonSchema(schema: z.ZodType): types.JSONSchema {
|
||||
export function zodToJsonSchema(
|
||||
schema: z.ZodType,
|
||||
{
|
||||
strict = false
|
||||
}: {
|
||||
strict?: boolean
|
||||
} = {}
|
||||
): types.JSONSchema {
|
||||
return omit(
|
||||
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }),
|
||||
zodToJsonSchemaImpl(schema, {
|
||||
$refStrategy: 'none',
|
||||
openaiStrictMode: strict
|
||||
}),
|
||||
'$schema',
|
||||
'default',
|
||||
'definitions',
|
||||
|
|
1527
pnpm-lock.yaml
1527
pnpm-lock.yaml
Plik diff jest za duży
Load Diff
Ładowanie…
Reference in New Issue