Merge pull request #658 from transitive-bullshit/feature/openai-structured-outputs

old-agentic
Travis Fischer 2024-08-08 01:30:54 -05:00 zatwierdzone przez GitHub
commit 934c4b3536
17 zmienionych plików z 1436 dodań i 442 usunięć

Wyświetl plik

@ -10,6 +10,7 @@ jobs:
fail-fast: true fail-fast: true
matrix: matrix:
node-version: node-version:
- 18
- 20 - 20
- 21 - 21
- 22 - 22
@ -18,33 +19,23 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 9.7.0
run_install: false
- name: Install Node.js - name: Install Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v4
with: with:
node-version: ${{ matrix.node-version }} node-version: ${{ matrix.node-version }}
cache: 'pnpm'
- name: Install pnpm
uses: pnpm/action-setup@v3
id: pnpm-install
with:
version: 9.6.0
run_install: false
- name: Get pnpm store directory
shell: bash
run: |
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
- name: Setup pnpm cache
uses: actions/cache@v4
with:
path: ${{ env.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-pnpm-store-
- name: Install dependencies - name: Install dependencies
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile --strict-peer-dependencies
- name: Run build
run: pnpm build
- name: Run test - name: Run test
run: pnpm run test run: pnpm test

Wyświetl plik

@ -16,6 +16,7 @@ async function main() {
}) })
const chain = createAIChain({ const chain = createAIChain({
name: 'search_news',
chatFn: chatModel.run.bind(chatModel), chatFn: chatModel.run.bind(chatModel),
tools: [perigon.functions.pick('search_news_stories'), serper], tools: [perigon.functions.pick('search_news_stories'), serper],
params: { params: {

Wyświetl plik

@ -12,6 +12,7 @@ async function main() {
}) })
const result = await extractObject({ const result = await extractObject({
name: 'extract-user',
chatFn: chatModel.run.bind(chatModel), chatFn: chatModel.run.bind(chatModel),
params: { params: {
messages: [ messages: [
@ -25,7 +26,8 @@ async function main() {
name: z.string(), name: z.string(),
age: z.number(), age: z.number(),
location: z.string().optional() location: z.string().optional()
}) }),
strict: true
}) })
console.log(result) console.log(result)

Wyświetl plik

@ -7,7 +7,7 @@
"type": "git", "type": "git",
"url": "git+https://github.com/transitive-bullshit/agentic.git" "url": "git+https://github.com/transitive-bullshit/agentic.git"
}, },
"packageManager": "pnpm@9.6.0", "packageManager": "pnpm@9.7.0",
"engines": { "engines": {
"node": ">=18" "node": ">=18"
}, },
@ -26,6 +26,7 @@
"release:build": "run-s build", "release:build": "run-s build",
"release:version": "changeset version", "release:version": "changeset version",
"release:publish": "changeset publish", "release:publish": "changeset publish",
"pretest": "run-s build",
"precommit": "lint-staged", "precommit": "lint-staged",
"preinstall": "npx only-allow pnpm", "preinstall": "npx only-allow pnpm",
"prepare": "husky" "prepare": "husky"
@ -45,7 +46,7 @@
"prettier": "^3.3.3", "prettier": "^3.3.3",
"tsup": "^8.2.4", "tsup": "^8.2.4",
"tsx": "^4.16.5", "tsx": "^4.16.5",
"turbo": "^2.0.11", "turbo": "^2.0.12",
"typescript": "^5.5.4", "typescript": "^5.5.4",
"vitest": "2.0.5", "vitest": "2.0.5",
"zod": "^3.23.8" "zod": "^3.23.8"

Wyświetl plik

@ -40,11 +40,11 @@
"jsonrepair": "^3.6.1", "jsonrepair": "^3.6.1",
"ky": "^1.5.0", "ky": "^1.5.0",
"normalize-url": "^8.0.1", "normalize-url": "^8.0.1",
"openai-zod-to-json-schema": "^1.0.0",
"p-map": "^7.0.2", "p-map": "^7.0.2",
"p-throttle": "^6.1.0", "p-throttle": "^6.1.0",
"quick-lru": "^7.0.0", "quick-lru": "^7.0.0",
"type-fest": "^4.21.0", "type-fest": "^4.21.0",
"zod-to-json-schema": "^3.23.2",
"zod-validation-error": "^3.3.0" "zod-validation-error": "^3.3.0"
}, },
"peerDependencies": { "peerDependencies": {
@ -52,7 +52,7 @@
}, },
"devDependencies": { "devDependencies": {
"@agentic/tsconfig": "workspace:*", "@agentic/tsconfig": "workspace:*",
"openai-fetch": "^2.0.4" "openai-fetch": "^3.0.0"
}, },
"publishConfig": { "publishConfig": {
"access": "public" "access": "public"

Wyświetl plik

@ -10,15 +10,42 @@ import { asSchema, augmentSystemMessageWithJsonSchema } from './schema'
import { getErrorMessage } from './utils' import { getErrorMessage } from './utils'
export type AIChainParams<Result extends types.AIChainResult = string> = { export type AIChainParams<Result extends types.AIChainResult = string> = {
/** Name of the chain */
name: string
/** Chat completions function */
chatFn: types.ChatFn chatFn: types.ChatFn
/** Description of the chain */
description?: string
/** Optional chat completion params */
params?: types.Simplify< params?: types.Simplify<
Partial<Omit<types.ChatParams, 'tools' | 'functions'>> Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
> >
/** Optional tools */
tools?: types.AIFunctionLike[] tools?: types.AIFunctionLike[]
/** Optional response schema */
schema?: z.ZodType<Result> | types.Schema<Result> schema?: z.ZodType<Result> | types.Schema<Result>
/**
* Whether or not the response schema should be treated as strict for
* constrained structured output generation.
*/
strict?: boolean
/** Max number of LLM calls to allow */
maxCalls?: number maxCalls?: number
/** Max number of retries to allow */
maxRetries?: number maxRetries?: number
/** Max concurrency when invoking tool calls */
toolCallConcurrency?: number toolCallConcurrency?: number
/** Whether or not to inject the schema into the context */
injectSchemaIntoSystemMessage?: boolean injectSchemaIntoSystemMessage?: boolean
} }
@ -38,6 +65,8 @@ export type AIChainParams<Result extends types.AIChainResult = string> = {
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`). * exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
*/ */
export function createAIChain<Result extends types.AIChainResult = string>({ export function createAIChain<Result extends types.AIChainResult = string>({
name,
description,
chatFn, chatFn,
params, params,
schema: rawSchema, schema: rawSchema,
@ -45,13 +74,28 @@ export function createAIChain<Result extends types.AIChainResult = string>({
maxCalls = 5, maxCalls = 5,
maxRetries = 2, maxRetries = 2,
toolCallConcurrency = 8, toolCallConcurrency = 8,
injectSchemaIntoSystemMessage = true injectSchemaIntoSystemMessage = false,
strict = false
}: AIChainParams<Result>): types.AIChain<Result> { }: AIChainParams<Result>): types.AIChain<Result> {
const functionSet = new AIFunctionSet(tools) const functionSet = new AIFunctionSet(tools)
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
// TODO: support custom stopping criteria (like setting a flag in a tool call)
const defaultParams: Partial<types.ChatParams> | undefined = const defaultParams: Partial<types.ChatParams> | undefined =
rawSchema && !functionSet.size schema && !functionSet.size
? { ? {
response_format: { type: 'json_object' } response_format: strict
? {
type: 'json_schema',
json_schema: {
name,
description,
strict,
schema: schema.jsonSchema
}
}
: { type: 'json_object' }
} }
: undefined : undefined
@ -77,8 +121,6 @@ export function createAIChain<Result extends types.AIChainResult = string>({
throw new Error('AIChain error: "messages" is empty') throw new Error('AIChain error: "messages" is empty')
} }
const schema = rawSchema ? asSchema(rawSchema) : undefined
if (schema && injectSchemaIntoSystemMessage) { if (schema && injectSchemaIntoSystemMessage) {
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem) const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
const lastSystemMessageContent = const lastSystemMessageContent =
@ -101,6 +143,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
do { do {
++numCalls ++numCalls
const response = await chatFn({ const response = await chatFn({
...modelParams, ...modelParams,
messages, messages,
@ -149,15 +192,11 @@ export function createAIChain<Result extends types.AIChainResult = string>({
throw new AbortError( throw new AbortError(
'Function calls are not supported; expected tool call' 'Function calls are not supported; expected tool call'
) )
} else if (Msg.isRefusal(message)) {
throw new AbortError(`Model refusal: ${message.refusal}`)
} else if (Msg.isAssistant(message)) { } else if (Msg.isAssistant(message)) {
if (schema && schema.validate) { if (schema) {
const result = schema.validate(message.content) return schema.parse(message.content)
if (result.success) {
return result.data
}
throw new Error(result.error)
} else { } else {
return message.content as Result return message.content as Result
} }
@ -169,6 +208,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
throw err throw err
} }
console.warn(`Chain "${name}" error:`, err.message)
messages.push( messages.push(
Msg.user( Msg.user(
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}` `There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
@ -177,7 +218,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
if (numErrors > maxRetries) { if (numErrors > maxRetries) {
throw new Error( throw new Error(
`Chain failed after ${numErrors} errors: ${err.message}`, `Chain ${name} failed after ${numErrors} errors: ${err.message}`,
{ {
cause: err cause: err
} }
@ -186,6 +227,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
} }
} while (numCalls < maxCalls) } while (numCalls < maxCalls)
throw new Error(`Chain aborted after reaching max ${maxCalls} calls`) throw new Error(
`Chain "${name}" aborted after reaching max ${maxCalls} calls`
)
} }
} }

Wyświetl plik

@ -1,4 +1,4 @@
import { describe, expect, it } from 'vitest' import { describe, expect, test } from 'vitest'
import { z } from 'zod' import { z } from 'zod'
import { createAIFunction } from './create-ai-function' import { createAIFunction } from './create-ai-function'
@ -18,7 +18,7 @@ const fullName = createAIFunction(
) )
describe('createAIFunction()', () => { describe('createAIFunction()', () => {
it('exposes OpenAI function calling spec', () => { test('exposes OpenAI function calling spec', () => {
expect(fullName.spec.name).toEqual('fullName') expect(fullName.spec.name).toEqual('fullName')
expect(fullName.spec.description).toEqual( expect(fullName.spec.description).toEqual(
'Returns the full name of a person.' 'Returns the full name of a person.'
@ -34,7 +34,7 @@ describe('createAIFunction()', () => {
}) })
}) })
it('executes the function', async () => { test('executes the function', async () => {
expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual( expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual(
'John Doe' 'John Doe'
) )

Wyświetl plik

@ -20,8 +20,13 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
name: string name: string
/** Description of the function. */ /** Description of the function. */
description?: string description?: string
/** Zod schema for the arguments string. */ /** Zod schema for the function parameters. */
inputSchema: InputSchema inputSchema: InputSchema
/**
* Whether or not to enable structured output generation based on the given
* zod schema.
*/
strict?: boolean
}, },
/** Implementation of the function to call with the parsed arguments. */ /** Implementation of the function to call with the parsed arguments. */
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output> implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
@ -59,12 +64,15 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
return implementation(parsedInput) return implementation(parsedInput)
} }
const strict = !!spec.strict
aiFunction.inputSchema = spec.inputSchema aiFunction.inputSchema = spec.inputSchema
aiFunction.parseInput = parseInput aiFunction.parseInput = parseInput
aiFunction.spec = { aiFunction.spec = {
name: spec.name, name: spec.name,
description: spec.description?.trim() ?? '', description: spec.description?.trim() ?? '',
parameters: zodToJsonSchema(spec.inputSchema) parameters: zodToJsonSchema(spec.inputSchema, { strict }),
strict
} }
aiFunction.impl = implementation aiFunction.impl = implementation

Wyświetl plik

@ -10,6 +10,7 @@ export interface PrivateAIFunctionMetadata {
description: string description: string
inputSchema: z.AnyZodObject inputSchema: z.AnyZodObject
methodName: string methodName: string
strict?: boolean
} }
// Polyfill for `Symbol.metadata` // Polyfill for `Symbol.metadata`
@ -69,11 +70,13 @@ export function aiFunction<
>({ >({
name, name,
description, description,
inputSchema inputSchema,
strict
}: { }: {
name?: string name?: string
description: string description: string
inputSchema: InputSchema inputSchema: InputSchema
strict?: boolean
}) { }) {
return ( return (
_targetMethod: ( _targetMethod: (
@ -99,7 +102,8 @@ export function aiFunction<
name: name ?? methodName, name: name ?? methodName,
description, description,
inputSchema, inputSchema,
methodName methodName,
strict
}) })
context.addInitializer(function () { context.addInitializer(function () {

Wyświetl plik

@ -1,11 +1,11 @@
import type * as OpenAI from 'openai-fetch' import type * as OpenAI from 'openai-fetch'
import { describe, expect, expectTypeOf, it } from 'vitest' import { describe, expect, expectTypeOf, test } from 'vitest'
import type * as types from './types' import type * as types from './types'
import { Msg } from './message' import { Msg } from './message'
describe('Msg', () => { describe('Msg', () => {
it('creates a message and fixes indentation', () => { test('creates a message and fixes indentation', () => {
const msgContent = ` const msgContent = `
Hello, World! Hello, World!
` `
@ -14,7 +14,7 @@ describe('Msg', () => {
expect(msg.content).toEqual('Hello, World!') expect(msg.content).toEqual('Hello, World!')
}) })
it('supports disabling indentation fixing', () => { test('supports disabling indentation fixing', () => {
const msgContent = ` const msgContent = `
Hello, World! Hello, World!
` `
@ -22,7 +22,7 @@ describe('Msg', () => {
expect(msg.content).toEqual('\n Hello, World!\n ') expect(msg.content).toEqual('\n Hello, World!\n ')
}) })
it('handles tool calls request', () => { test('handles tool calls request', () => {
const msg = Msg.toolCall([ const msg = Msg.toolCall([
{ {
id: 'fake-tool-call-id', id: 'fake-tool-call-id',
@ -37,13 +37,13 @@ describe('Msg', () => {
expect(Msg.isToolCall(msg)).toBe(true) expect(Msg.isToolCall(msg)).toBe(true)
}) })
it('handles tool call response', () => { test('handles tool call response', () => {
const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id') const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id')
expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>() expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>()
expect(Msg.isToolResult(msg)).toBe(true) expect(Msg.isToolResult(msg)).toBe(true)
}) })
it('prompt message types should interop with openai-fetch message types', () => { test('prompt message types should interop with openai-fetch message types', () => {
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>() expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>()
expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>() expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>()
expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>() expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>()

Wyświetl plik

@ -7,10 +7,17 @@ import { cleanStringForModel, stringifyForModel } from './utils'
*/ */
export interface Msg { export interface Msg {
/** /**
* The contents of the message. `content` is required for all messages, and * The contents of the message. `content` may be null for assistant messages
* may be null for assistant messages with function calls. * with function calls or `undefined` for assistant messages if a `refusal`
* was given by the model.
*/ */
content: string | null content?: string | null
/**
* The reason the model refused to generate this message or `undefined` if the
* message was generated successfully.
*/
refusal?: string | null
/** /**
* The role of the messages author. One of `system`, `user`, `assistant`, * The role of the messages author. One of `system`, `user`, `assistant`,
@ -43,6 +50,15 @@ export interface Msg {
name?: string name?: string
} }
export interface LegacyMsg {
content: string | null
role: Msg.Role
function_call?: Msg.Call.Function
tool_calls?: Msg.Call.Tool[]
tool_call_id?: string
name?: string
}
/** Narrowed OpenAI Message types. */ /** Narrowed OpenAI Message types. */
export namespace Msg { export namespace Msg {
/** Possible roles for a message. */ /** Possible roles for a message. */
@ -98,6 +114,13 @@ export namespace Msg {
content: string content: string
} }
/** Message with refusal reason from the assistant. */
export type Refusal = {
role: 'assistant'
name?: string
refusal: string
}
/** Message with arguments to call a function. */ /** Message with arguments to call a function. */
export type FuncCall = { export type FuncCall = {
role: 'assistant' role: 'assistant'
@ -185,6 +208,27 @@ export namespace Msg {
} }
} }
/**
* Create an assistant refusal message. Cleans indentation and newlines by
* default.
*/
export function refusal(
refusal: string,
opts?: {
/** Custom name for the message. */
name?: string
/** Whether to clean extra newlines and indentation. Defaults to true. */
cleanRefusal?: boolean
}
): Msg.Refusal {
const { name, cleanRefusal = true } = opts ?? {}
return {
role: 'assistant',
refusal: cleanRefusal ? cleanStringForModel(refusal) : refusal,
...(name ? { name } : {})
}
}
/** Create a function call message with argumets. */ /** Create a function call message with argumets. */
export function funcCall( export function funcCall(
function_call: { function_call: {
@ -249,7 +293,7 @@ export namespace Msg {
// @TODO // @TODO
response: any response: any
// response: ChatModel.EnrichedResponse // response: ChatModel.EnrichedResponse
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall { ): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
const msg = response.choices[0].message as Msg const msg = response.choices[0].message as Msg
return narrowResponseMessage(msg) return narrowResponseMessage(msg)
} }
@ -257,13 +301,15 @@ export namespace Msg {
/** Narrow a message received from the API. It only responds with role=assistant */ /** Narrow a message received from the API. It only responds with role=assistant */
export function narrowResponseMessage( export function narrowResponseMessage(
msg: Msg msg: Msg
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall { ): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
if (msg.content === null && msg.tool_calls != null) { if (msg.content === null && msg.tool_calls != null) {
return Msg.toolCall(msg.tool_calls) return Msg.toolCall(msg.tool_calls)
} else if (msg.content === null && msg.function_call != null) { } else if (msg.content === null && msg.function_call != null) {
return Msg.funcCall(msg.function_call) return Msg.funcCall(msg.function_call)
} else if (msg.content !== null) { } else if (msg.content !== null && msg.content !== undefined) {
return Msg.assistant(msg.content) return Msg.assistant(msg.content)
} else if (msg.refusal != null) {
return Msg.refusal(msg.refusal)
} else { } else {
// @TODO: probably don't want to error here // @TODO: probably don't want to error here
console.log('Invalid message', msg) console.log('Invalid message', msg)
@ -283,6 +329,10 @@ export namespace Msg {
export function isAssistant(message: Msg): message is Msg.Assistant { export function isAssistant(message: Msg): message is Msg.Assistant {
return message.role === 'assistant' && message.content !== null return message.role === 'assistant' && message.content !== null
} }
/** Check if a message is an assistant refusal message. */
export function isRefusal(message: Msg): message is Msg.Refusal {
return message.role === 'assistant' && message.refusal !== null
}
/** Check if a message is a function call message with arguments. */ /** Check if a message is a function call message with arguments. */
export function isFuncCall(message: Msg): message is Msg.FuncCall { export function isFuncCall(message: Msg): message is Msg.FuncCall {
return message.role === 'assistant' && message.function_call != null return message.role === 'assistant' && message.function_call != null
@ -304,6 +354,7 @@ export namespace Msg {
export function narrow(message: Msg.System): Msg.System export function narrow(message: Msg.System): Msg.System
export function narrow(message: Msg.User): Msg.User export function narrow(message: Msg.User): Msg.User
export function narrow(message: Msg.Assistant): Msg.Assistant export function narrow(message: Msg.Assistant): Msg.Assistant
export function narrow(message: Msg.Assistant): Msg.Refusal
export function narrow(message: Msg.FuncCall): Msg.FuncCall export function narrow(message: Msg.FuncCall): Msg.FuncCall
export function narrow(message: Msg.FuncResult): Msg.FuncResult export function narrow(message: Msg.FuncResult): Msg.FuncResult
export function narrow(message: Msg.ToolCall): Msg.ToolCall export function narrow(message: Msg.ToolCall): Msg.ToolCall
@ -314,6 +365,7 @@ export namespace Msg {
| Msg.System | Msg.System
| Msg.User | Msg.User
| Msg.Assistant | Msg.Assistant
| Msg.Refusal
| Msg.FuncCall | Msg.FuncCall
| Msg.FuncResult | Msg.FuncResult
| Msg.ToolCall | Msg.ToolCall
@ -327,6 +379,9 @@ export namespace Msg {
if (isAssistant(message)) { if (isAssistant(message)) {
return message return message
} }
if (isRefusal(message)) {
return message
}
if (isFuncCall(message)) { if (isFuncCall(message)) {
return message return message
} }

Wyświetl plik

@ -1,7 +1,7 @@
import type { z } from 'zod' import type { z } from 'zod'
import type * as types from './types' import type * as types from './types'
import { safeParseStructuredOutput } from './parse-structured-output' import { parseStructuredOutput } from './parse-structured-output'
import { stringifyForModel } from './utils' import { stringifyForModel } from './utils'
import { zodToJsonSchema } from './zod-to-json-schema' import { zodToJsonSchema } from './zod-to-json-schema'
@ -9,7 +9,6 @@ import { zodToJsonSchema } from './zod-to-json-schema'
* Used to mark schemas so we can support both Zod and custom schemas. * Used to mark schemas so we can support both Zod and custom schemas.
*/ */
export const schemaSymbol = Symbol('agentic.schema') export const schemaSymbol = Symbol('agentic.schema')
export const validatorSymbol = Symbol('agentic.validator')
export type Schema<TData = unknown> = { export type Schema<TData = unknown> = {
/** /**
@ -18,10 +17,18 @@ export type Schema<TData = unknown> = {
readonly jsonSchema: types.JSONSchema readonly jsonSchema: types.JSONSchema
/** /**
* Optional. Validates that the structure of a value matches this schema, * Parses the value, validates that it matches this schema, and returns a
* and returns a typed version of the value if it does. * typed version of the value if it does. Throw an error if the value does
* not match the schema.
*/ */
readonly validate?: types.ValidatorFn<TData> readonly parse: types.ParseFn<TData>
/**
* Parses the value, validates that it matches this schema, and returns a
* typed version of the value if it does. Returns an error message if the
* value does not match the schema, and will never throw an error.
*/
readonly safeParse: types.SafeParseFn<TData>
/** /**
* Used to mark schemas so we can support both Zod and custom schemas. * Used to mark schemas so we can support both Zod and custom schemas.
@ -41,7 +48,7 @@ export function isSchema(value: unknown): value is Schema {
schemaSymbol in value && schemaSymbol in value &&
value[schemaSymbol] === true && value[schemaSymbol] === true &&
'jsonSchema' in value && 'jsonSchema' in value &&
'validate' in value 'parse' in value
) )
} }
@ -59,9 +66,10 @@ export function isZodSchema(value: unknown): value is z.ZodType {
} }
export function asSchema<TData>( export function asSchema<TData>(
schema: z.Schema<TData> | Schema<TData> schema: z.Schema<TData> | Schema<TData>,
opts: { strict?: boolean } = {}
): Schema<TData> { ): Schema<TData> {
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema) return isSchema(schema) ? schema : createSchemaFromZodSchema(schema, opts)
} }
/** /**
@ -70,25 +78,38 @@ export function asSchema<TData>(
export function createSchema<TData = unknown>( export function createSchema<TData = unknown>(
jsonSchema: types.JSONSchema, jsonSchema: types.JSONSchema,
{ {
validate parse = (value) => value as TData,
safeParse
}: { }: {
validate?: types.ValidatorFn<TData> parse?: types.ParseFn<TData>
safeParse?: types.SafeParseFn<TData>
} = {} } = {}
): Schema<TData> { ): Schema<TData> {
safeParse ??= (value: unknown) => {
try {
const result = parse(value)
return { success: true, data: result }
} catch (err: any) {
return { success: false, error: err.message ?? String(err) }
}
}
return { return {
[schemaSymbol]: true, [schemaSymbol]: true,
_type: undefined as TData, _type: undefined as TData,
jsonSchema, jsonSchema,
validate parse,
safeParse
} }
} }
export function createSchemaFromZodSchema<TData>( export function createSchemaFromZodSchema<TData>(
zodSchema: z.Schema<TData> zodSchema: z.Schema<TData>,
opts: { strict?: boolean } = {}
): Schema<TData> { ): Schema<TData> {
return createSchema(zodToJsonSchema(zodSchema), { return createSchema(zodToJsonSchema(zodSchema, opts), {
validate: (value) => { parse: (value) => {
return safeParseStructuredOutput(value, zodSchema) return parseStructuredOutput(value, zodSchema)
} }
}) })
} }

Wyświetl plik

@ -3,13 +3,13 @@ import type { z } from 'zod'
import type { AIFunctionSet } from './ai-function-set' import type { AIFunctionSet } from './ai-function-set'
import type { AIFunctionsProvider } from './fns' import type { AIFunctionsProvider } from './fns'
import type { Msg } from './message' import type { LegacyMsg, Msg } from './message'
export type { Msg } from './message' export type { Msg } from './message'
export type { Schema } from './schema' export type { Schema } from './schema'
export type { KyInstance } from 'ky' export type { KyInstance } from 'ky'
export type { ThrottledFunction } from 'p-throttle' export type { ThrottledFunction } from 'p-throttle'
export type { SetRequired, Simplify } from 'type-fest' export type { SetOptional, SetRequired, Simplify } from 'type-fest'
export type Nullable<T> = T | null export type Nullable<T> = T | null
@ -33,6 +33,13 @@ export interface AIFunctionSpec {
/** JSON schema spec of the function's input parameters */ /** JSON schema spec of the function's input parameters */
parameters: JSONSchema parameters: JSONSchema
/**
* Whether to enable strict schema adherence when generating the function
* parameters. Currently only supported by OpenAI's
* [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs).
*/
strict?: boolean
} }
export interface AIToolSpec { export interface AIToolSpec {
@ -102,7 +109,17 @@ export interface ChatParams {
max_tokens?: number max_tokens?: number
presence_penalty?: number presence_penalty?: number
frequency_penalty?: number frequency_penalty?: number
response_format?: { type: 'text' | 'json_object' } response_format?:
| {
type: 'text'
}
| {
type: 'json_object'
}
| {
type: 'json_schema'
json_schema: ResponseFormatJSONSchema
}
seed?: number seed?: number
stop?: string | null | Array<string> stop?: string | null | Array<string>
temperature?: number temperature?: number
@ -111,10 +128,53 @@ export interface ChatParams {
user?: string user?: string
} }
export type LegacyChatParams = Simplify<
Omit<ChatParams, 'messages'> & { messages: LegacyMsg[] }
>
export interface ResponseFormatJSONSchema {
/**
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
* underscores and dashes, with a maximum length of 64.
*/
name: string
/**
* A description of what the response format is for, used by the model to
* determine how to respond in the format.
*/
description?: string
/**
* The schema for the response format, described as a JSON Schema object.
*/
schema?: JSONSchema
/**
* Whether to enable strict schema adherence when generating the output. If
* set to true, the model will always follow the exact schema defined in the
* `schema` field. Only a subset of JSON Schema is supported when `strict`
* is `true`. Currently only supported by OpenAI's
* [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs).
*/
strict?: boolean
}
/**
* OpenAI has changed some of their types, so instead of trying to support all
* possible types, for these params, just relax them for now.
*/
export type RelaxedChatParams = Simplify<
Omit<ChatParams, 'messages' | 'response_format'> & {
messages: any[]
response_format?: any
}
>
/** An OpenAI-compatible chat completions API */ /** An OpenAI-compatible chat completions API */
export type ChatFn = ( export type ChatFn = (
params: Simplify<SetOptional<ChatParams, 'model'>> params: Simplify<SetOptional<RelaxedChatParams, 'model'>>
) => Promise<{ message: Msg }> ) => Promise<{ message: Msg | LegacyMsg }>
export type AIChainResult = string | Record<string, any> export type AIChainResult = string | Record<string, any>
@ -134,4 +194,5 @@ export type SafeParseResult<TData> =
error: string error: string
} }
export type ValidatorFn<TData> = (value: unknown) => SafeParseResult<TData> export type ParseFn<TData> = (value: unknown) => TData
export type SafeParseFn<TData> = (value: unknown) => SafeParseResult<TData>

Wyświetl plik

@ -1,10 +1,10 @@
import { describe, expect, it } from 'vitest' import { describe, expect, test } from 'vitest'
import { z } from 'zod' import { z } from 'zod'
import { zodToJsonSchema } from './zod-to-json-schema' import { zodToJsonSchema } from './zod-to-json-schema'
describe('zodToJsonSchema', () => { describe('zodToJsonSchema', () => {
it('handles basic objects', () => { test('handles basic objects', () => {
const params = zodToJsonSchema( const params = zodToJsonSchema(
z.object({ z.object({
name: z.string().min(1).describe('Name of the person'), name: z.string().min(1).describe('Name of the person'),
@ -30,7 +30,7 @@ describe('zodToJsonSchema', () => {
}) })
}) })
it('handles enums and unions', () => { test('handles enums and unions', () => {
const params = zodToJsonSchema( const params = zodToJsonSchema(
z.object({ z.object({
name: z.string().min(1).describe('Name of the person'), name: z.string().min(1).describe('Name of the person'),

Wyświetl plik

@ -1,13 +1,23 @@
import type { z } from 'zod' import type { z } from 'zod'
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema' import { zodToJsonSchema as zodToJsonSchemaImpl } from 'openai-zod-to-json-schema'
import type * as types from './types' import type * as types from './types'
import { omit } from './utils' import { omit } from './utils'
/** Generate a JSON Schema from a Zod schema. */ /** Generate a JSON Schema from a Zod schema. */
export function zodToJsonSchema(schema: z.ZodType): types.JSONSchema { export function zodToJsonSchema(
schema: z.ZodType,
{
strict = false
}: {
strict?: boolean
} = {}
): types.JSONSchema {
return omit( return omit(
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }), zodToJsonSchemaImpl(schema, {
$refStrategy: 'none',
openaiStrictMode: strict
}),
'$schema', '$schema',
'default', 'default',
'definitions', 'definitions',

Plik diff jest za duży Load Diff

Wyświetl plik

@ -12,13 +12,7 @@
"dependsOn": ["^clean"] "dependsOn": ["^clean"]
}, },
"test": { "test": {
"dependsOn": [ "dependsOn": ["test:format", "test:lint", "test:typecheck", "test:unit"]
"build",
"test:format",
"test:lint",
"test:typecheck",
"test:unit"
]
}, },
"test:lint": { "test:lint": {
"dependsOn": ["^test:lint"], "dependsOn": ["^test:lint"],