kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
Merge pull request #658 from transitive-bullshit/feature/openai-structured-outputs
commit
934c4b3536
|
@ -10,6 +10,7 @@ jobs:
|
||||||
fail-fast: true
|
fail-fast: true
|
||||||
matrix:
|
matrix:
|
||||||
node-version:
|
node-version:
|
||||||
|
- 18
|
||||||
- 20
|
- 20
|
||||||
- 21
|
- 21
|
||||||
- 22
|
- 22
|
||||||
|
@ -18,33 +19,23 @@ jobs:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install pnpm
|
||||||
|
uses: pnpm/action-setup@v4
|
||||||
|
with:
|
||||||
|
version: 9.7.0
|
||||||
|
run_install: false
|
||||||
|
|
||||||
- name: Install Node.js
|
- name: Install Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: ${{ matrix.node-version }}
|
node-version: ${{ matrix.node-version }}
|
||||||
|
cache: 'pnpm'
|
||||||
- name: Install pnpm
|
|
||||||
uses: pnpm/action-setup@v3
|
|
||||||
id: pnpm-install
|
|
||||||
with:
|
|
||||||
version: 9.6.0
|
|
||||||
run_install: false
|
|
||||||
|
|
||||||
- name: Get pnpm store directory
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Setup pnpm cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ${{ env.STORE_PATH }}
|
|
||||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-pnpm-store-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile --strict-peer-dependencies
|
||||||
|
|
||||||
|
- name: Run build
|
||||||
|
run: pnpm build
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
run: pnpm run test
|
run: pnpm test
|
||||||
|
|
|
@ -16,6 +16,7 @@ async function main() {
|
||||||
})
|
})
|
||||||
|
|
||||||
const chain = createAIChain({
|
const chain = createAIChain({
|
||||||
|
name: 'search_news',
|
||||||
chatFn: chatModel.run.bind(chatModel),
|
chatFn: chatModel.run.bind(chatModel),
|
||||||
tools: [perigon.functions.pick('search_news_stories'), serper],
|
tools: [perigon.functions.pick('search_news_stories'), serper],
|
||||||
params: {
|
params: {
|
||||||
|
|
|
@ -12,6 +12,7 @@ async function main() {
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await extractObject({
|
const result = await extractObject({
|
||||||
|
name: 'extract-user',
|
||||||
chatFn: chatModel.run.bind(chatModel),
|
chatFn: chatModel.run.bind(chatModel),
|
||||||
params: {
|
params: {
|
||||||
messages: [
|
messages: [
|
||||||
|
@ -25,7 +26,8 @@ async function main() {
|
||||||
name: z.string(),
|
name: z.string(),
|
||||||
age: z.number(),
|
age: z.number(),
|
||||||
location: z.string().optional()
|
location: z.string().optional()
|
||||||
})
|
}),
|
||||||
|
strict: true
|
||||||
})
|
})
|
||||||
|
|
||||||
console.log(result)
|
console.log(result)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
"type": "git",
|
"type": "git",
|
||||||
"url": "git+https://github.com/transitive-bullshit/agentic.git"
|
"url": "git+https://github.com/transitive-bullshit/agentic.git"
|
||||||
},
|
},
|
||||||
"packageManager": "pnpm@9.6.0",
|
"packageManager": "pnpm@9.7.0",
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=18"
|
"node": ">=18"
|
||||||
},
|
},
|
||||||
|
@ -26,6 +26,7 @@
|
||||||
"release:build": "run-s build",
|
"release:build": "run-s build",
|
||||||
"release:version": "changeset version",
|
"release:version": "changeset version",
|
||||||
"release:publish": "changeset publish",
|
"release:publish": "changeset publish",
|
||||||
|
"pretest": "run-s build",
|
||||||
"precommit": "lint-staged",
|
"precommit": "lint-staged",
|
||||||
"preinstall": "npx only-allow pnpm",
|
"preinstall": "npx only-allow pnpm",
|
||||||
"prepare": "husky"
|
"prepare": "husky"
|
||||||
|
@ -45,7 +46,7 @@
|
||||||
"prettier": "^3.3.3",
|
"prettier": "^3.3.3",
|
||||||
"tsup": "^8.2.4",
|
"tsup": "^8.2.4",
|
||||||
"tsx": "^4.16.5",
|
"tsx": "^4.16.5",
|
||||||
"turbo": "^2.0.11",
|
"turbo": "^2.0.12",
|
||||||
"typescript": "^5.5.4",
|
"typescript": "^5.5.4",
|
||||||
"vitest": "2.0.5",
|
"vitest": "2.0.5",
|
||||||
"zod": "^3.23.8"
|
"zod": "^3.23.8"
|
||||||
|
|
|
@ -40,11 +40,11 @@
|
||||||
"jsonrepair": "^3.6.1",
|
"jsonrepair": "^3.6.1",
|
||||||
"ky": "^1.5.0",
|
"ky": "^1.5.0",
|
||||||
"normalize-url": "^8.0.1",
|
"normalize-url": "^8.0.1",
|
||||||
|
"openai-zod-to-json-schema": "^1.0.0",
|
||||||
"p-map": "^7.0.2",
|
"p-map": "^7.0.2",
|
||||||
"p-throttle": "^6.1.0",
|
"p-throttle": "^6.1.0",
|
||||||
"quick-lru": "^7.0.0",
|
"quick-lru": "^7.0.0",
|
||||||
"type-fest": "^4.21.0",
|
"type-fest": "^4.21.0",
|
||||||
"zod-to-json-schema": "^3.23.2",
|
|
||||||
"zod-validation-error": "^3.3.0"
|
"zod-validation-error": "^3.3.0"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
|
@ -52,7 +52,7 @@
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@agentic/tsconfig": "workspace:*",
|
"@agentic/tsconfig": "workspace:*",
|
||||||
"openai-fetch": "^2.0.4"
|
"openai-fetch": "^3.0.0"
|
||||||
},
|
},
|
||||||
"publishConfig": {
|
"publishConfig": {
|
||||||
"access": "public"
|
"access": "public"
|
||||||
|
|
|
@ -10,15 +10,42 @@ import { asSchema, augmentSystemMessageWithJsonSchema } from './schema'
|
||||||
import { getErrorMessage } from './utils'
|
import { getErrorMessage } from './utils'
|
||||||
|
|
||||||
export type AIChainParams<Result extends types.AIChainResult = string> = {
|
export type AIChainParams<Result extends types.AIChainResult = string> = {
|
||||||
|
/** Name of the chain */
|
||||||
|
name: string
|
||||||
|
|
||||||
|
/** Chat completions function */
|
||||||
chatFn: types.ChatFn
|
chatFn: types.ChatFn
|
||||||
|
|
||||||
|
/** Description of the chain */
|
||||||
|
description?: string
|
||||||
|
|
||||||
|
/** Optional chat completion params */
|
||||||
params?: types.Simplify<
|
params?: types.Simplify<
|
||||||
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
|
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
|
||||||
>
|
>
|
||||||
|
|
||||||
|
/** Optional tools */
|
||||||
tools?: types.AIFunctionLike[]
|
tools?: types.AIFunctionLike[]
|
||||||
|
|
||||||
|
/** Optional response schema */
|
||||||
schema?: z.ZodType<Result> | types.Schema<Result>
|
schema?: z.ZodType<Result> | types.Schema<Result>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether or not the response schema should be treated as strict for
|
||||||
|
* constrained structured output generation.
|
||||||
|
*/
|
||||||
|
strict?: boolean
|
||||||
|
|
||||||
|
/** Max number of LLM calls to allow */
|
||||||
maxCalls?: number
|
maxCalls?: number
|
||||||
|
|
||||||
|
/** Max number of retries to allow */
|
||||||
maxRetries?: number
|
maxRetries?: number
|
||||||
|
|
||||||
|
/** Max concurrency when invoking tool calls */
|
||||||
toolCallConcurrency?: number
|
toolCallConcurrency?: number
|
||||||
|
|
||||||
|
/** Whether or not to inject the schema into the context */
|
||||||
injectSchemaIntoSystemMessage?: boolean
|
injectSchemaIntoSystemMessage?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +65,8 @@ export type AIChainParams<Result extends types.AIChainResult = string> = {
|
||||||
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
|
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
|
||||||
*/
|
*/
|
||||||
export function createAIChain<Result extends types.AIChainResult = string>({
|
export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
|
name,
|
||||||
|
description,
|
||||||
chatFn,
|
chatFn,
|
||||||
params,
|
params,
|
||||||
schema: rawSchema,
|
schema: rawSchema,
|
||||||
|
@ -45,13 +74,28 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
maxCalls = 5,
|
maxCalls = 5,
|
||||||
maxRetries = 2,
|
maxRetries = 2,
|
||||||
toolCallConcurrency = 8,
|
toolCallConcurrency = 8,
|
||||||
injectSchemaIntoSystemMessage = true
|
injectSchemaIntoSystemMessage = false,
|
||||||
|
strict = false
|
||||||
}: AIChainParams<Result>): types.AIChain<Result> {
|
}: AIChainParams<Result>): types.AIChain<Result> {
|
||||||
const functionSet = new AIFunctionSet(tools)
|
const functionSet = new AIFunctionSet(tools)
|
||||||
|
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
|
||||||
|
|
||||||
|
// TODO: support custom stopping criteria (like setting a flag in a tool call)
|
||||||
|
|
||||||
const defaultParams: Partial<types.ChatParams> | undefined =
|
const defaultParams: Partial<types.ChatParams> | undefined =
|
||||||
rawSchema && !functionSet.size
|
schema && !functionSet.size
|
||||||
? {
|
? {
|
||||||
response_format: { type: 'json_object' }
|
response_format: strict
|
||||||
|
? {
|
||||||
|
type: 'json_schema',
|
||||||
|
json_schema: {
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
strict,
|
||||||
|
schema: schema.jsonSchema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
: { type: 'json_object' }
|
||||||
}
|
}
|
||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
|
@ -77,8 +121,6 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
throw new Error('AIChain error: "messages" is empty')
|
throw new Error('AIChain error: "messages" is empty')
|
||||||
}
|
}
|
||||||
|
|
||||||
const schema = rawSchema ? asSchema(rawSchema) : undefined
|
|
||||||
|
|
||||||
if (schema && injectSchemaIntoSystemMessage) {
|
if (schema && injectSchemaIntoSystemMessage) {
|
||||||
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
|
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
|
||||||
const lastSystemMessageContent =
|
const lastSystemMessageContent =
|
||||||
|
@ -101,6 +143,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
|
|
||||||
do {
|
do {
|
||||||
++numCalls
|
++numCalls
|
||||||
|
|
||||||
const response = await chatFn({
|
const response = await chatFn({
|
||||||
...modelParams,
|
...modelParams,
|
||||||
messages,
|
messages,
|
||||||
|
@ -149,15 +192,11 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
throw new AbortError(
|
throw new AbortError(
|
||||||
'Function calls are not supported; expected tool call'
|
'Function calls are not supported; expected tool call'
|
||||||
)
|
)
|
||||||
|
} else if (Msg.isRefusal(message)) {
|
||||||
|
throw new AbortError(`Model refusal: ${message.refusal}`)
|
||||||
} else if (Msg.isAssistant(message)) {
|
} else if (Msg.isAssistant(message)) {
|
||||||
if (schema && schema.validate) {
|
if (schema) {
|
||||||
const result = schema.validate(message.content)
|
return schema.parse(message.content)
|
||||||
|
|
||||||
if (result.success) {
|
|
||||||
return result.data
|
|
||||||
}
|
|
||||||
|
|
||||||
throw new Error(result.error)
|
|
||||||
} else {
|
} else {
|
||||||
return message.content as Result
|
return message.content as Result
|
||||||
}
|
}
|
||||||
|
@ -169,6 +208,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
console.warn(`Chain "${name}" error:`, err.message)
|
||||||
|
|
||||||
messages.push(
|
messages.push(
|
||||||
Msg.user(
|
Msg.user(
|
||||||
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
|
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
|
||||||
|
@ -177,7 +218,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
|
|
||||||
if (numErrors > maxRetries) {
|
if (numErrors > maxRetries) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Chain failed after ${numErrors} errors: ${err.message}`,
|
`Chain ${name} failed after ${numErrors} errors: ${err.message}`,
|
||||||
{
|
{
|
||||||
cause: err
|
cause: err
|
||||||
}
|
}
|
||||||
|
@ -186,6 +227,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
}
|
}
|
||||||
} while (numCalls < maxCalls)
|
} while (numCalls < maxCalls)
|
||||||
|
|
||||||
throw new Error(`Chain aborted after reaching max ${maxCalls} calls`)
|
throw new Error(
|
||||||
|
`Chain "${name}" aborted after reaching max ${maxCalls} calls`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { describe, expect, it } from 'vitest'
|
import { describe, expect, test } from 'vitest'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
import { createAIFunction } from './create-ai-function'
|
import { createAIFunction } from './create-ai-function'
|
||||||
|
@ -18,7 +18,7 @@ const fullName = createAIFunction(
|
||||||
)
|
)
|
||||||
|
|
||||||
describe('createAIFunction()', () => {
|
describe('createAIFunction()', () => {
|
||||||
it('exposes OpenAI function calling spec', () => {
|
test('exposes OpenAI function calling spec', () => {
|
||||||
expect(fullName.spec.name).toEqual('fullName')
|
expect(fullName.spec.name).toEqual('fullName')
|
||||||
expect(fullName.spec.description).toEqual(
|
expect(fullName.spec.description).toEqual(
|
||||||
'Returns the full name of a person.'
|
'Returns the full name of a person.'
|
||||||
|
@ -34,7 +34,7 @@ describe('createAIFunction()', () => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('executes the function', async () => {
|
test('executes the function', async () => {
|
||||||
expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual(
|
expect(await fullName('{"first": "John", "last": "Doe"}')).toEqual(
|
||||||
'John Doe'
|
'John Doe'
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,8 +20,13 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
|
||||||
name: string
|
name: string
|
||||||
/** Description of the function. */
|
/** Description of the function. */
|
||||||
description?: string
|
description?: string
|
||||||
/** Zod schema for the arguments string. */
|
/** Zod schema for the function parameters. */
|
||||||
inputSchema: InputSchema
|
inputSchema: InputSchema
|
||||||
|
/**
|
||||||
|
* Whether or not to enable structured output generation based on the given
|
||||||
|
* zod schema.
|
||||||
|
*/
|
||||||
|
strict?: boolean
|
||||||
},
|
},
|
||||||
/** Implementation of the function to call with the parsed arguments. */
|
/** Implementation of the function to call with the parsed arguments. */
|
||||||
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
|
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
|
||||||
|
@ -59,12 +64,15 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
|
||||||
return implementation(parsedInput)
|
return implementation(parsedInput)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const strict = !!spec.strict
|
||||||
|
|
||||||
aiFunction.inputSchema = spec.inputSchema
|
aiFunction.inputSchema = spec.inputSchema
|
||||||
aiFunction.parseInput = parseInput
|
aiFunction.parseInput = parseInput
|
||||||
aiFunction.spec = {
|
aiFunction.spec = {
|
||||||
name: spec.name,
|
name: spec.name,
|
||||||
description: spec.description?.trim() ?? '',
|
description: spec.description?.trim() ?? '',
|
||||||
parameters: zodToJsonSchema(spec.inputSchema)
|
parameters: zodToJsonSchema(spec.inputSchema, { strict }),
|
||||||
|
strict
|
||||||
}
|
}
|
||||||
aiFunction.impl = implementation
|
aiFunction.impl = implementation
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ export interface PrivateAIFunctionMetadata {
|
||||||
description: string
|
description: string
|
||||||
inputSchema: z.AnyZodObject
|
inputSchema: z.AnyZodObject
|
||||||
methodName: string
|
methodName: string
|
||||||
|
strict?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
// Polyfill for `Symbol.metadata`
|
// Polyfill for `Symbol.metadata`
|
||||||
|
@ -69,11 +70,13 @@ export function aiFunction<
|
||||||
>({
|
>({
|
||||||
name,
|
name,
|
||||||
description,
|
description,
|
||||||
inputSchema
|
inputSchema,
|
||||||
|
strict
|
||||||
}: {
|
}: {
|
||||||
name?: string
|
name?: string
|
||||||
description: string
|
description: string
|
||||||
inputSchema: InputSchema
|
inputSchema: InputSchema
|
||||||
|
strict?: boolean
|
||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
_targetMethod: (
|
_targetMethod: (
|
||||||
|
@ -99,7 +102,8 @@ export function aiFunction<
|
||||||
name: name ?? methodName,
|
name: name ?? methodName,
|
||||||
description,
|
description,
|
||||||
inputSchema,
|
inputSchema,
|
||||||
methodName
|
methodName,
|
||||||
|
strict
|
||||||
})
|
})
|
||||||
|
|
||||||
context.addInitializer(function () {
|
context.addInitializer(function () {
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import type * as OpenAI from 'openai-fetch'
|
import type * as OpenAI from 'openai-fetch'
|
||||||
import { describe, expect, expectTypeOf, it } from 'vitest'
|
import { describe, expect, expectTypeOf, test } from 'vitest'
|
||||||
|
|
||||||
import type * as types from './types'
|
import type * as types from './types'
|
||||||
import { Msg } from './message'
|
import { Msg } from './message'
|
||||||
|
|
||||||
describe('Msg', () => {
|
describe('Msg', () => {
|
||||||
it('creates a message and fixes indentation', () => {
|
test('creates a message and fixes indentation', () => {
|
||||||
const msgContent = `
|
const msgContent = `
|
||||||
Hello, World!
|
Hello, World!
|
||||||
`
|
`
|
||||||
|
@ -14,7 +14,7 @@ describe('Msg', () => {
|
||||||
expect(msg.content).toEqual('Hello, World!')
|
expect(msg.content).toEqual('Hello, World!')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('supports disabling indentation fixing', () => {
|
test('supports disabling indentation fixing', () => {
|
||||||
const msgContent = `
|
const msgContent = `
|
||||||
Hello, World!
|
Hello, World!
|
||||||
`
|
`
|
||||||
|
@ -22,7 +22,7 @@ describe('Msg', () => {
|
||||||
expect(msg.content).toEqual('\n Hello, World!\n ')
|
expect(msg.content).toEqual('\n Hello, World!\n ')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles tool calls request', () => {
|
test('handles tool calls request', () => {
|
||||||
const msg = Msg.toolCall([
|
const msg = Msg.toolCall([
|
||||||
{
|
{
|
||||||
id: 'fake-tool-call-id',
|
id: 'fake-tool-call-id',
|
||||||
|
@ -37,13 +37,13 @@ describe('Msg', () => {
|
||||||
expect(Msg.isToolCall(msg)).toBe(true)
|
expect(Msg.isToolCall(msg)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles tool call response', () => {
|
test('handles tool call response', () => {
|
||||||
const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id')
|
const msg = Msg.toolResult('Hello, World!', 'fake-tool-call-id')
|
||||||
expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>()
|
expectTypeOf(msg).toMatchTypeOf<types.Msg.ToolResult>()
|
||||||
expect(Msg.isToolResult(msg)).toBe(true)
|
expect(Msg.isToolResult(msg)).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('prompt message types should interop with openai-fetch message types', () => {
|
test('prompt message types should interop with openai-fetch message types', () => {
|
||||||
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>()
|
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<types.Msg>()
|
||||||
expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>()
|
expectTypeOf({} as types.Msg).toMatchTypeOf<OpenAI.ChatMessage>()
|
||||||
expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>()
|
expectTypeOf({} as types.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>()
|
||||||
|
|
|
@ -7,10 +7,17 @@ import { cleanStringForModel, stringifyForModel } from './utils'
|
||||||
*/
|
*/
|
||||||
export interface Msg {
|
export interface Msg {
|
||||||
/**
|
/**
|
||||||
* The contents of the message. `content` is required for all messages, and
|
* The contents of the message. `content` may be null for assistant messages
|
||||||
* may be null for assistant messages with function calls.
|
* with function calls or `undefined` for assistant messages if a `refusal`
|
||||||
|
* was given by the model.
|
||||||
*/
|
*/
|
||||||
content: string | null
|
content?: string | null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The reason the model refused to generate this message or `undefined` if the
|
||||||
|
* message was generated successfully.
|
||||||
|
*/
|
||||||
|
refusal?: string | null
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The role of the messages author. One of `system`, `user`, `assistant`,
|
* The role of the messages author. One of `system`, `user`, `assistant`,
|
||||||
|
@ -43,6 +50,15 @@ export interface Msg {
|
||||||
name?: string
|
name?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface LegacyMsg {
|
||||||
|
content: string | null
|
||||||
|
role: Msg.Role
|
||||||
|
function_call?: Msg.Call.Function
|
||||||
|
tool_calls?: Msg.Call.Tool[]
|
||||||
|
tool_call_id?: string
|
||||||
|
name?: string
|
||||||
|
}
|
||||||
|
|
||||||
/** Narrowed OpenAI Message types. */
|
/** Narrowed OpenAI Message types. */
|
||||||
export namespace Msg {
|
export namespace Msg {
|
||||||
/** Possible roles for a message. */
|
/** Possible roles for a message. */
|
||||||
|
@ -98,6 +114,13 @@ export namespace Msg {
|
||||||
content: string
|
content: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Message with refusal reason from the assistant. */
|
||||||
|
export type Refusal = {
|
||||||
|
role: 'assistant'
|
||||||
|
name?: string
|
||||||
|
refusal: string
|
||||||
|
}
|
||||||
|
|
||||||
/** Message with arguments to call a function. */
|
/** Message with arguments to call a function. */
|
||||||
export type FuncCall = {
|
export type FuncCall = {
|
||||||
role: 'assistant'
|
role: 'assistant'
|
||||||
|
@ -185,6 +208,27 @@ export namespace Msg {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an assistant refusal message. Cleans indentation and newlines by
|
||||||
|
* default.
|
||||||
|
*/
|
||||||
|
export function refusal(
|
||||||
|
refusal: string,
|
||||||
|
opts?: {
|
||||||
|
/** Custom name for the message. */
|
||||||
|
name?: string
|
||||||
|
/** Whether to clean extra newlines and indentation. Defaults to true. */
|
||||||
|
cleanRefusal?: boolean
|
||||||
|
}
|
||||||
|
): Msg.Refusal {
|
||||||
|
const { name, cleanRefusal = true } = opts ?? {}
|
||||||
|
return {
|
||||||
|
role: 'assistant',
|
||||||
|
refusal: cleanRefusal ? cleanStringForModel(refusal) : refusal,
|
||||||
|
...(name ? { name } : {})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Create a function call message with argumets. */
|
/** Create a function call message with argumets. */
|
||||||
export function funcCall(
|
export function funcCall(
|
||||||
function_call: {
|
function_call: {
|
||||||
|
@ -249,7 +293,7 @@ export namespace Msg {
|
||||||
// @TODO
|
// @TODO
|
||||||
response: any
|
response: any
|
||||||
// response: ChatModel.EnrichedResponse
|
// response: ChatModel.EnrichedResponse
|
||||||
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
|
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
|
||||||
const msg = response.choices[0].message as Msg
|
const msg = response.choices[0].message as Msg
|
||||||
return narrowResponseMessage(msg)
|
return narrowResponseMessage(msg)
|
||||||
}
|
}
|
||||||
|
@ -257,13 +301,15 @@ export namespace Msg {
|
||||||
/** Narrow a message received from the API. It only responds with role=assistant */
|
/** Narrow a message received from the API. It only responds with role=assistant */
|
||||||
export function narrowResponseMessage(
|
export function narrowResponseMessage(
|
||||||
msg: Msg
|
msg: Msg
|
||||||
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
|
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
|
||||||
if (msg.content === null && msg.tool_calls != null) {
|
if (msg.content === null && msg.tool_calls != null) {
|
||||||
return Msg.toolCall(msg.tool_calls)
|
return Msg.toolCall(msg.tool_calls)
|
||||||
} else if (msg.content === null && msg.function_call != null) {
|
} else if (msg.content === null && msg.function_call != null) {
|
||||||
return Msg.funcCall(msg.function_call)
|
return Msg.funcCall(msg.function_call)
|
||||||
} else if (msg.content !== null) {
|
} else if (msg.content !== null && msg.content !== undefined) {
|
||||||
return Msg.assistant(msg.content)
|
return Msg.assistant(msg.content)
|
||||||
|
} else if (msg.refusal != null) {
|
||||||
|
return Msg.refusal(msg.refusal)
|
||||||
} else {
|
} else {
|
||||||
// @TODO: probably don't want to error here
|
// @TODO: probably don't want to error here
|
||||||
console.log('Invalid message', msg)
|
console.log('Invalid message', msg)
|
||||||
|
@ -283,6 +329,10 @@ export namespace Msg {
|
||||||
export function isAssistant(message: Msg): message is Msg.Assistant {
|
export function isAssistant(message: Msg): message is Msg.Assistant {
|
||||||
return message.role === 'assistant' && message.content !== null
|
return message.role === 'assistant' && message.content !== null
|
||||||
}
|
}
|
||||||
|
/** Check if a message is an assistant refusal message. */
|
||||||
|
export function isRefusal(message: Msg): message is Msg.Refusal {
|
||||||
|
return message.role === 'assistant' && message.refusal !== null
|
||||||
|
}
|
||||||
/** Check if a message is a function call message with arguments. */
|
/** Check if a message is a function call message with arguments. */
|
||||||
export function isFuncCall(message: Msg): message is Msg.FuncCall {
|
export function isFuncCall(message: Msg): message is Msg.FuncCall {
|
||||||
return message.role === 'assistant' && message.function_call != null
|
return message.role === 'assistant' && message.function_call != null
|
||||||
|
@ -304,6 +354,7 @@ export namespace Msg {
|
||||||
export function narrow(message: Msg.System): Msg.System
|
export function narrow(message: Msg.System): Msg.System
|
||||||
export function narrow(message: Msg.User): Msg.User
|
export function narrow(message: Msg.User): Msg.User
|
||||||
export function narrow(message: Msg.Assistant): Msg.Assistant
|
export function narrow(message: Msg.Assistant): Msg.Assistant
|
||||||
|
export function narrow(message: Msg.Assistant): Msg.Refusal
|
||||||
export function narrow(message: Msg.FuncCall): Msg.FuncCall
|
export function narrow(message: Msg.FuncCall): Msg.FuncCall
|
||||||
export function narrow(message: Msg.FuncResult): Msg.FuncResult
|
export function narrow(message: Msg.FuncResult): Msg.FuncResult
|
||||||
export function narrow(message: Msg.ToolCall): Msg.ToolCall
|
export function narrow(message: Msg.ToolCall): Msg.ToolCall
|
||||||
|
@ -314,6 +365,7 @@ export namespace Msg {
|
||||||
| Msg.System
|
| Msg.System
|
||||||
| Msg.User
|
| Msg.User
|
||||||
| Msg.Assistant
|
| Msg.Assistant
|
||||||
|
| Msg.Refusal
|
||||||
| Msg.FuncCall
|
| Msg.FuncCall
|
||||||
| Msg.FuncResult
|
| Msg.FuncResult
|
||||||
| Msg.ToolCall
|
| Msg.ToolCall
|
||||||
|
@ -327,6 +379,9 @@ export namespace Msg {
|
||||||
if (isAssistant(message)) {
|
if (isAssistant(message)) {
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
if (isRefusal(message)) {
|
||||||
|
return message
|
||||||
|
}
|
||||||
if (isFuncCall(message)) {
|
if (isFuncCall(message)) {
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import type { z } from 'zod'
|
import type { z } from 'zod'
|
||||||
|
|
||||||
import type * as types from './types'
|
import type * as types from './types'
|
||||||
import { safeParseStructuredOutput } from './parse-structured-output'
|
import { parseStructuredOutput } from './parse-structured-output'
|
||||||
import { stringifyForModel } from './utils'
|
import { stringifyForModel } from './utils'
|
||||||
import { zodToJsonSchema } from './zod-to-json-schema'
|
import { zodToJsonSchema } from './zod-to-json-schema'
|
||||||
|
|
||||||
|
@ -9,7 +9,6 @@ import { zodToJsonSchema } from './zod-to-json-schema'
|
||||||
* Used to mark schemas so we can support both Zod and custom schemas.
|
* Used to mark schemas so we can support both Zod and custom schemas.
|
||||||
*/
|
*/
|
||||||
export const schemaSymbol = Symbol('agentic.schema')
|
export const schemaSymbol = Symbol('agentic.schema')
|
||||||
export const validatorSymbol = Symbol('agentic.validator')
|
|
||||||
|
|
||||||
export type Schema<TData = unknown> = {
|
export type Schema<TData = unknown> = {
|
||||||
/**
|
/**
|
||||||
|
@ -18,10 +17,18 @@ export type Schema<TData = unknown> = {
|
||||||
readonly jsonSchema: types.JSONSchema
|
readonly jsonSchema: types.JSONSchema
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optional. Validates that the structure of a value matches this schema,
|
* Parses the value, validates that it matches this schema, and returns a
|
||||||
* and returns a typed version of the value if it does.
|
* typed version of the value if it does. Throw an error if the value does
|
||||||
|
* not match the schema.
|
||||||
*/
|
*/
|
||||||
readonly validate?: types.ValidatorFn<TData>
|
readonly parse: types.ParseFn<TData>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses the value, validates that it matches this schema, and returns a
|
||||||
|
* typed version of the value if it does. Returns an error message if the
|
||||||
|
* value does not match the schema, and will never throw an error.
|
||||||
|
*/
|
||||||
|
readonly safeParse: types.SafeParseFn<TData>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Used to mark schemas so we can support both Zod and custom schemas.
|
* Used to mark schemas so we can support both Zod and custom schemas.
|
||||||
|
@ -41,7 +48,7 @@ export function isSchema(value: unknown): value is Schema {
|
||||||
schemaSymbol in value &&
|
schemaSymbol in value &&
|
||||||
value[schemaSymbol] === true &&
|
value[schemaSymbol] === true &&
|
||||||
'jsonSchema' in value &&
|
'jsonSchema' in value &&
|
||||||
'validate' in value
|
'parse' in value
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,9 +66,10 @@ export function isZodSchema(value: unknown): value is z.ZodType {
|
||||||
}
|
}
|
||||||
|
|
||||||
export function asSchema<TData>(
|
export function asSchema<TData>(
|
||||||
schema: z.Schema<TData> | Schema<TData>
|
schema: z.Schema<TData> | Schema<TData>,
|
||||||
|
opts: { strict?: boolean } = {}
|
||||||
): Schema<TData> {
|
): Schema<TData> {
|
||||||
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema)
|
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -70,25 +78,38 @@ export function asSchema<TData>(
|
||||||
export function createSchema<TData = unknown>(
|
export function createSchema<TData = unknown>(
|
||||||
jsonSchema: types.JSONSchema,
|
jsonSchema: types.JSONSchema,
|
||||||
{
|
{
|
||||||
validate
|
parse = (value) => value as TData,
|
||||||
|
safeParse
|
||||||
}: {
|
}: {
|
||||||
validate?: types.ValidatorFn<TData>
|
parse?: types.ParseFn<TData>
|
||||||
|
safeParse?: types.SafeParseFn<TData>
|
||||||
} = {}
|
} = {}
|
||||||
): Schema<TData> {
|
): Schema<TData> {
|
||||||
|
safeParse ??= (value: unknown) => {
|
||||||
|
try {
|
||||||
|
const result = parse(value)
|
||||||
|
return { success: true, data: result }
|
||||||
|
} catch (err: any) {
|
||||||
|
return { success: false, error: err.message ?? String(err) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
[schemaSymbol]: true,
|
[schemaSymbol]: true,
|
||||||
_type: undefined as TData,
|
_type: undefined as TData,
|
||||||
jsonSchema,
|
jsonSchema,
|
||||||
validate
|
parse,
|
||||||
|
safeParse
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createSchemaFromZodSchema<TData>(
|
export function createSchemaFromZodSchema<TData>(
|
||||||
zodSchema: z.Schema<TData>
|
zodSchema: z.Schema<TData>,
|
||||||
|
opts: { strict?: boolean } = {}
|
||||||
): Schema<TData> {
|
): Schema<TData> {
|
||||||
return createSchema(zodToJsonSchema(zodSchema), {
|
return createSchema(zodToJsonSchema(zodSchema, opts), {
|
||||||
validate: (value) => {
|
parse: (value) => {
|
||||||
return safeParseStructuredOutput(value, zodSchema)
|
return parseStructuredOutput(value, zodSchema)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,13 +3,13 @@ import type { z } from 'zod'
|
||||||
|
|
||||||
import type { AIFunctionSet } from './ai-function-set'
|
import type { AIFunctionSet } from './ai-function-set'
|
||||||
import type { AIFunctionsProvider } from './fns'
|
import type { AIFunctionsProvider } from './fns'
|
||||||
import type { Msg } from './message'
|
import type { LegacyMsg, Msg } from './message'
|
||||||
|
|
||||||
export type { Msg } from './message'
|
export type { Msg } from './message'
|
||||||
export type { Schema } from './schema'
|
export type { Schema } from './schema'
|
||||||
export type { KyInstance } from 'ky'
|
export type { KyInstance } from 'ky'
|
||||||
export type { ThrottledFunction } from 'p-throttle'
|
export type { ThrottledFunction } from 'p-throttle'
|
||||||
export type { SetRequired, Simplify } from 'type-fest'
|
export type { SetOptional, SetRequired, Simplify } from 'type-fest'
|
||||||
|
|
||||||
export type Nullable<T> = T | null
|
export type Nullable<T> = T | null
|
||||||
|
|
||||||
|
@ -33,6 +33,13 @@ export interface AIFunctionSpec {
|
||||||
|
|
||||||
/** JSON schema spec of the function's input parameters */
|
/** JSON schema spec of the function's input parameters */
|
||||||
parameters: JSONSchema
|
parameters: JSONSchema
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to enable strict schema adherence when generating the function
|
||||||
|
* parameters. Currently only supported by OpenAI's
|
||||||
|
* [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs).
|
||||||
|
*/
|
||||||
|
strict?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AIToolSpec {
|
export interface AIToolSpec {
|
||||||
|
@ -102,7 +109,17 @@ export interface ChatParams {
|
||||||
max_tokens?: number
|
max_tokens?: number
|
||||||
presence_penalty?: number
|
presence_penalty?: number
|
||||||
frequency_penalty?: number
|
frequency_penalty?: number
|
||||||
response_format?: { type: 'text' | 'json_object' }
|
response_format?:
|
||||||
|
| {
|
||||||
|
type: 'text'
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: 'json_object'
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: 'json_schema'
|
||||||
|
json_schema: ResponseFormatJSONSchema
|
||||||
|
}
|
||||||
seed?: number
|
seed?: number
|
||||||
stop?: string | null | Array<string>
|
stop?: string | null | Array<string>
|
||||||
temperature?: number
|
temperature?: number
|
||||||
|
@ -111,10 +128,53 @@ export interface ChatParams {
|
||||||
user?: string
|
user?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type LegacyChatParams = Simplify<
|
||||||
|
Omit<ChatParams, 'messages'> & { messages: LegacyMsg[] }
|
||||||
|
>
|
||||||
|
|
||||||
|
export interface ResponseFormatJSONSchema {
|
||||||
|
/**
|
||||||
|
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
|
||||||
|
* underscores and dashes, with a maximum length of 64.
|
||||||
|
*/
|
||||||
|
name: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A description of what the response format is for, used by the model to
|
||||||
|
* determine how to respond in the format.
|
||||||
|
*/
|
||||||
|
description?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The schema for the response format, described as a JSON Schema object.
|
||||||
|
*/
|
||||||
|
schema?: JSONSchema
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to enable strict schema adherence when generating the output. If
|
||||||
|
* set to true, the model will always follow the exact schema defined in the
|
||||||
|
* `schema` field. Only a subset of JSON Schema is supported when `strict`
|
||||||
|
* is `true`. Currently only supported by OpenAI's
|
||||||
|
* [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs).
|
||||||
|
*/
|
||||||
|
strict?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpenAI has changed some of their types, so instead of trying to support all
|
||||||
|
* possible types, for these params, just relax them for now.
|
||||||
|
*/
|
||||||
|
export type RelaxedChatParams = Simplify<
|
||||||
|
Omit<ChatParams, 'messages' | 'response_format'> & {
|
||||||
|
messages: any[]
|
||||||
|
response_format?: any
|
||||||
|
}
|
||||||
|
>
|
||||||
|
|
||||||
/** An OpenAI-compatible chat completions API */
|
/** An OpenAI-compatible chat completions API */
|
||||||
export type ChatFn = (
|
export type ChatFn = (
|
||||||
params: Simplify<SetOptional<ChatParams, 'model'>>
|
params: Simplify<SetOptional<RelaxedChatParams, 'model'>>
|
||||||
) => Promise<{ message: Msg }>
|
) => Promise<{ message: Msg | LegacyMsg }>
|
||||||
|
|
||||||
export type AIChainResult = string | Record<string, any>
|
export type AIChainResult = string | Record<string, any>
|
||||||
|
|
||||||
|
@ -134,4 +194,5 @@ export type SafeParseResult<TData> =
|
||||||
error: string
|
error: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ValidatorFn<TData> = (value: unknown) => SafeParseResult<TData>
|
export type ParseFn<TData> = (value: unknown) => TData
|
||||||
|
export type SafeParseFn<TData> = (value: unknown) => SafeParseResult<TData>
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import { describe, expect, it } from 'vitest'
|
import { describe, expect, test } from 'vitest'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
import { zodToJsonSchema } from './zod-to-json-schema'
|
import { zodToJsonSchema } from './zod-to-json-schema'
|
||||||
|
|
||||||
describe('zodToJsonSchema', () => {
|
describe('zodToJsonSchema', () => {
|
||||||
it('handles basic objects', () => {
|
test('handles basic objects', () => {
|
||||||
const params = zodToJsonSchema(
|
const params = zodToJsonSchema(
|
||||||
z.object({
|
z.object({
|
||||||
name: z.string().min(1).describe('Name of the person'),
|
name: z.string().min(1).describe('Name of the person'),
|
||||||
|
@ -30,7 +30,7 @@ describe('zodToJsonSchema', () => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles enums and unions', () => {
|
test('handles enums and unions', () => {
|
||||||
const params = zodToJsonSchema(
|
const params = zodToJsonSchema(
|
||||||
z.object({
|
z.object({
|
||||||
name: z.string().min(1).describe('Name of the person'),
|
name: z.string().min(1).describe('Name of the person'),
|
||||||
|
|
|
@ -1,13 +1,23 @@
|
||||||
import type { z } from 'zod'
|
import type { z } from 'zod'
|
||||||
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema'
|
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'openai-zod-to-json-schema'
|
||||||
|
|
||||||
import type * as types from './types'
|
import type * as types from './types'
|
||||||
import { omit } from './utils'
|
import { omit } from './utils'
|
||||||
|
|
||||||
/** Generate a JSON Schema from a Zod schema. */
|
/** Generate a JSON Schema from a Zod schema. */
|
||||||
export function zodToJsonSchema(schema: z.ZodType): types.JSONSchema {
|
export function zodToJsonSchema(
|
||||||
|
schema: z.ZodType,
|
||||||
|
{
|
||||||
|
strict = false
|
||||||
|
}: {
|
||||||
|
strict?: boolean
|
||||||
|
} = {}
|
||||||
|
): types.JSONSchema {
|
||||||
return omit(
|
return omit(
|
||||||
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }),
|
zodToJsonSchemaImpl(schema, {
|
||||||
|
$refStrategy: 'none',
|
||||||
|
openaiStrictMode: strict
|
||||||
|
}),
|
||||||
'$schema',
|
'$schema',
|
||||||
'default',
|
'default',
|
||||||
'definitions',
|
'definitions',
|
||||||
|
|
Plik diff jest za duży
Load Diff
|
@ -12,13 +12,7 @@
|
||||||
"dependsOn": ["^clean"]
|
"dependsOn": ["^clean"]
|
||||||
},
|
},
|
||||||
"test": {
|
"test": {
|
||||||
"dependsOn": [
|
"dependsOn": ["test:format", "test:lint", "test:typecheck", "test:unit"]
|
||||||
"build",
|
|
||||||
"test:format",
|
|
||||||
"test:lint",
|
|
||||||
"test:typecheck",
|
|
||||||
"test:unit"
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
"test:lint": {
|
"test:lint": {
|
||||||
"dependsOn": ["^test:lint"],
|
"dependsOn": ["^test:lint"],
|
||||||
|
|
Ładowanie…
Reference in New Issue