kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
Merge pull request #658 from transitive-bullshit/feature/openai-structured-outputs
commit
3b1f6c3bc3
|
@ -10,6 +10,7 @@ jobs:
|
|||
fail-fast: true
|
||||
matrix:
|
||||
node-version:
|
||||
- 18
|
||||
- 20
|
||||
- 21
|
||||
- 22
|
||||
|
@ -18,33 +19,23 @@ jobs:
|
|||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9.7.0
|
||||
run_install: false
|
||||
|
||||
- name: Install Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v3
|
||||
id: pnpm-install
|
||||
with:
|
||||
version: 9.6.0
|
||||
run_install: false
|
||||
|
||||
- name: Get pnpm store directory
|
||||
shell: bash
|
||||
run: |
|
||||
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
|
||||
|
||||
- name: Setup pnpm cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ env.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
cache: 'pnpm'
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: pnpm install --frozen-lockfile --strict-peer-dependencies
|
||||
|
||||
- name: Run build
|
||||
run: pnpm build
|
||||
|
||||
- name: Run test
|
||||
run: pnpm run test
|
||||
run: pnpm test
|
||||
|
|
|
@ -16,6 +16,7 @@ async function main() {
|
|||
})
|
||||
|
||||
const chain = createAIChain({
|
||||
name: 'search_news',
|
||||
chatFn: chatModel.run.bind(chatModel),
|
||||
tools: [perigon.functions.pick('search_news_stories'), serper],
|
||||
params: {
|
||||
|
|
|
@ -12,6 +12,7 @@ async function main() {
|
|||
})
|
||||
|
||||
const result = await extractObject({
|
||||
name: 'extract-user',
|
||||
chatFn: chatModel.run.bind(chatModel),
|
||||
params: {
|
||||
messages: [
|
||||
|
@ -25,7 +26,8 @@ async function main() {
|
|||
name: z.string(),
|
||||
age: z.number(),
|
||||
location: z.string().optional()
|
||||
})
|
||||
}),
|
||||
strict: true
|
||||
})
|
||||
|
||||
console.log(result)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
"type": "git",
|
||||
"url": "git+https://github.com/transitive-bullshit/agentic.git"
|
||||
},
|
||||
"packageManager": "pnpm@9.6.0",
|
||||
"packageManager": "pnpm@9.7.0",
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
|
@ -26,6 +26,7 @@
|
|||
"release:build": "run-s build",
|
||||
"release:version": "changeset version",
|
||||
"release:publish": "changeset publish",
|
||||
"pretest": "run-s build",
|
||||
"precommit": "lint-staged",
|
||||
"preinstall": "npx only-allow pnpm",
|
||||
"prepare": "husky"
|
||||
|
@ -45,7 +46,7 @@
|
|||
"prettier": "^3.3.3",
|
||||
"tsup": "^8.2.4",
|
||||
"tsx": "^4.16.5",
|
||||
"turbo": "^2.0.11",
|
||||
"turbo": "^2.0.12",
|
||||
"typescript": "^5.5.4",
|
||||
"vitest": "2.0.5",
|
||||
"zod": "^3.23.8"
|
||||
|
|
|
@ -40,11 +40,11 @@
|
|||
"jsonrepair": "^3.6.1",
|
||||
"ky": "^1.5.0",
|
||||
"normalize-url": "^8.0.1",
|
||||
"openai-zod-to-json-schema": "^1.0.0",
|
||||
"p-map": "^7.0.2",
|
||||
"p-throttle": "^6.1.0",
|
||||
"quick-lru": "^7.0.0",
|
||||
"type-fest": "^4.21.0",
|
||||
"zod-to-json-schema": "^3.23.2",
|
||||
"zod-validation-error": "^3.3.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
@ -52,7 +52,7 @@
|
|||
},
|
||||
"devDependencies": {
|
||||
"@agentic/tsconfig": "workspace:*",
|
||||
"openai-fetch": "^2.0.4"
|
||||
"openai-fetch": "^3.0.0"
|
||||
},
|
||||
"publishConfig": {
|
||||
"access": "public"
|
||||
|
|
|
@ -10,15 +10,42 @@ import { asSchema, augmentSystemMessageWithJsonSchema } from './schema'
|
|||
import { getErrorMessage } from './utils'
|
||||
|
||||
export type AIChainParams<Result extends types.AIChainResult = string> = {
|
||||
/** Name of the chain */
|
||||
name: string
|
||||
|
||||
/** Chat completions function */
|
||||
chatFn: types.ChatFn
|
||||
|
||||
/** Description of the chain */
|
||||
description?: string
|
||||
|
||||
/** Optional chat completion params */
|
||||
params?: types.Simplify<
|
||||
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
|
||||
>
|
||||
|
||||
/** Optional tools */
|
||||
tools?: types.AIFunctionLike[]
|
||||
|
||||
/** Optional response schema */
|
||||
schema?: z.ZodType<Result> | types.Schema<Result>
|
||||
|
||||
/**
|
||||
* Whether or not the response schema should be treated as strict for
|
||||
* constrained structured output generation.
|
||||
*/
|
||||
strict?: boolean
|
||||
|
||||
/** Max number of LLM calls to allow */
|
||||
maxCalls?: number
|
||||
|
||||
/** Max number of retries to allow */
|
||||
maxRetries?: number
|
||||
|
||||
/** Max concurrency when invoking tool calls */
|
||||
toolCallConcurrency?: number
|
||||
|
||||
/** Whether or not to inject the schema into the context */
|
||||
injectSchemaIntoSystemMessage?: boolean
|
||||
}
|
||||
|
||||
|
@ -38,6 +65,8 @@ export type AIChainParams<Result extends types.AIChainResult = string> = {
|
|||
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
|
||||
*/
|
||||
export function createAIChain<Result extends types.AIChainResult = string>({
|
||||
name,
|
||||
description,
|
||||
chatFn,
|
||||
params,
|
||||
schema: rawSchema,
|
||||
|
@ -45,13 +74,28 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
maxCalls = 5,
|
||||
maxRetries = 2,
|
||||
toolCallConcurrency = 8,
|
||||
injectSchemaIntoSystemMessage = true
|
||||
injectSchemaIntoSystemMessage = false,
|
||||
strict = false
|
||||
}: AIChainParams<Result>): types.AIChain<Result> {
|
||||
const functionSet = new AIFunctionSet(tools)
|
||||
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
|
||||
|
||||
// TODO: support custom stopping criteria (like setting a flag in a tool call)
|
||||
|
||||
const defaultParams: Partial<types.ChatParams> | undefined =
|
||||
rawSchema && !functionSet.size
|
||||
schema && !functionSet.size
|
||||
? {
|
||||
response_format: { type: 'json_object' }
|
||||
response_format: strict
|
||||
? {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name,
|
||||
description,
|
||||
strict,
|
||||
schema: schema.jsonSchema
|
||||
}
|
||||
}
|
||||
: { type: 'json_object' }
|
||||
}
|
||||
: undefined
|
||||
|
||||
|
@ -77,8 +121,6 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
throw new Error('AIChain error: "messages" is empty')
|
||||
}
|
||||
|
||||
const schema = rawSchema ? asSchema(rawSchema) : undefined
|
||||
|
||||
if (schema && injectSchemaIntoSystemMessage) {
|
||||
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
|
||||
const lastSystemMessageContent =
|
||||
|
@ -101,6 +143,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
|
||||
do {
|
||||
++numCalls
|
||||
|
||||
const response = await chatFn({
|
||||
...modelParams,
|
||||
messages,
|
||||
|
@ -149,15 +192,11 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
throw new AbortError(
|
||||
'Function calls are not supported; expected tool call'
|
||||
)
|
||||
} else if (Msg.isRefusal(message)) {
|
||||
throw new AbortError(`Model refusal: ${message.refusal}`)
|
||||
} else if (Msg.isAssistant(message)) {
|
||||
if (schema && schema.validate) {
|
||||
const result = schema.validate(message.content)
|
||||
|
||||
if (result.success) {
|
||||
return result.data
|
||||
}
|
||||
|
||||
throw new Error(result.error)
|
||||
if (schema) {
|
||||
return schema.parse(message.content)
|
||||
} else {
|
||||
return message.content as Result
|
||||
}
|
||||
|
@ -169,6 +208,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
throw err
|
||||
}
|
||||
|
||||
console.warn(`Chain "${name}" error:`, err.message)
|
||||
|
||||
messages.push(
|
||||
Msg.user(
|
||||
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
|
||||
|
@ -177,7 +218,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
|
||||
if (numErrors > maxRetries) {
|
||||
throw new Error(
|
||||
`Chain failed after ${numErrors} errors: ${err.message}`,
|
||||
`Chain ${name} failed after ${numErrors} errors: ${err.message}`,
|
||||
{
|
||||
cause: err
|
||||
}
|
||||
|
@ -186,6 +227,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
}
|
||||
} while (numCalls < maxCalls)
|
||||
|
||||
throw new Error(`Chain aborted after reaching max ${maxCalls} calls`)
|
||||
throw new Error(
|
||||
`Chain "${name}" aborted after reaching max ${maxCalls} calls`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { describe, expect, it } from 'vitest'
|
||||
import { describe, expect, test } from 'vitest'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { createAIFunction } from './create-ai-function'
|
||||
|
@ -18,7 +18,7 @@ const fullName = createAIFunction(
|
|||
)
|
||||
|
||||
describe('createAIFunction()', () => {
|
||||
it('exposes OpenAI function calling spec', () => {
|
||||
test('exposes OpenAI function calling spec', () => {
|
||||
expect(fullName.spec.name).toEqual('fullName')
|
||||
expect(fullName.spec.description).toEqual(
|
||||
'Returns the full name of a person.'
|
||||
|
@ -34,7 +34,7 @@ describe('createAIFunction()', () => {
|
|||
})
|
||||
})
|
||||
|
||||
it('executes the function', async () => {
|
||||
test('executes the function', async () => {
|
||||
expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual(
|
||||
'John Doe'
|
||||
)
|
||||
|
|
|
@ -20,8 +20,13 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
|
|||
name: string
|
||||
/** Description of the function. */
|
||||
description?: string
|
||||
/** Zod schema for the arguments string. */
|
||||
/** Zod schema for the function parameters. */
|
||||
inputSchema: InputSchema
|
||||
/**
|
||||
* Whether or not to enable structured output generation based on the given
|
||||
* zod schema.
|
||||
*/
|
||||
strict?: boolean
|
||||
},
|
||||
/** Implementation of the function to call with the parsed arguments. */
|
||||
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
|
||||
|
@ -59,12 +64,15 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
|
|||
return implementation(parsedInput)
|
||||
}
|
||||
|
||||
const strict = !!spec.strict
|
||||
|
||||
aiFunction.inputSchema = spec.inputSchema
|
||||
aiFunction.parseInput = parseInput
|
||||
aiFunction.spec = {
|
||||
name: spec.name,
|
||||
description: spec.description?.trim() ?? '',
|
||||
parameters: zodToJsonSchema(spec.inputSchema)
|
||||
parameters: zodToJsonSchema(spec.inputSchema, { strict }),
|
||||
strict
|
||||
}
|
||||
aiFunction.impl = implementation
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ export interface PrivateAIFunctionMetadata {
|
|||
description: string
|
||||
inputSchema: z.AnyZodObject
|
||||
methodName: string
|
||||
strict?: boolean
|
||||
}
|
||||
|
||||
// Polyfill for `Symbol.metadata`
|
||||
|
@ -69,11 +70,13 @@ export function aiFunction<
|
|||
>({
|
||||
name,
|
||||
description,
|
||||
inputSchema
|
||||
inputSchema,
|
||||
strict
|
||||
}: {
|
||||
name?: string
|
||||
description: string
|
||||
inputSchema: InputSchema
|
||||
strict?: boolean
|
||||
}) {
|
||||
return (
|
||||
_targetMethod: (
|
||||
|
@ -99,7 +102,8 @@ export function aiFunction<
|
|||
name: name ?? methodName,
|
||||
description,
|
||||
inputSchema,
|
||||
methodName
|
||||
methodName,
|
||||
strict
|
||||
})
|
||||
|
||||
context.addInitializer(function () {
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import type * as OpenAI from 'openai-fetch'
|
||||
import { describe, expect, expectTypeOf, it } from 'vitest'
|
||||
import { describe, expect, expectTypeOf, test } from 'vitest'
|
||||
|
||||
import type * as types from './types'
|
||||
import { Msg } from './message'
|
||||
|
||||
describe('Msg', () => {
|
||||
it('creates a message and fixes indentation', () => {
|
||||
test('creates a message and fixes indentation', () => {
|
||||
const msgContent = `
|
||||
Hello, World!
|
||||
`
|
||||
|
@ -14,7 +14,7 @@ describe('Msg', () => {
|
|||
expect(msg.content).toEqual('Hello, World!')
|
||||
})
|
||||
|
||||
it('supports disabling indentation fixing', () => {
|
||||
test('supports disabling indentation fixing', () => {
|
||||
const msgContent = `
|
||||
Hello, World!
|
||||
`
|
||||
|
@ -22,7 +22,7 @@ describe('Msg', () => {
|
|||
expect(msg.content).toEqual('\n Hello, World!\n ')
|
||||
})
|
||||
|
||||
it('handles tool calls request', () => {
|
||||
test('handles tool calls request', () => {
|
||||
const msg = Msg.toolCall([
|
||||
{
|
||||
id: 'fake-tool-call-id',
|
||||
|
@ -37,13 +37,13 @@ describe('Msg', () => {
|
|||
expect(Msg.isToolCall(msg)).toBe(true)
|
||||
})
|
||||
|
||||
it('handles tool call response', () => {
|
||||
test('handles tool call response', () => {
|
||||
const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id')
|
||||
expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>()
|
||||
expect(Msg.isToolResult(msg)).toBe(true)
|
||||
})
|
||||
|
||||
it('prompt message types should interop with openai-fetch message types', () => {
|
||||
test('prompt message types should interop with openai-fetch message types', () => {
|
||||
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>()
|
||||
expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>()
|
||||
expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>()
|
||||
|
|
|
@ -7,10 +7,17 @@ import { cleanStringForModel, stringifyForModel } from './utils'
|
|||
*/
|
||||
export interface Msg {
|
||||
/**
|
||||
* The contents of the message. `content` is required for all messages, and
|
||||
* may be null for assistant messages with function calls.
|
||||
* The contents of the message. `content` may be null for assistant messages
|
||||
* with function calls or `undefined` for assistant messages if a `refusal`
|
||||
* was given by the model.
|
||||
*/
|
||||
content: string | null
|
||||
content?: string | null
|
||||
|
||||
/**
|
||||
* The reason the model refused to generate this message or `undefined` if the
|
||||
* message was generated successfully.
|
||||
*/
|
||||
refusal?: string | null
|
||||
|
||||
/**
|
||||
* The role of the messages author. One of `system`, `user`, `assistant`,
|
||||
|
@ -43,6 +50,15 @@ export interface Msg {
|
|||
name?: string
|
||||
}
|
||||
|
||||
export interface LegacyMsg {
|
||||
content: string | null
|
||||
role: Msg.Role
|
||||
function_call?: Msg.Call.Function
|
||||
tool_calls?: Msg.Call.Tool[]
|
||||
tool_call_id?: string
|
||||
name?: string
|
||||
}
|
||||
|
||||
/** Narrowed OpenAI Message types. */
|
||||
export namespace Msg {
|
||||
/** Possible roles for a message. */
|
||||
|
@ -98,6 +114,13 @@ export namespace Msg {
|
|||
content: string
|
||||
}
|
||||
|
||||
/** Message with refusal reason from the assistant. */
|
||||
export type Refusal = {
|
||||
role: 'assistant'
|
||||
name?: string
|
||||
refusal: string
|
||||
}
|
||||
|
||||
/** Message with arguments to call a function. */
|
||||
export type FuncCall = {
|
||||
role: 'assistant'
|
||||
|
@ -185,6 +208,27 @@ export namespace Msg {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an assistant refusal message. Cleans indentation and newlines by
|
||||
* default.
|
||||
*/
|
||||
export function refusal(
|
||||
refusal: string,
|
||||
opts?: {
|
||||
/** Custom name for the message. */
|
||||
name?: string
|
||||
/** Whether to clean extra newlines and indentation. Defaults to true. */
|
||||
cleanRefusal?: boolean
|
||||
}
|
||||
): Msg.Refusal {
|
||||
const { name, cleanRefusal = true } = opts ?? {}
|
||||
return {
|
||||
role: 'assistant',
|
||||
refusal: cleanRefusal ? cleanStringForModel(refusal) : refusal,
|
||||
...(name ? { name } : {})
|
||||
}
|
||||
}
|
||||
|
||||
/** Create a function call message with argumets. */
|
||||
export function funcCall(
|
||||
function_call: {
|
||||
|
@ -249,7 +293,7 @@ export namespace Msg {
|
|||
// @TODO
|
||||
response: any
|
||||
// response: ChatModel.EnrichedResponse
|
||||
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
|
||||
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
|
||||
const msg = response.choices[0].message as Msg
|
||||
return narrowResponseMessage(msg)
|
||||
}
|
||||
|
@ -257,13 +301,15 @@ export namespace Msg {
|
|||
/** Narrow a message received from the API. It only responds with role=assistant */
|
||||
export function narrowResponseMessage(
|
||||
msg: Msg
|
||||
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
|
||||
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
|
||||
if (msg.content === null && msg.tool_calls != null) {
|
||||
return Msg.toolCall(msg.tool_calls)
|
||||
} else if (msg.content === null && msg.function_call != null) {
|
||||
return Msg.funcCall(msg.function_call)
|
||||
} else if (msg.content !== null) {
|
||||
} else if (msg.content !== null && msg.content !== undefined) {
|
||||
return Msg.assistant(msg.content)
|
||||
} else if (msg.refusal != null) {
|
||||
return Msg.refusal(msg.refusal)
|
||||
} else {
|
||||
// @TODO: probably don't want to error here
|
||||
console.log('Invalid message', msg)
|
||||
|
@ -283,6 +329,10 @@ export namespace Msg {
|
|||
export function isAssistant(message: Msg): message is Msg.Assistant {
|
||||
return message.role === 'assistant' && message.content !== null
|
||||
}
|
||||
/** Check if a message is an assistant refusal message. */
|
||||
export function isRefusal(message: Msg): message is Msg.Refusal {
|
||||
return message.role === 'assistant' && message.refusal !== null
|
||||
}
|
||||
/** Check if a message is a function call message with arguments. */
|
||||
export function isFuncCall(message: Msg): message is Msg.FuncCall {
|
||||
return message.role === 'assistant' && message.function_call != null
|
||||
|
@ -304,6 +354,7 @@ export namespace Msg {
|
|||
export function narrow(message: Msg.System): Msg.System
|
||||
export function narrow(message: Msg.User): Msg.User
|
||||
export function narrow(message: Msg.Assistant): Msg.Assistant
|
||||
export function narrow(message: Msg.Assistant): Msg.Refusal
|
||||
export function narrow(message: Msg.FuncCall): Msg.FuncCall
|
||||
export function narrow(message: Msg.FuncResult): Msg.FuncResult
|
||||
export function narrow(message: Msg.ToolCall): Msg.ToolCall
|
||||
|
@ -314,6 +365,7 @@ export namespace Msg {
|
|||
| Msg.System
|
||||
| Msg.User
|
||||
| Msg.Assistant
|
||||
| Msg.Refusal
|
||||
| Msg.FuncCall
|
||||
| Msg.FuncResult
|
||||
| Msg.ToolCall
|
||||
|
@ -327,6 +379,9 @@ export namespace Msg {
|
|||
if (isAssistant(message)) {
|
||||
return message
|
||||
}
|
||||
if (isRefusal(message)) {
|
||||
return message
|
||||
}
|
||||
if (isFuncCall(message)) {
|
||||
return message
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import type { z } from 'zod'
|
||||
|
||||
import type * as types from './types'
|
||||
import { safeParseStructuredOutput } from './parse-structured-output'
|
||||
import { parseStructuredOutput } from './parse-structured-output'
|
||||
import { stringifyForModel } from './utils'
|
||||
import { zodToJsonSchema } from './zod-to-json-schema'
|
||||
|
||||
|
@ -9,7 +9,6 @@ import { zodToJsonSchema } from './zod-to-json-schema'
|
|||
* Used to mark schemas so we can support both Zod and custom schemas.
|
||||
*/
|
||||
export const schemaSymbol = Symbol('agentic.schema')
|
||||
export const validatorSymbol = Symbol('agentic.validator')
|
||||
|
||||
export type Schema<TData = unknown> = {
|
||||
/**
|
||||
|
@ -18,10 +17,18 @@ export type Schema<TData = unknown> = {
|
|||
readonly jsonSchema: types.JSONSchema
|
||||
|
||||
/**
|
||||
* Optional. Validates that the structure of a value matches this schema,
|
||||
* and returns a typed version of the value if it does.
|
||||
* Parses the value, validates that it matches this schema, and returns a
|
||||
* typed version of the value if it does. Throw an error if the value does
|
||||
* not match the schema.
|
||||
*/
|
||||
readonly validate?: types.ValidatorFn<TData>
|
||||
readonly parse: types.ParseFn<TData>
|
||||
|
||||
/**
|
||||
* Parses the value, validates that it matches this schema, and returns a
|
||||
* typed version of the value if it does. Returns an error message if the
|
||||
* value does not match the schema, and will never throw an error.
|
||||
*/
|
||||
readonly safeParse: types.SafeParseFn<TData>
|
||||
|
||||
/**
|
||||
* Used to mark schemas so we can support both Zod and custom schemas.
|
||||
|
@ -41,7 +48,7 @@ export function isSchema(value: unknown): value is Schema {
|
|||
schemaSymbol in value &&
|
||||
value[schemaSymbol] === true &&
|
||||
'jsonSchema' in value &&
|
||||
'validate' in value
|
||||
'parse' in value
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -59,9 +66,10 @@ export function isZodSchema(value: unknown): value is z.ZodType {
|
|||
}
|
||||
|
||||
export function asSchema<TData>(
|
||||
schema: z.Schema<TData> | Schema<TData>
|
||||
schema: z.Schema<TData> | Schema<TData>,
|
||||
opts: { strict?: boolean } = {}
|
||||
): Schema<TData> {
|
||||
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema)
|
||||
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema, opts)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -70,25 +78,38 @@ export function asSchema<TData>(
|
|||
export function createSchema<TData = unknown>(
|
||||
jsonSchema: types.JSONSchema,
|
||||
{
|
||||
validate
|
||||
parse = (value) => value as TData,
|
||||
safeParse
|
||||
}: {
|
||||
validate?: types.ValidatorFn<TData>
|
||||
parse?: types.ParseFn<TData>
|
||||
safeParse?: types.SafeParseFn<TData>
|
||||
} = {}
|
||||
): Schema<TData> {
|
||||
safeParse ??= (value: unknown) => {
|
||||
try {
|
||||
const result = parse(value)
|
||||
return { success: true, data: result }
|
||||
} catch (err: any) {
|
||||
return { success: false, error: err.message ?? String(err) }
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
[schemaSymbol]: true,
|
||||
_type: undefined as TData,
|
||||
jsonSchema,
|
||||
validate
|
||||
parse,
|
||||
safeParse
|
||||
}
|
||||
}
|
||||
|
||||
export function createSchemaFromZodSchema<TData>(
|
||||
zodSchema: z.Schema<TData>
|
||||
zodSchema: z.Schema<TData>,
|
||||
opts: { strict?: boolean } = {}
|
||||
): Schema<TData> {
|
||||
return createSchema(zodToJsonSchema(zodSchema), {
|
||||
validate: (value) => {
|
||||
return safeParseStructuredOutput(value, zodSchema)
|
||||
return createSchema(zodToJsonSchema(zodSchema, opts), {
|
||||
parse: (value) => {
|
||||
return parseStructuredOutput(value, zodSchema)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,13 +3,13 @@ import type { z } from 'zod'
|
|||
|
||||
import type { AIFunctionSet } from './ai-function-set'
|
||||
import type { AIFunctionsProvider } from './fns'
|
||||
import type { Msg } from './message'
|
||||
import type { LegacyMsg, Msg } from './message'
|
||||
|
||||
export type { Msg } from './message'
|
||||
export type { Schema } from './schema'
|
||||
export type { KyInstance } from 'ky'
|
||||
export type { ThrottledFunction } from 'p-throttle'
|
||||
export type { SetRequired, Simplify } from 'type-fest'
|
||||
export type { SetOptional, SetRequired, Simplify } from 'type-fest'
|
||||
|
||||
export type Nullable<T> = T | null
|
||||
|
||||
|
@ -33,6 +33,13 @@ export interface AIFunctionSpec {
|
|||
|
||||
/** JSON schema spec of the function's input parameters */
|
||||
parameters: JSONSchema
|
||||
|
||||
/**
|
||||
* Whether to enable strict schema adherence when generating the function
|
||||
* parameters. Currently only supported by OpenAI's
|
||||
* [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs).
|
||||
*/
|
||||
strict?: boolean
|
||||
}
|
||||
|
||||
export interface AIToolSpec {
|
||||
|
@ -102,7 +109,17 @@ export interface ChatParams {
|
|||
max_tokens?: number
|
||||
presence_penalty?: number
|
||||
frequency_penalty?: number
|
||||
response_format?: { type: 'text' | 'json_object' }
|
||||
response_format?:
|
||||
| {
|
||||
type: 'text'
|
||||
}
|
||||
| {
|
||||
type: 'json_object'
|
||||
}
|
||||
| {
|
||||
type: 'json_schema'
|
||||
json_schema: ResponseFormatJSONSchema
|
||||
}
|
||||
seed?: number
|
||||
stop?: string | null | Array<string>
|
||||
temperature?: number
|
||||
|
@ -111,10 +128,53 @@ export interface ChatParams {
|
|||
user?: string
|
||||
}
|
||||
|
||||
export type LegacyChatParams = Simplify<
|
||||
Omit<ChatParams, 'messages'> & { messages: LegacyMsg[] }
|
||||
>
|
||||
|
||||
export interface ResponseFormatJSONSchema {
|
||||
/**
|
||||
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
|
||||
* underscores and dashes, with a maximum length of 64.
|
||||
*/
|
||||
name: string
|
||||
|
||||
/**
|
||||
* A description of what the response format is for, used by the model to
|
||||
* determine how to respond in the format.
|
||||
*/
|
||||
description?: string
|
||||
|
||||
/**
|
||||
* The schema for the response format, described as a JSON Schema object.
|
||||
*/
|
||||
schema?: JSONSchema
|
||||
|
||||
/**
|
||||
* Whether to enable strict schema adherence when generating the output. If
|
||||
* set to true, the model will always follow the exact schema defined in the
|
||||
* `schema` field. Only a subset of JSON Schema is supported when `strict`
|
||||
* is `true`. Currently only supported by OpenAI's
|
||||
* [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs).
|
||||
*/
|
||||
strict?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI has changed some of their types, so instead of trying to support all
|
||||
* possible types, for these params, just relax them for now.
|
||||
*/
|
||||
export type RelaxedChatParams = Simplify<
|
||||
Omit<ChatParams, 'messages' | 'response_format'> & {
|
||||
messages: any[]
|
||||
response_format?: any
|
||||
}
|
||||
>
|
||||
|
||||
/** An OpenAI-compatible chat completions API */
|
||||
export type ChatFn = (
|
||||
params: Simplify<SetOptional<ChatParams, 'model'>>
|
||||
) => Promise<{ message: Msg }>
|
||||
params: Simplify<SetOptional<RelaxedChatParams, 'model'>>
|
||||
) => Promise<{ message: Msg | LegacyMsg }>
|
||||
|
||||
export type AIChainResult = string | Record<string, any>
|
||||
|
||||
|
@ -134,4 +194,5 @@ export type SafeParseResult<TData> =
|
|||
error: string
|
||||
}
|
||||
|
||||
export type ValidatorFn<TData> = (value: unknown) => SafeParseResult<TData>
|
||||
export type ParseFn<TData> = (value: unknown) => TData
|
||||
export type SafeParseFn<TData> = (value: unknown) => SafeParseResult<TData>
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import { describe, expect, it } from 'vitest'
|
||||
import { describe, expect, test } from 'vitest'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { zodToJsonSchema } from './zod-to-json-schema'
|
||||
|
||||
describe('zodToJsonSchema', () => {
|
||||
it('handles basic objects', () => {
|
||||
test('handles basic objects', () => {
|
||||
const params = zodToJsonSchema(
|
||||
z.object({
|
||||
name: z.string().min(1).describe('Name of the person'),
|
||||
|
@ -30,7 +30,7 @@ describe('zodToJsonSchema', () => {
|
|||
})
|
||||
})
|
||||
|
||||
it('handles enums and unions', () => {
|
||||
test('handles enums and unions', () => {
|
||||
const params = zodToJsonSchema(
|
||||
z.object({
|
||||
name: z.string().min(1).describe('Name of the person'),
|
||||
|
|
|
@ -1,13 +1,23 @@
|
|||
import type { z } from 'zod'
|
||||
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema'
|
||||
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'openai-zod-to-json-schema'
|
||||
|
||||
import type * as types from './types'
|
||||
import { omit } from './utils'
|
||||
|
||||
/** Generate a JSON Schema from a Zod schema. */
|
||||
export function zodToJsonSchema(schema: z.ZodType): types.JSONSchema {
|
||||
export function zodToJsonSchema(
|
||||
schema: z.ZodType,
|
||||
{
|
||||
strict = false
|
||||
}: {
|
||||
strict?: boolean
|
||||
} = {}
|
||||
): types.JSONSchema {
|
||||
return omit(
|
||||
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }),
|
||||
zodToJsonSchemaImpl(schema, {
|
||||
$refStrategy: 'none',
|
||||
openaiStrictMode: strict
|
||||
}),
|
||||
'$schema',
|
||||
'default',
|
||||
'definitions',
|
||||
|
|
1497
pnpm-lock.yaml
1497
pnpm-lock.yaml
Plik diff jest za duży
Load Diff
|
@ -12,13 +12,7 @@
|
|||
"dependsOn": ["^clean"]
|
||||
},
|
||||
"test": {
|
||||
"dependsOn": [
|
||||
"build",
|
||||
"test:format",
|
||||
"test:lint",
|
||||
"test:typecheck",
|
||||
"test:unit"
|
||||
]
|
||||
"dependsOn": ["test:format", "test:lint", "test:typecheck", "test:unit"]
|
||||
},
|
||||
"test:lint": {
|
||||
"dependsOn": ["^test:lint"],
|
||||
|
|
Ładowanie…
Reference in New Issue