kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add support for openai's structured output generation to @agentic/core
rodzic
e3a409df67
commit
f4b79d69b5
|
@ -10,6 +10,7 @@ jobs:
|
||||||
fail-fast: true
|
fail-fast: true
|
||||||
matrix:
|
matrix:
|
||||||
node-version:
|
node-version:
|
||||||
|
- 18
|
||||||
- 20
|
- 20
|
||||||
- 21
|
- 21
|
||||||
- 22
|
- 22
|
||||||
|
@ -27,7 +28,7 @@ jobs:
|
||||||
uses: pnpm/action-setup@v3
|
uses: pnpm/action-setup@v3
|
||||||
id: pnpm-install
|
id: pnpm-install
|
||||||
with:
|
with:
|
||||||
version: 9.6.0
|
version: 9.7.0
|
||||||
run_install: false
|
run_install: false
|
||||||
|
|
||||||
- name: Get pnpm store directory
|
- name: Get pnpm store directory
|
||||||
|
|
|
@ -12,6 +12,7 @@ async function main() {
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await extractObject({
|
const result = await extractObject({
|
||||||
|
name: 'extract-user',
|
||||||
chatFn: chatModel.run.bind(chatModel),
|
chatFn: chatModel.run.bind(chatModel),
|
||||||
params: {
|
params: {
|
||||||
messages: [
|
messages: [
|
||||||
|
@ -25,7 +26,8 @@ async function main() {
|
||||||
name: z.string(),
|
name: z.string(),
|
||||||
age: z.number(),
|
age: z.number(),
|
||||||
location: z.string().optional()
|
location: z.string().optional()
|
||||||
})
|
}),
|
||||||
|
strict: true
|
||||||
})
|
})
|
||||||
|
|
||||||
console.log(result)
|
console.log(result)
|
||||||
|
|
|
@ -40,11 +40,11 @@
|
||||||
"jsonrepair": "^3.6.1",
|
"jsonrepair": "^3.6.1",
|
||||||
"ky": "^1.5.0",
|
"ky": "^1.5.0",
|
||||||
"normalize-url": "^8.0.1",
|
"normalize-url": "^8.0.1",
|
||||||
|
"openai-zod-to-json-schema": "^1.0.0",
|
||||||
"p-map": "^7.0.2",
|
"p-map": "^7.0.2",
|
||||||
"p-throttle": "^6.1.0",
|
"p-throttle": "^6.1.0",
|
||||||
"quick-lru": "^7.0.0",
|
"quick-lru": "^7.0.0",
|
||||||
"type-fest": "^4.21.0",
|
"type-fest": "^4.21.0",
|
||||||
"zod-to-json-schema": "^3.23.2",
|
|
||||||
"zod-validation-error": "^3.3.0"
|
"zod-validation-error": "^3.3.0"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
|
@ -52,7 +52,7 @@
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@agentic/tsconfig": "workspace:*",
|
"@agentic/tsconfig": "workspace:*",
|
||||||
"openai-fetch": "^2.0.4"
|
"openai-fetch": "^2.1.0"
|
||||||
},
|
},
|
||||||
"publishConfig": {
|
"publishConfig": {
|
||||||
"access": "public"
|
"access": "public"
|
||||||
|
|
|
@ -10,15 +10,42 @@ import { asSchema, augmentSystemMessageWithJsonSchema } from './schema'
|
||||||
import { getErrorMessage } from './utils'
|
import { getErrorMessage } from './utils'
|
||||||
|
|
||||||
export type AIChainParams<Result extends types.AIChainResult = string> = {
|
export type AIChainParams<Result extends types.AIChainResult = string> = {
|
||||||
|
/** Name of the chain */
|
||||||
|
name: string
|
||||||
|
|
||||||
|
/** Chat completions function */
|
||||||
chatFn: types.ChatFn
|
chatFn: types.ChatFn
|
||||||
|
|
||||||
|
/** Description of the chain */
|
||||||
|
description?: string
|
||||||
|
|
||||||
|
/** Optional chat completion params */
|
||||||
params?: types.Simplify<
|
params?: types.Simplify<
|
||||||
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
|
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
|
||||||
>
|
>
|
||||||
|
|
||||||
|
/** Optional tools */
|
||||||
tools?: types.AIFunctionLike[]
|
tools?: types.AIFunctionLike[]
|
||||||
|
|
||||||
|
/** Optional response schema */
|
||||||
schema?: z.ZodType<Result> | types.Schema<Result>
|
schema?: z.ZodType<Result> | types.Schema<Result>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether or not the response schema should be treated as strict for
|
||||||
|
* constrained structured output generation.
|
||||||
|
*/
|
||||||
|
strict?: boolean
|
||||||
|
|
||||||
|
/** Max number of LLM calls to allow */
|
||||||
maxCalls?: number
|
maxCalls?: number
|
||||||
|
|
||||||
|
/** Max number of retries to allow */
|
||||||
maxRetries?: number
|
maxRetries?: number
|
||||||
|
|
||||||
|
/** Max concurrency when invoking tool calls */
|
||||||
toolCallConcurrency?: number
|
toolCallConcurrency?: number
|
||||||
|
|
||||||
|
/** Whether or not to inject the schema into the context */
|
||||||
injectSchemaIntoSystemMessage?: boolean
|
injectSchemaIntoSystemMessage?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +65,8 @@ export type AIChainParams<Result extends types.AIChainResult = string> = {
|
||||||
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
|
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
|
||||||
*/
|
*/
|
||||||
export function createAIChain<Result extends types.AIChainResult = string>({
|
export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
|
name,
|
||||||
|
description,
|
||||||
chatFn,
|
chatFn,
|
||||||
params,
|
params,
|
||||||
schema: rawSchema,
|
schema: rawSchema,
|
||||||
|
@ -45,13 +74,26 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
maxCalls = 5,
|
maxCalls = 5,
|
||||||
maxRetries = 2,
|
maxRetries = 2,
|
||||||
toolCallConcurrency = 8,
|
toolCallConcurrency = 8,
|
||||||
injectSchemaIntoSystemMessage = true
|
injectSchemaIntoSystemMessage = false,
|
||||||
|
strict = false
|
||||||
}: AIChainParams<Result>): types.AIChain<Result> {
|
}: AIChainParams<Result>): types.AIChain<Result> {
|
||||||
const functionSet = new AIFunctionSet(tools)
|
const functionSet = new AIFunctionSet(tools)
|
||||||
|
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
|
||||||
|
|
||||||
const defaultParams: Partial<types.ChatParams> | undefined =
|
const defaultParams: Partial<types.ChatParams> | undefined =
|
||||||
rawSchema && !functionSet.size
|
schema && !functionSet.size
|
||||||
? {
|
? {
|
||||||
response_format: { type: 'json_object' }
|
response_format: strict
|
||||||
|
? {
|
||||||
|
type: 'json_schema',
|
||||||
|
json_schema: {
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
strict,
|
||||||
|
schema: schema.jsonSchema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
: { type: 'json_object' }
|
||||||
}
|
}
|
||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
|
@ -77,8 +119,6 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
throw new Error('AIChain error: "messages" is empty')
|
throw new Error('AIChain error: "messages" is empty')
|
||||||
}
|
}
|
||||||
|
|
||||||
const schema = rawSchema ? asSchema(rawSchema) : undefined
|
|
||||||
|
|
||||||
if (schema && injectSchemaIntoSystemMessage) {
|
if (schema && injectSchemaIntoSystemMessage) {
|
||||||
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
|
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
|
||||||
const lastSystemMessageContent =
|
const lastSystemMessageContent =
|
||||||
|
@ -101,6 +141,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
|
|
||||||
do {
|
do {
|
||||||
++numCalls
|
++numCalls
|
||||||
|
|
||||||
const response = await chatFn({
|
const response = await chatFn({
|
||||||
...modelParams,
|
...modelParams,
|
||||||
messages,
|
messages,
|
||||||
|
@ -150,7 +191,9 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
'Function calls are not supported; expected tool call'
|
'Function calls are not supported; expected tool call'
|
||||||
)
|
)
|
||||||
} else if (Msg.isAssistant(message)) {
|
} else if (Msg.isAssistant(message)) {
|
||||||
if (schema && schema.validate) {
|
if (message.refusal) {
|
||||||
|
throw new AbortError(`Model refusal: ${message.refusal}`)
|
||||||
|
} else if (schema && schema.validate) {
|
||||||
const result = schema.validate(message.content)
|
const result = schema.validate(message.content)
|
||||||
|
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
|
@ -177,7 +220,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
|
|
||||||
if (numErrors > maxRetries) {
|
if (numErrors > maxRetries) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Chain failed after ${numErrors} errors: ${err.message}`,
|
`Chain ${name} failed after ${numErrors} errors: ${err.message}`,
|
||||||
{
|
{
|
||||||
cause: err
|
cause: err
|
||||||
}
|
}
|
||||||
|
@ -186,6 +229,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
}
|
}
|
||||||
} while (numCalls < maxCalls)
|
} while (numCalls < maxCalls)
|
||||||
|
|
||||||
throw new Error(`Chain aborted after reaching max ${maxCalls} calls`)
|
throw new Error(
|
||||||
|
`Chain "${name}" aborted after reaching max ${maxCalls} calls`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { describe, expect, it } from 'vitest'
|
import { describe, expect, test } from 'vitest'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
import { createAIFunction } from './create-ai-function'
|
import { createAIFunction } from './create-ai-function'
|
||||||
|
@ -18,7 +18,7 @@ const fullName = createAIFunction(
|
||||||
)
|
)
|
||||||
|
|
||||||
describe('createAIFunction()', () => {
|
describe('createAIFunction()', () => {
|
||||||
it('exposes OpenAI function calling spec', () => {
|
test('exposes OpenAI function calling spec', () => {
|
||||||
expect(fullName.spec.name).toEqual('fullName')
|
expect(fullName.spec.name).toEqual('fullName')
|
||||||
expect(fullName.spec.description).toEqual(
|
expect(fullName.spec.description).toEqual(
|
||||||
'Returns the full name of a person.'
|
'Returns the full name of a person.'
|
||||||
|
@ -34,7 +34,7 @@ describe('createAIFunction()', () => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('executes the function', async () => {
|
test('executes the function', async () => {
|
||||||
expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual(
|
expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual(
|
||||||
'John Doe'
|
'John Doe'
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,6 +22,11 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
|
||||||
description?: string
|
description?: string
|
||||||
/** Zod schema for the arguments string. */
|
/** Zod schema for the arguments string. */
|
||||||
inputSchema: InputSchema
|
inputSchema: InputSchema
|
||||||
|
/**
|
||||||
|
* Whether or not to enable structured output generation based on the given
|
||||||
|
* zod schema.
|
||||||
|
*/
|
||||||
|
strict?: boolean
|
||||||
},
|
},
|
||||||
/** Implementation of the function to call with the parsed arguments. */
|
/** Implementation of the function to call with the parsed arguments. */
|
||||||
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
|
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
|
||||||
|
@ -59,12 +64,15 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
|
||||||
return implementation(parsedInput)
|
return implementation(parsedInput)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const strict = !!spec.strict
|
||||||
|
|
||||||
aiFunction.inputSchema = spec.inputSchema
|
aiFunction.inputSchema = spec.inputSchema
|
||||||
aiFunction.parseInput = parseInput
|
aiFunction.parseInput = parseInput
|
||||||
aiFunction.spec = {
|
aiFunction.spec = {
|
||||||
name: spec.name,
|
name: spec.name,
|
||||||
description: spec.description?.trim() ?? '',
|
description: spec.description?.trim() ?? '',
|
||||||
parameters: zodToJsonSchema(spec.inputSchema)
|
parameters: zodToJsonSchema(spec.inputSchema, { strict }),
|
||||||
|
strict
|
||||||
}
|
}
|
||||||
aiFunction.impl = implementation
|
aiFunction.impl = implementation
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ export interface PrivateAIFunctionMetadata {
|
||||||
description: string
|
description: string
|
||||||
inputSchema: z.AnyZodObject
|
inputSchema: z.AnyZodObject
|
||||||
methodName: string
|
methodName: string
|
||||||
|
strict?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
// Polyfill for `Symbol.metadata`
|
// Polyfill for `Symbol.metadata`
|
||||||
|
@ -69,11 +70,13 @@ export function aiFunction<
|
||||||
>({
|
>({
|
||||||
name,
|
name,
|
||||||
description,
|
description,
|
||||||
inputSchema
|
inputSchema,
|
||||||
|
strict
|
||||||
}: {
|
}: {
|
||||||
name?: string
|
name?: string
|
||||||
description: string
|
description: string
|
||||||
inputSchema: InputSchema
|
inputSchema: InputSchema
|
||||||
|
strict?: boolean
|
||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
_targetMethod: (
|
_targetMethod: (
|
||||||
|
@ -99,7 +102,8 @@ export function aiFunction<
|
||||||
name: name ?? methodName,
|
name: name ?? methodName,
|
||||||
description,
|
description,
|
||||||
inputSchema,
|
inputSchema,
|
||||||
methodName
|
methodName,
|
||||||
|
strict
|
||||||
})
|
})
|
||||||
|
|
||||||
context.addInitializer(function () {
|
context.addInitializer(function () {
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import type * as OpenAI from 'openai-fetch'
|
import type * as OpenAI from 'openai-fetch'
|
||||||
import { describe, expect, expectTypeOf, it } from 'vitest'
|
import { describe, expect, expectTypeOf, test } from 'vitest'
|
||||||
|
|
||||||
import type * as types from './types'
|
import type * as types from './types'
|
||||||
import { Msg } from './message'
|
import { Msg } from './message'
|
||||||
|
|
||||||
describe('Msg', () => {
|
describe('Msg', () => {
|
||||||
it('creates a message and fixes indentation', () => {
|
test('creates a message and fixes indentation', () => {
|
||||||
const msgContent = `
|
const msgContent = `
|
||||||
Hello, World!
|
Hello, World!
|
||||||
`
|
`
|
||||||
|
@ -14,7 +14,7 @@ describe('Msg', () => {
|
||||||
expect(msg.content).toEqual('Hello, World!')
|
expect(msg.content).toEqual('Hello, World!')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('supports disabling indentation fixing', () => {
|
test('supports disabling indentation fixing', () => {
|
||||||
const msgContent = `
|
const msgContent = `
|
||||||
Hello, World!
|
Hello, World!
|
||||||
`
|
`
|
||||||
|
@ -22,7 +22,7 @@ describe('Msg', () => {
|
||||||
expect(msg.content).toEqual('\n Hello, World!\n ')
|
expect(msg.content).toEqual('\n Hello, World!\n ')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles tool calls request', () => {
|
test('handles tool calls request', () => {
|
||||||
const msg = Msg.toolCall([
|
const msg = Msg.toolCall([
|
||||||
{
|
{
|
||||||
id: 'fake-tool-call-id',
|
id: 'fake-tool-call-id',
|
||||||
|
@ -37,13 +37,13 @@ describe('Msg', () => {
|
||||||
expect(Msg.isToolCall(msg)).toBe(true)
|
expect(Msg.isToolCall(msg)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles tool call response', () => {
|
test('handles tool call response', () => {
|
||||||
const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id')
|
const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id')
|
||||||
expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>()
|
expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>()
|
||||||
expect(Msg.isToolResult(msg)).toBe(true)
|
expect(Msg.isToolResult(msg)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('prompt message types should interop with openai-fetch message types', () => {
|
test('prompt message types should interop with openai-fetch message types', () => {
|
||||||
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>()
|
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>()
|
||||||
expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>()
|
expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>()
|
||||||
expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>()
|
expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>()
|
||||||
|
|
|
@ -7,10 +7,17 @@ import { cleanStringForModel, stringifyForModel } from './utils'
|
||||||
*/
|
*/
|
||||||
export interface Msg {
|
export interface Msg {
|
||||||
/**
|
/**
|
||||||
* The contents of the message. `content` is required for all messages, and
|
* The contents of the message. `content` may be null for assistant messages
|
||||||
* may be null for assistant messages with function calls.
|
* with function calls or `undefined` for assistant messages if a `refusal`
|
||||||
|
* was given by the model.
|
||||||
*/
|
*/
|
||||||
content: string | null
|
content?: string | null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The reason the model refused to generate this message or `undefined` if the
|
||||||
|
* message was generated successfully.
|
||||||
|
*/
|
||||||
|
refusal?: string | null
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The role of the messages author. One of `system`, `user`, `assistant`,
|
* The role of the messages author. One of `system`, `user`, `assistant`,
|
||||||
|
@ -95,7 +102,8 @@ export namespace Msg {
|
||||||
export type Assistant = {
|
export type Assistant = {
|
||||||
role: 'assistant'
|
role: 'assistant'
|
||||||
name?: string
|
name?: string
|
||||||
content: string
|
content?: string
|
||||||
|
refusal?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Message with arguments to call a function. */
|
/** Message with arguments to call a function. */
|
||||||
|
@ -262,7 +270,7 @@ export namespace Msg {
|
||||||
return Msg.toolCall(msg.tool_calls)
|
return Msg.toolCall(msg.tool_calls)
|
||||||
} else if (msg.content === null && msg.function_call != null) {
|
} else if (msg.content === null && msg.function_call != null) {
|
||||||
return Msg.funcCall(msg.function_call)
|
return Msg.funcCall(msg.function_call)
|
||||||
} else if (msg.content !== null) {
|
} else if (msg.content !== null && msg.content !== undefined) {
|
||||||
return Msg.assistant(msg.content)
|
return Msg.assistant(msg.content)
|
||||||
} else {
|
} else {
|
||||||
// @TODO: probably don't want to error here
|
// @TODO: probably don't want to error here
|
||||||
|
|
|
@ -59,9 +59,10 @@ export function isZodSchema(value: unknown): value is z.ZodType {
|
||||||
}
|
}
|
||||||
|
|
||||||
export function asSchema<TData>(
|
export function asSchema<TData>(
|
||||||
schema: z.Schema<TData> | Schema<TData>
|
schema: z.Schema<TData> | Schema<TData>,
|
||||||
|
opts: { strict?: boolean } = {}
|
||||||
): Schema<TData> {
|
): Schema<TData> {
|
||||||
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema)
|
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -84,9 +85,10 @@ export function createSchema<TData = unknown>(
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createSchemaFromZodSchema<TData>(
|
export function createSchemaFromZodSchema<TData>(
|
||||||
zodSchema: z.Schema<TData>
|
zodSchema: z.Schema<TData>,
|
||||||
|
opts: { strict?: boolean } = {}
|
||||||
): Schema<TData> {
|
): Schema<TData> {
|
||||||
return createSchema(zodToJsonSchema(zodSchema), {
|
return createSchema(zodToJsonSchema(zodSchema, opts), {
|
||||||
validate: (value) => {
|
validate: (value) => {
|
||||||
return safeParseStructuredOutput(value, zodSchema)
|
return safeParseStructuredOutput(value, zodSchema)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ export type { Msg } from './message'
|
||||||
export type { Schema } from './schema'
|
export type { Schema } from './schema'
|
||||||
export type { KyInstance } from 'ky'
|
export type { KyInstance } from 'ky'
|
||||||
export type { ThrottledFunction } from 'p-throttle'
|
export type { ThrottledFunction } from 'p-throttle'
|
||||||
export type { SetRequired, Simplify } from 'type-fest'
|
export type { SetOptional, SetRequired, Simplify } from 'type-fest'
|
||||||
|
|
||||||
export type Nullable<T> = T | null
|
export type Nullable<T> = T | null
|
||||||
|
|
||||||
|
@ -33,6 +33,15 @@ export interface AIFunctionSpec {
|
||||||
|
|
||||||
/** JSON schema spec of the function's input parameters */
|
/** JSON schema spec of the function's input parameters */
|
||||||
parameters: JSONSchema
|
parameters: JSONSchema
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to enable strict schema adherence when generating the function
|
||||||
|
* parameters. If set to true, the model will always follow the exact schema
|
||||||
|
* defined in the `schema` field. Only a subset of JSON Schema is supported
|
||||||
|
* when `strict` is `true`. To learn more, read the
|
||||||
|
* [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
|
||||||
|
*/
|
||||||
|
strict?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AIToolSpec {
|
export interface AIToolSpec {
|
||||||
|
@ -102,7 +111,17 @@ export interface ChatParams {
|
||||||
max_tokens?: number
|
max_tokens?: number
|
||||||
presence_penalty?: number
|
presence_penalty?: number
|
||||||
frequency_penalty?: number
|
frequency_penalty?: number
|
||||||
response_format?: { type: 'text' | 'json_object' }
|
response_format?:
|
||||||
|
| {
|
||||||
|
type: 'text'
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: 'json_object'
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: 'json_schema'
|
||||||
|
json_schema: ResponseFormatJSONSchema
|
||||||
|
}
|
||||||
seed?: number
|
seed?: number
|
||||||
stop?: string | null | Array<string>
|
stop?: string | null | Array<string>
|
||||||
temperature?: number
|
temperature?: number
|
||||||
|
@ -111,6 +130,34 @@ export interface ChatParams {
|
||||||
user?: string
|
user?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ResponseFormatJSONSchema {
|
||||||
|
/**
|
||||||
|
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
|
||||||
|
* underscores and dashes, with a maximum length of 64.
|
||||||
|
*/
|
||||||
|
name: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A description of what the response format is for, used by the model to
|
||||||
|
* determine how to respond in the format.
|
||||||
|
*/
|
||||||
|
description?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The schema for the response format, described as a JSON Schema object.
|
||||||
|
*/
|
||||||
|
schema?: JSONSchema
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to enable strict schema adherence when generating the output. If
|
||||||
|
* set to true, the model will always follow the exact schema defined in the
|
||||||
|
* `schema` field. Only a subset of JSON Schema is supported when `strict`
|
||||||
|
* is `true`. To learn more, read the
|
||||||
|
* [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
|
||||||
|
*/
|
||||||
|
strict?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
/** An OpenAI-compatible chat completions API */
|
/** An OpenAI-compatible chat completions API */
|
||||||
export type ChatFn = (
|
export type ChatFn = (
|
||||||
params: Simplify<SetOptional<ChatParams, 'model'>>
|
params: Simplify<SetOptional<ChatParams, 'model'>>
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import { describe, expect, it } from 'vitest'
|
import { describe, expect, test } from 'vitest'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
import { zodToJsonSchema } from './zod-to-json-schema'
|
import { zodToJsonSchema } from './zod-to-json-schema'
|
||||||
|
|
||||||
describe('zodToJsonSchema', () => {
|
describe('zodToJsonSchema', () => {
|
||||||
it('handles basic objects', () => {
|
test('handles basic objects', () => {
|
||||||
const params = zodToJsonSchema(
|
const params = zodToJsonSchema(
|
||||||
z.object({
|
z.object({
|
||||||
name: z.string().min(1).describe('Name of the person'),
|
name: z.string().min(1).describe('Name of the person'),
|
||||||
|
@ -30,7 +30,7 @@ describe('zodToJsonSchema', () => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles enums and unions', () => {
|
test('handles enums and unions', () => {
|
||||||
const params = zodToJsonSchema(
|
const params = zodToJsonSchema(
|
||||||
z.object({
|
z.object({
|
||||||
name: z.string().min(1).describe('Name of the person'),
|
name: z.string().min(1).describe('Name of the person'),
|
||||||
|
|
|
@ -1,13 +1,23 @@
|
||||||
import type { z } from 'zod'
|
import type { z } from 'zod'
|
||||||
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema'
|
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'openai-zod-to-json-schema'
|
||||||
|
|
||||||
import type * as types from './types'
|
import type * as types from './types'
|
||||||
import { omit } from './utils'
|
import { omit } from './utils'
|
||||||
|
|
||||||
/** Generate a JSON Schema from a Zod schema. */
|
/** Generate a JSON Schema from a Zod schema. */
|
||||||
export function zodToJsonSchema(schema: z.ZodType): types.JSONSchema {
|
export function zodToJsonSchema(
|
||||||
|
schema: z.ZodType,
|
||||||
|
{
|
||||||
|
strict = false
|
||||||
|
}: {
|
||||||
|
strict?: boolean
|
||||||
|
} = {}
|
||||||
|
): types.JSONSchema {
|
||||||
return omit(
|
return omit(
|
||||||
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }),
|
zodToJsonSchemaImpl(schema, {
|
||||||
|
$refStrategy: 'none',
|
||||||
|
openaiStrictMode: strict
|
||||||
|
}),
|
||||||
'$schema',
|
'$schema',
|
||||||
'default',
|
'default',
|
||||||
'definitions',
|
'definitions',
|
||||||
|
|
1527
pnpm-lock.yaml
1527
pnpm-lock.yaml
Plik diff jest za duży
Load Diff
Ładowanie…
Reference in New Issue