kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add AIChain
rodzic
a58f026eae
commit
391da4f996
|
@ -0,0 +1,39 @@
|
||||||
|
#!/usr/bin/env node
|
||||||
|
import 'dotenv/config'
|
||||||
|
|
||||||
|
import {
|
||||||
|
createAIChain,
|
||||||
|
Msg,
|
||||||
|
PerigonClient,
|
||||||
|
SerperClient
|
||||||
|
} from '@agentic/stdlib'
|
||||||
|
import { ChatModel } from '@dexaai/dexter'
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const perigon = new PerigonClient()
|
||||||
|
const serper = new SerperClient()
|
||||||
|
|
||||||
|
const chatModel = new ChatModel({
|
||||||
|
params: { model: 'gpt-4o', temperature: 0 },
|
||||||
|
debug: true
|
||||||
|
})
|
||||||
|
|
||||||
|
const chain = createAIChain({
|
||||||
|
chatFn: chatModel.run.bind(chatModel),
|
||||||
|
tools: [perigon.functions.pick('search_news_stories'), serper],
|
||||||
|
params: {
|
||||||
|
messages: [
|
||||||
|
Msg.system(
|
||||||
|
'You are a helpful assistant. Be as concise as possible. Respond in markdown. Always cite your sources.'
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await chain(
|
||||||
|
'Summarize the latest news stories about the upcoming US election.'
|
||||||
|
)
|
||||||
|
console.log(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
await main()
|
|
@ -0,0 +1,32 @@
|
||||||
|
#!/usr/bin/env node
|
||||||
|
import 'dotenv/config'
|
||||||
|
|
||||||
|
import { createAIChain, Msg } from '@agentic/stdlib'
|
||||||
|
import { ChatModel } from '@dexaai/dexter'
|
||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const chatModel = new ChatModel({
|
||||||
|
params: { model: 'gpt-4o', temperature: 0 },
|
||||||
|
debug: true
|
||||||
|
})
|
||||||
|
|
||||||
|
const chain = createAIChain({
|
||||||
|
chatFn: chatModel.run.bind(chatModel),
|
||||||
|
params: {
|
||||||
|
messages: [Msg.system('Extract a JSON user object from the given text.')]
|
||||||
|
},
|
||||||
|
schema: z.object({
|
||||||
|
name: z.string(),
|
||||||
|
age: z.number(),
|
||||||
|
location: z.string().optional()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await chain(
|
||||||
|
'Bob Vance is 42 years old and lives in Brooklyn, NY. He is a software engineer.'
|
||||||
|
)
|
||||||
|
console.log(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
await main()
|
|
@ -98,7 +98,8 @@
|
||||||
"quick-lru": "^7.0.0",
|
"quick-lru": "^7.0.0",
|
||||||
"type-fest": "^4.21.0",
|
"type-fest": "^4.21.0",
|
||||||
"zod": "^3.23.3",
|
"zod": "^3.23.3",
|
||||||
"zod-to-json-schema": "^3.23.1"
|
"zod-to-json-schema": "^3.23.1",
|
||||||
|
"zod-validation-error": "^3.3.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@aws-sdk/client-sso-oidc": "^3.616.0",
|
"@aws-sdk/client-sso-oidc": "^3.616.0",
|
||||||
|
|
626
pnpm-lock.yaml
626
pnpm-lock.yaml
Plik diff jest za duży
Load Diff
|
@ -0,0 +1,189 @@
|
||||||
|
import type { SetOptional } from 'type-fest'
|
||||||
|
import type { ZodType } from 'zod'
|
||||||
|
import pMap from 'p-map'
|
||||||
|
|
||||||
|
import type * as types from './types.js'
|
||||||
|
import { AIFunctionSet } from './ai-function-set.js'
|
||||||
|
import { AbortError } from './errors.js'
|
||||||
|
import { Msg } from './message.js'
|
||||||
|
import { asSchema, augmentSystemMessageWithJsonSchema } from './schema.js'
|
||||||
|
import { getErrorMessage } from './utils.js'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a chain of chat completion calls that can be invoked as a single
|
||||||
|
* function. It is meant to simplify the process of resolving tool calls
|
||||||
|
* and optionally adding validation to the final result.
|
||||||
|
*
|
||||||
|
* The returned function will invoke the `chatFn` up to `maxCalls` times,
|
||||||
|
* resolving any tool calls to the included `functions` and retrying if
|
||||||
|
* necessary up to `maxRetries`.
|
||||||
|
*
|
||||||
|
* The chain ends when a non-tool call is returned, and the final result can
|
||||||
|
* optionally be validated against a Zod schema, which defaults to a `string`.
|
||||||
|
*
|
||||||
|
* To prevent possible infinite loops, the chain will throw an error if it
|
||||||
|
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
|
||||||
|
*/
|
||||||
|
export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
|
chatFn,
|
||||||
|
params,
|
||||||
|
schema: rawSchema,
|
||||||
|
tools,
|
||||||
|
maxCalls = 5,
|
||||||
|
maxRetries = 2,
|
||||||
|
toolCallConcurrency = 8,
|
||||||
|
injectSchemaIntoSystemMessage = true
|
||||||
|
}: {
|
||||||
|
chatFn: types.ChatFn
|
||||||
|
params?: types.Simplify<
|
||||||
|
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
|
||||||
|
>
|
||||||
|
tools?: types.AIFunctionLike[]
|
||||||
|
schema?: ZodType<Result> | types.Schema<Result>
|
||||||
|
maxCalls?: number
|
||||||
|
maxRetries?: number
|
||||||
|
toolCallConcurrency?: number
|
||||||
|
injectSchemaIntoSystemMessage?: boolean
|
||||||
|
}): types.AIChain<Result> {
|
||||||
|
const functionSet = new AIFunctionSet(tools)
|
||||||
|
const defaultParams: Partial<types.ChatParams> | undefined =
|
||||||
|
rawSchema && !functionSet.size
|
||||||
|
? {
|
||||||
|
response_format: { type: 'json_object' }
|
||||||
|
}
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
return async (chatParams) => {
|
||||||
|
const { messages, ...modelParams }: SetOptional<types.ChatParams, 'model'> =
|
||||||
|
typeof chatParams === 'string'
|
||||||
|
? {
|
||||||
|
...defaultParams,
|
||||||
|
...params,
|
||||||
|
messages: [...(params?.messages ?? []), Msg.user(chatParams)]
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
...defaultParams,
|
||||||
|
...params,
|
||||||
|
...chatParams,
|
||||||
|
messages: [
|
||||||
|
...(params?.messages ?? []),
|
||||||
|
...(chatParams.messages ?? [])
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!messages.length) {
|
||||||
|
throw new Error('AIChain error: "messages" is empty')
|
||||||
|
}
|
||||||
|
|
||||||
|
const schema = rawSchema ? asSchema(rawSchema) : undefined
|
||||||
|
|
||||||
|
if (schema && injectSchemaIntoSystemMessage) {
|
||||||
|
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
|
||||||
|
const lastSystemMessageContent =
|
||||||
|
messages[lastSystemMessageIndex]?.content!
|
||||||
|
|
||||||
|
const systemMessage = augmentSystemMessageWithJsonSchema({
|
||||||
|
system: lastSystemMessageContent,
|
||||||
|
schema: schema.jsonSchema
|
||||||
|
})
|
||||||
|
|
||||||
|
if (lastSystemMessageIndex >= 0) {
|
||||||
|
messages[lastSystemMessageIndex] = Msg.system(systemMessage!)
|
||||||
|
} else {
|
||||||
|
messages.unshift(Msg.system(systemMessage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let numCalls = 0
|
||||||
|
let numErrors = 0
|
||||||
|
|
||||||
|
do {
|
||||||
|
++numCalls
|
||||||
|
const response = await chatFn({
|
||||||
|
...modelParams,
|
||||||
|
messages,
|
||||||
|
tools: functionSet.size ? functionSet.toolSpecs : undefined
|
||||||
|
})
|
||||||
|
|
||||||
|
const { message } = response
|
||||||
|
messages.push(message)
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (Msg.isToolCall(message)) {
|
||||||
|
if (!functionSet.size) {
|
||||||
|
throw new AbortError('No functions provided to handle tool call')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Synchronously validate that all tool calls reference valid functions
|
||||||
|
for (const toolCall of message.tool_calls) {
|
||||||
|
const func = functionSet.get(toolCall.function.name)
|
||||||
|
|
||||||
|
if (!func) {
|
||||||
|
throw new Error(
|
||||||
|
`No function found with name ${toolCall.function.name}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await pMap(
|
||||||
|
message.tool_calls,
|
||||||
|
async (toolCall) => {
|
||||||
|
const func = functionSet.get(toolCall.function.name)!
|
||||||
|
|
||||||
|
// TODO: ideally we'd differentiate between tool argument validation
|
||||||
|
// errors versus errors thrown from the tool implementation. Errors
|
||||||
|
// from the underlying tool could be things like network errors, which
|
||||||
|
// should be retried locally without re-calling the LLM.
|
||||||
|
const result = await func(toolCall.function.arguments)
|
||||||
|
|
||||||
|
const toolResult = Msg.toolResult(result, toolCall.id)
|
||||||
|
messages.push(toolResult)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
concurrency: toolCallConcurrency
|
||||||
|
}
|
||||||
|
)
|
||||||
|
} else if (Msg.isFuncCall(message)) {
|
||||||
|
throw new AbortError(
|
||||||
|
'Function calls are not supported; expected tool call'
|
||||||
|
)
|
||||||
|
} else if (Msg.isAssistant(message)) {
|
||||||
|
if (schema && schema.validate) {
|
||||||
|
const result = schema.validate(message.content)
|
||||||
|
|
||||||
|
if (result.success) {
|
||||||
|
return result.data
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(result.error)
|
||||||
|
} else {
|
||||||
|
return message.content as Result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
numErrors++
|
||||||
|
|
||||||
|
if (err instanceof AbortError) {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.push(
|
||||||
|
Msg.user(
|
||||||
|
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (numErrors > maxRetries) {
|
||||||
|
throw new Error(
|
||||||
|
`Chain failed after ${numErrors} errors: ${err.message}`,
|
||||||
|
{
|
||||||
|
cause: err
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (numCalls < maxCalls)
|
||||||
|
|
||||||
|
throw new Error(`Chain aborted after reaching max ${maxCalls} calls`)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,7 @@
|
||||||
export class RetryableError extends Error {}
|
export class RetryableError extends Error {}
|
||||||
|
|
||||||
|
export class AbortError extends Error {}
|
||||||
|
|
||||||
export class ParseError extends RetryableError {}
|
export class ParseError extends RetryableError {}
|
||||||
|
|
||||||
export class TimeoutError extends Error {}
|
export class TimeoutError extends Error {}
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
export * from './ai-function-set.js'
|
export * from './ai-function-set.js'
|
||||||
|
export * from './create-ai-chain.js'
|
||||||
export * from './create-ai-function.js'
|
export * from './create-ai-function.js'
|
||||||
export * from './errors.js'
|
export * from './errors.js'
|
||||||
export * from './fns.js'
|
export * from './fns.js'
|
||||||
export * from './message.js'
|
export * from './message.js'
|
||||||
export * from './parse-structured-output.js'
|
export * from './parse-structured-output.js'
|
||||||
|
export * from './schema.js'
|
||||||
export * from './services/index.js'
|
export * from './services/index.js'
|
||||||
export * from './tools/search-and-crawl.js'
|
export * from './tools/search-and-crawl.js'
|
||||||
export type * from './types.js'
|
export type * from './types.js'
|
||||||
|
|
|
@ -1,20 +1,10 @@
|
||||||
import type { JsonValue } from 'type-fest'
|
import type { JsonObject, JsonValue } from 'type-fest'
|
||||||
import { jsonrepair, JSONRepairError } from 'jsonrepair'
|
import { jsonrepair, JSONRepairError } from 'jsonrepair'
|
||||||
import { z, type ZodType } from 'zod'
|
import { z, type ZodType } from 'zod'
|
||||||
|
import { fromZodError } from 'zod-validation-error'
|
||||||
|
|
||||||
import { ParseError } from './errors.js'
|
import { ParseError } from './errors.js'
|
||||||
|
import { type SafeParseResult } from './types.js'
|
||||||
export type SafeParseResult<T> =
|
|
||||||
| {
|
|
||||||
success: true
|
|
||||||
data: T
|
|
||||||
error?: never
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
success: false
|
|
||||||
data?: never
|
|
||||||
error: string
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parses a string which is expected to contain a structured JSON value.
|
* Parses a string which is expected to contain a structured JSON value.
|
||||||
|
@ -25,15 +15,18 @@ export type SafeParseResult<T> =
|
||||||
* The JSON value is then parsed against a `zod` schema to enforce the shape of
|
* The JSON value is then parsed against a `zod` schema to enforce the shape of
|
||||||
* the output.
|
* the output.
|
||||||
*
|
*
|
||||||
* @param output - string to parse
|
|
||||||
* @param outputSchema - zod schema
|
|
||||||
*
|
|
||||||
* @returns parsed output
|
* @returns parsed output
|
||||||
*/
|
*/
|
||||||
export function parseStructuredOutput<T>(
|
export function parseStructuredOutput<T>(
|
||||||
output: string,
|
value: unknown,
|
||||||
outputSchema: ZodType<T>
|
outputSchema: ZodType<T>
|
||||||
): T {
|
): T {
|
||||||
|
if (!value || typeof value !== 'string') {
|
||||||
|
throw new Error('Invalid output: expected string')
|
||||||
|
}
|
||||||
|
|
||||||
|
const output = value as string
|
||||||
|
|
||||||
let result
|
let result
|
||||||
if (outputSchema instanceof z.ZodArray || 'element' in outputSchema) {
|
if (outputSchema instanceof z.ZodArray || 'element' in outputSchema) {
|
||||||
result = parseArrayOutput(output)
|
result = parseArrayOutput(output)
|
||||||
|
@ -55,16 +48,25 @@ export function parseStructuredOutput<T>(
|
||||||
const safeResult = (outputSchema.safeParse as any)(result)
|
const safeResult = (outputSchema.safeParse as any)(result)
|
||||||
|
|
||||||
if (!safeResult.success) {
|
if (!safeResult.success) {
|
||||||
throw new ParseError(safeResult.error)
|
throw fromZodError(safeResult.error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return safeResult.data
|
return safeResult.data
|
||||||
}
|
}
|
||||||
|
|
||||||
export function safeParseStructuredOutput<T>(
|
export function safeParseStructuredOutput<T>(
|
||||||
output: string,
|
value: unknown,
|
||||||
outputSchema: ZodType<T>
|
outputSchema: ZodType<T>
|
||||||
): SafeParseResult<T> {
|
): SafeParseResult<T> {
|
||||||
|
if (!value || typeof value !== 'string') {
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: 'Invalid output: expected string'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const output = value as string
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const data = parseStructuredOutput<T>(output, outputSchema)
|
const data = parseStructuredOutput<T>(output, outputSchema)
|
||||||
return {
|
return {
|
||||||
|
@ -72,7 +74,7 @@ export function safeParseStructuredOutput<T>(
|
||||||
data
|
data
|
||||||
}
|
}
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
console.error(err)
|
// console.error(err)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
success: false,
|
success: false,
|
||||||
|
@ -179,18 +181,16 @@ const BOOLEAN_OUTPUTS: Record<string, boolean> = {
|
||||||
* @param output - string to parse
|
* @param output - string to parse
|
||||||
* @returns parsed array
|
* @returns parsed array
|
||||||
*/
|
*/
|
||||||
export function parseArrayOutput(output: string): Array<any> {
|
export function parseArrayOutput(output: string): JsonValue[] {
|
||||||
try {
|
try {
|
||||||
const arrayOutput = extractJSONFromString(output, 'array')
|
const arrayOutput = extractJSONFromString(output, 'array')
|
||||||
if (arrayOutput.length === 0) {
|
if (arrayOutput.length === 0) {
|
||||||
throw new ParseError(`Invalid JSON array: ${output}`)
|
throw new ParseError('Invalid JSON array')
|
||||||
}
|
}
|
||||||
|
|
||||||
const parsedOutput = arrayOutput[0]
|
const parsedOutput = arrayOutput[0]
|
||||||
if (!Array.isArray(parsedOutput)) {
|
if (!Array.isArray(parsedOutput)) {
|
||||||
throw new ParseError(
|
throw new ParseError('Expected JSON array')
|
||||||
`Invalid JSON array: ${JSON.stringify(parsedOutput)}`
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return parsedOutput
|
return parsedOutput
|
||||||
|
@ -211,24 +211,24 @@ export function parseArrayOutput(output: string): Array<any> {
|
||||||
* @param output - string to parse
|
* @param output - string to parse
|
||||||
* @returns parsed object
|
* @returns parsed object
|
||||||
*/
|
*/
|
||||||
export function parseObjectOutput(output: string) {
|
export function parseObjectOutput(output: string): JsonObject {
|
||||||
try {
|
try {
|
||||||
const arrayOutput = extractJSONFromString(output, 'object')
|
const arrayOutput = extractJSONFromString(output, 'object')
|
||||||
if (arrayOutput.length === 0) {
|
if (arrayOutput.length === 0) {
|
||||||
throw new ParseError(`Invalid JSON object: ${output}`)
|
throw new ParseError('Invalid JSON object')
|
||||||
}
|
}
|
||||||
|
|
||||||
let parsedOutput = arrayOutput[0]
|
let parsedOutput = arrayOutput[0]
|
||||||
if (Array.isArray(parsedOutput)) {
|
if (Array.isArray(parsedOutput)) {
|
||||||
// TODO
|
// TODO
|
||||||
parsedOutput = parsedOutput[0]
|
parsedOutput = parsedOutput[0]
|
||||||
} else if (typeof parsedOutput !== 'object') {
|
|
||||||
throw new ParseError(
|
|
||||||
`Invalid JSON object: ${JSON.stringify(parsedOutput)}`
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return parsedOutput
|
if (!parsedOutput || typeof parsedOutput !== 'object') {
|
||||||
|
throw new ParseError('Expected JSON object')
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedOutput as JsonObject
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
if (err instanceof JSONRepairError) {
|
if (err instanceof JSONRepairError) {
|
||||||
throw new ParseError(err.message, { cause: err })
|
throw new ParseError(err.message, { cause: err })
|
||||||
|
|
|
@ -0,0 +1,119 @@
|
||||||
|
import type { z } from 'zod'
|
||||||
|
|
||||||
|
import type * as types from './types.js'
|
||||||
|
import { safeParseStructuredOutput } from './parse-structured-output.js'
|
||||||
|
import { stringifyForModel } from './utils.js'
|
||||||
|
import { zodToJsonSchema } from './zod-to-json-schema.js'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to mark schemas so we can support both Zod and custom schemas.
|
||||||
|
*/
|
||||||
|
export const schemaSymbol = Symbol('agentic..schema')
|
||||||
|
export const validatorSymbol = Symbol('agentic.validator')
|
||||||
|
|
||||||
|
export type Schema<TData = unknown> = {
|
||||||
|
/**
|
||||||
|
* The JSON Schema.
|
||||||
|
*/
|
||||||
|
readonly jsonSchema: types.JSONSchema
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optional. Validates that the structure of a value matches this schema,
|
||||||
|
* and returns a typed version of the value if it does.
|
||||||
|
*/
|
||||||
|
readonly validate?: types.ValidatorFn<TData>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to mark schemas so we can support both Zod and custom schemas.
|
||||||
|
*/
|
||||||
|
[schemaSymbol]: true
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Schema type for inference.
|
||||||
|
*/
|
||||||
|
_type: TData
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isSchema(value: unknown): value is Schema {
|
||||||
|
return (
|
||||||
|
typeof value === 'object' &&
|
||||||
|
value !== null &&
|
||||||
|
schemaSymbol in value &&
|
||||||
|
value[schemaSymbol] === true &&
|
||||||
|
'jsonSchema' in value &&
|
||||||
|
'validate' in value
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isZodSchema(value: unknown): value is z.ZodType {
|
||||||
|
return (
|
||||||
|
typeof value === 'object' &&
|
||||||
|
value !== null &&
|
||||||
|
'_type' in value &&
|
||||||
|
'_output' in value &&
|
||||||
|
'_input' in value &&
|
||||||
|
'_def' in value &&
|
||||||
|
'parse' in value &&
|
||||||
|
'safeParse' in value
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function asSchema<TData>(
|
||||||
|
schema: z.Schema<TData> | Schema<TData>
|
||||||
|
): Schema<TData> {
|
||||||
|
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a schema from a JSON Schema.
|
||||||
|
*/
|
||||||
|
export function createSchema<TData = unknown>(
|
||||||
|
jsonSchema: types.JSONSchema,
|
||||||
|
{
|
||||||
|
validate
|
||||||
|
}: {
|
||||||
|
validate?: types.ValidatorFn<TData>
|
||||||
|
} = {}
|
||||||
|
): Schema<TData> {
|
||||||
|
return {
|
||||||
|
[schemaSymbol]: true,
|
||||||
|
_type: undefined as TData,
|
||||||
|
jsonSchema,
|
||||||
|
validate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createSchemaFromZodSchema<TData>(
|
||||||
|
zodSchema: z.Schema<TData>
|
||||||
|
): Schema<TData> {
|
||||||
|
return createSchema(zodToJsonSchema(zodSchema), {
|
||||||
|
validate: (value) => {
|
||||||
|
return safeParseStructuredOutput(value, zodSchema)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const DEFAULT_SCHEMA_PREFIX = `
|
||||||
|
---
|
||||||
|
|
||||||
|
Respond with JSON using the following JSON schema:
|
||||||
|
|
||||||
|
\`\`\`json`
|
||||||
|
const DEFAULT_SCHEMA_SUFFIX = '```'
|
||||||
|
|
||||||
|
export function augmentSystemMessageWithJsonSchema({
|
||||||
|
schema,
|
||||||
|
system,
|
||||||
|
schemaPrefix = DEFAULT_SCHEMA_PREFIX,
|
||||||
|
schemaSuffix = DEFAULT_SCHEMA_SUFFIX
|
||||||
|
}: {
|
||||||
|
schema: types.JSONSchema
|
||||||
|
system?: string
|
||||||
|
schemaPrefix?: string
|
||||||
|
schemaSuffix?: string
|
||||||
|
}): string {
|
||||||
|
return [system, schemaPrefix, stringifyForModel(schema), schemaSuffix]
|
||||||
|
.filter(Boolean)
|
||||||
|
.join('\n')
|
||||||
|
.trim()
|
||||||
|
}
|
|
@ -3,6 +3,7 @@ import pThrottle from 'p-throttle'
|
||||||
import z from 'zod'
|
import z from 'zod'
|
||||||
|
|
||||||
import { aiFunction, AIFunctionsProvider } from '../fns.js'
|
import { aiFunction, AIFunctionsProvider } from '../fns.js'
|
||||||
|
import { isZodSchema } from '../schema.js'
|
||||||
import { assert, delay, getEnv, throttleKy } from '../utils.js'
|
import { assert, delay, getEnv, throttleKy } from '../utils.js'
|
||||||
import { zodToJsonSchema } from '../zod-to-json-schema.js'
|
import { zodToJsonSchema } from '../zod-to-json-schema.js'
|
||||||
|
|
||||||
|
@ -157,7 +158,7 @@ export class FirecrawlClient extends AIFunctionsProvider {
|
||||||
|
|
||||||
if (opts?.extractorOptions?.extractionSchema) {
|
if (opts?.extractorOptions?.extractionSchema) {
|
||||||
let schema = opts.extractorOptions.extractionSchema
|
let schema = opts.extractorOptions.extractionSchema
|
||||||
if (schema instanceof z.ZodSchema) {
|
if (isZodSchema(schema)) {
|
||||||
schema = zodToJsonSchema(schema)
|
schema = zodToJsonSchema(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
67
src/types.ts
67
src/types.ts
|
@ -1,4 +1,4 @@
|
||||||
import type { Jsonifiable } from 'type-fest'
|
import type { Jsonifiable, SetOptional, Simplify } from 'type-fest'
|
||||||
import type { z } from 'zod'
|
import type { z } from 'zod'
|
||||||
|
|
||||||
import type { AIFunctionSet } from './ai-function-set.js'
|
import type { AIFunctionSet } from './ai-function-set.js'
|
||||||
|
@ -6,8 +6,10 @@ import type { AIFunctionsProvider } from './fns.js'
|
||||||
import type { Msg } from './message.js'
|
import type { Msg } from './message.js'
|
||||||
|
|
||||||
export type { Msg } from './message.js'
|
export type { Msg } from './message.js'
|
||||||
|
export type { Schema } from './schema.js'
|
||||||
export type { KyInstance } from 'ky'
|
export type { KyInstance } from 'ky'
|
||||||
export type { ThrottledFunction } from 'p-throttle'
|
export type { ThrottledFunction } from 'p-throttle'
|
||||||
|
export type { Simplify } from 'type-fest'
|
||||||
|
|
||||||
export type Nullable<T> = T | null
|
export type Nullable<T> = T | null
|
||||||
|
|
||||||
|
@ -17,7 +19,15 @@ export type DeepNullable<T> = T extends object
|
||||||
|
|
||||||
export type MaybePromise<T> = T | Promise<T>
|
export type MaybePromise<T> = T | Promise<T>
|
||||||
|
|
||||||
export type RelaxedJsonifiable = Jsonifiable | Record<string, Jsonifiable>
|
// TODO: use a more specific type
|
||||||
|
export type JSONSchema = Record<string, unknown>
|
||||||
|
|
||||||
|
export type RelaxedJsonifiable =
|
||||||
|
| Jsonifiable
|
||||||
|
| Record<string, unknown>
|
||||||
|
| JSONSchema
|
||||||
|
|
||||||
|
export type Context = object
|
||||||
|
|
||||||
export interface AIFunctionSpec {
|
export interface AIFunctionSpec {
|
||||||
/** AI Function name. */
|
/** AI Function name. */
|
||||||
|
@ -27,7 +37,7 @@ export interface AIFunctionSpec {
|
||||||
description: string
|
description: string
|
||||||
|
|
||||||
/** JSON schema spec of the function's input parameters */
|
/** JSON schema spec of the function's input parameters */
|
||||||
parameters: Record<string, unknown>
|
parameters: JSONSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AIToolSpec {
|
export interface AIToolSpec {
|
||||||
|
@ -79,3 +89,54 @@ export interface AIFunction<
|
||||||
// TODO: this `any` shouldn't be necessary, but it is for `createAIFunction` results to be assignable to `AIFunctionLike`
|
// TODO: this `any` shouldn't be necessary, but it is for `createAIFunction` results to be assignable to `AIFunctionLike`
|
||||||
impl: (params: z.infer<InputSchema> | any) => MaybePromise<Return>
|
impl: (params: z.infer<InputSchema> | any) => MaybePromise<Return>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ChatParams {
|
||||||
|
messages: Msg[]
|
||||||
|
model: string & {}
|
||||||
|
functions?: AIFunctionSpec[]
|
||||||
|
function_call?: 'none' | 'auto' | { name: string }
|
||||||
|
tools?: AIToolSpec[]
|
||||||
|
tool_choice?:
|
||||||
|
| 'none'
|
||||||
|
| 'auto'
|
||||||
|
| 'required'
|
||||||
|
| { type: 'function'; function: { name: string } }
|
||||||
|
parallel_tool_calls?: boolean
|
||||||
|
logit_bias?: Record<string, number>
|
||||||
|
logprobs?: boolean
|
||||||
|
max_tokens?: number
|
||||||
|
presence_penalty?: number
|
||||||
|
frequency_penalty?: number
|
||||||
|
response_format?: { type: 'text' | 'json_object' }
|
||||||
|
seed?: number
|
||||||
|
stop?: string | null | Array<string>
|
||||||
|
temperature?: number
|
||||||
|
top_logprobs?: number
|
||||||
|
top_p?: number
|
||||||
|
user?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
/** An OpenAI-compatible chat completions API */
|
||||||
|
export type ChatFn = (
|
||||||
|
params: Simplify<SetOptional<ChatParams, 'model'>>
|
||||||
|
) => Promise<{ message: Msg }>
|
||||||
|
|
||||||
|
export type AIChainResult = string | Record<string, any>
|
||||||
|
|
||||||
|
export type AIChain<Result extends AIChainResult = string> = (
|
||||||
|
params:
|
||||||
|
| string
|
||||||
|
| Simplify<SetOptional<Omit<ChatParams, 'tools' | 'functions'>, 'model'>>
|
||||||
|
) => Promise<Result>
|
||||||
|
|
||||||
|
export type SafeParseResult<TData> =
|
||||||
|
| {
|
||||||
|
success: true
|
||||||
|
data: TData
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
success: false
|
||||||
|
error: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ValidatorFn<TData> = (value: unknown) => SafeParseResult<TData>
|
||||||
|
|
26
src/utils.ts
26
src/utils.ts
|
@ -1,4 +1,3 @@
|
||||||
import type { Jsonifiable } from 'type-fest'
|
|
||||||
import dedent from 'dedent'
|
import dedent from 'dedent'
|
||||||
import hashObjectImpl, { type Options as HashObjectOptions } from 'hash-object'
|
import hashObjectImpl, { type Options as HashObjectOptions } from 'hash-object'
|
||||||
|
|
||||||
|
@ -167,7 +166,9 @@ export function sanitizeSearchParams(
|
||||||
/**
|
/**
|
||||||
* Stringifies a JSON value in a way that's optimized for use with LLM prompts.
|
* Stringifies a JSON value in a way that's optimized for use with LLM prompts.
|
||||||
*/
|
*/
|
||||||
export function stringifyForModel(jsonObject?: Jsonifiable): string {
|
export function stringifyForModel(
|
||||||
|
jsonObject?: types.RelaxedJsonifiable
|
||||||
|
): string {
|
||||||
if (jsonObject === undefined) {
|
if (jsonObject === undefined) {
|
||||||
return ''
|
return ''
|
||||||
}
|
}
|
||||||
|
@ -208,3 +209,24 @@ export function isAIFunction(obj: any): obj is types.AIFunction {
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getErrorMessage(error?: unknown): string {
|
||||||
|
if (!error) {
|
||||||
|
return 'unknown error'
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof error === 'string') {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
const message = (error as any).message
|
||||||
|
if (message && typeof message === 'string') {
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
return JSON.stringify(error)
|
||||||
|
} catch {
|
||||||
|
return 'unknown error'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import type { z } from 'zod'
|
import type { z } from 'zod'
|
||||||
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema'
|
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema'
|
||||||
|
|
||||||
|
import type * as types from './types.js'
|
||||||
import { omit } from './utils.js'
|
import { omit } from './utils.js'
|
||||||
|
|
||||||
/** Generate a JSON Schema from a Zod schema. */
|
/** Generate a JSON Schema from a Zod schema. */
|
||||||
export function zodToJsonSchema(schema: z.ZodType): Record<string, unknown> {
|
export function zodToJsonSchema(schema: z.ZodType): types.JSONSchema {
|
||||||
return omit(
|
return omit(
|
||||||
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }),
|
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }),
|
||||||
'$schema',
|
'$schema',
|
||||||
|
|
Ładowanie…
Reference in New Issue