feat: add AIChain

pull/659/head
Travis Fischer 2024-07-26 14:40:34 -07:00
rodzic a58f026eae
commit 391da4f996
13 zmienionych plików z 897 dodań i 278 usunięć

Wyświetl plik

@ -0,0 +1,39 @@
#!/usr/bin/env node
import 'dotenv/config'
import {
createAIChain,
Msg,
PerigonClient,
SerperClient
} from '@agentic/stdlib'
import { ChatModel } from '@dexaai/dexter'
async function main() {
const perigon = new PerigonClient()
const serper = new SerperClient()
const chatModel = new ChatModel({
params: { model: 'gpt-4o', temperature: 0 },
debug: true
})
const chain = createAIChain({
chatFn: chatModel.run.bind(chatModel),
tools: [perigon.functions.pick('search_news_stories'), serper],
params: {
messages: [
Msg.system(
'You are a helpful assistant. Be as concise as possible. Respond in markdown. Always cite your sources.'
)
]
}
})
const result = await chain(
'Summarize the latest news stories about the upcoming US election.'
)
console.log(result)
}
await main()

Wyświetl plik

@ -0,0 +1,32 @@
#!/usr/bin/env node
import 'dotenv/config'
import { createAIChain, Msg } from '@agentic/stdlib'
import { ChatModel } from '@dexaai/dexter'
import { z } from 'zod'
async function main() {
const chatModel = new ChatModel({
params: { model: 'gpt-4o', temperature: 0 },
debug: true
})
const chain = createAIChain({
chatFn: chatModel.run.bind(chatModel),
params: {
messages: [Msg.system('Extract a JSON user object from the given text.')]
},
schema: z.object({
name: z.string(),
age: z.number(),
location: z.string().optional()
})
})
const result = await chain(
'Bob Vance is 42 years old and lives in Brooklyn, NY. He is a software engineer.'
)
console.log(result)
}
await main()

Wyświetl plik

@ -98,7 +98,8 @@
"quick-lru": "^7.0.0",
"type-fest": "^4.21.0",
"zod": "^3.23.3",
"zod-to-json-schema": "^3.23.1"
"zod-to-json-schema": "^3.23.1",
"zod-validation-error": "^3.3.0"
},
"devDependencies": {
"@aws-sdk/client-sso-oidc": "^3.616.0",

Plik diff jest za duży Load Diff

Wyświetl plik

@ -0,0 +1,189 @@
import type { SetOptional } from 'type-fest'
import type { ZodType } from 'zod'
import pMap from 'p-map'
import type * as types from './types.js'
import { AIFunctionSet } from './ai-function-set.js'
import { AbortError } from './errors.js'
import { Msg } from './message.js'
import { asSchema, augmentSystemMessageWithJsonSchema } from './schema.js'
import { getErrorMessage } from './utils.js'
/**
* Creates a chain of chat completion calls that can be invoked as a single
* function. It is meant to simplify the process of resolving tool calls
* and optionally adding validation to the final result.
*
* The returned function will invoke the `chatFn` up to `maxCalls` times,
* resolving any tool calls to the included `functions` and retrying if
* necessary up to `maxRetries`.
*
* The chain ends when a non-tool call is returned, and the final result can
* optionally be validated against a Zod schema, which defaults to a `string`.
*
* To prevent possible infinite loops, the chain will throw an error if it
* exceeds `maxCalls` (`maxCalls` is expected to be >= `maxRetries`).
*/
export function createAIChain<Result extends types.AIChainResult = string>({
chatFn,
params,
schema: rawSchema,
tools,
maxCalls = 5,
maxRetries = 2,
toolCallConcurrency = 8,
injectSchemaIntoSystemMessage = true
}: {
chatFn: types.ChatFn
params?: types.Simplify<
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
>
tools?: types.AIFunctionLike[]
schema?: ZodType<Result> | types.Schema<Result>
maxCalls?: number
maxRetries?: number
toolCallConcurrency?: number
injectSchemaIntoSystemMessage?: boolean
}): types.AIChain<Result> {
const functionSet = new AIFunctionSet(tools)
const defaultParams: Partial<types.ChatParams> | undefined =
rawSchema && !functionSet.size
? {
response_format: { type: 'json_object' }
}
: undefined
return async (chatParams) => {
const { messages, ...modelParams }: SetOptional<types.ChatParams, 'model'> =
typeof chatParams === 'string'
? {
...defaultParams,
...params,
messages: [...(params?.messages ?? []), Msg.user(chatParams)]
}
: {
...defaultParams,
...params,
...chatParams,
messages: [
...(params?.messages ?? []),
...(chatParams.messages ?? [])
]
}
if (!messages.length) {
throw new Error('AIChain error: "messages" is empty')
}
const schema = rawSchema ? asSchema(rawSchema) : undefined
if (schema && injectSchemaIntoSystemMessage) {
const lastSystemMessageIndex = messages.findLastIndex(Msg.isSystem)
const lastSystemMessageContent =
messages[lastSystemMessageIndex]?.content!
const systemMessage = augmentSystemMessageWithJsonSchema({
system: lastSystemMessageContent,
schema: schema.jsonSchema
})
if (lastSystemMessageIndex >= 0) {
messages[lastSystemMessageIndex] = Msg.system(systemMessage!)
} else {
messages.unshift(Msg.system(systemMessage))
}
}
let numCalls = 0
let numErrors = 0
do {
++numCalls
const response = await chatFn({
...modelParams,
messages,
tools: functionSet.size ? functionSet.toolSpecs : undefined
})
const { message } = response
messages.push(message)
try {
if (Msg.isToolCall(message)) {
if (!functionSet.size) {
throw new AbortError('No functions provided to handle tool call')
}
// Synchronously validate that all tool calls reference valid functions
for (const toolCall of message.tool_calls) {
const func = functionSet.get(toolCall.function.name)
if (!func) {
throw new Error(
`No function found with name ${toolCall.function.name}`
)
}
}
await pMap(
message.tool_calls,
async (toolCall) => {
const func = functionSet.get(toolCall.function.name)!
// TODO: ideally we'd differentiate between tool argument validation
// errors versus errors thrown from the tool implementation. Errors
// from the underlying tool could be things like network errors, which
// should be retried locally without re-calling the LLM.
const result = await func(toolCall.function.arguments)
const toolResult = Msg.toolResult(result, toolCall.id)
messages.push(toolResult)
},
{
concurrency: toolCallConcurrency
}
)
} else if (Msg.isFuncCall(message)) {
throw new AbortError(
'Function calls are not supported; expected tool call'
)
} else if (Msg.isAssistant(message)) {
if (schema && schema.validate) {
const result = schema.validate(message.content)
if (result.success) {
return result.data
}
throw new Error(result.error)
} else {
return message.content as Result
}
}
} catch (err: any) {
numErrors++
if (err instanceof AbortError) {
throw err
}
messages.push(
Msg.user(
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
)
)
if (numErrors > maxRetries) {
throw new Error(
`Chain failed after ${numErrors} errors: ${err.message}`,
{
cause: err
}
)
}
}
} while (numCalls < maxCalls)
throw new Error(`Chain aborted after reaching max ${maxCalls} calls`)
}
}

Wyświetl plik

@ -1,5 +1,7 @@
export class RetryableError extends Error {}
export class AbortError extends Error {}
export class ParseError extends RetryableError {}
export class TimeoutError extends Error {}

Wyświetl plik

@ -1,9 +1,11 @@
export * from './ai-function-set.js'
export * from './create-ai-chain.js'
export * from './create-ai-function.js'
export * from './errors.js'
export * from './fns.js'
export * from './message.js'
export * from './parse-structured-output.js'
export * from './schema.js'
export * from './services/index.js'
export * from './tools/search-and-crawl.js'
export type * from './types.js'

Wyświetl plik

@ -1,20 +1,10 @@
import type { JsonValue } from 'type-fest'
import type { JsonObject, JsonValue } from 'type-fest'
import { jsonrepair, JSONRepairError } from 'jsonrepair'
import { z, type ZodType } from 'zod'
import { fromZodError } from 'zod-validation-error'
import { ParseError } from './errors.js'
export type SafeParseResult<T> =
| {
success: true
data: T
error?: never
}
| {
success: false
data?: never
error: string
}
import { type SafeParseResult } from './types.js'
/**
* Parses a string which is expected to contain a structured JSON value.
@ -25,15 +15,18 @@ export type SafeParseResult<T> =
* The JSON value is then parsed against a `zod` schema to enforce the shape of
* the output.
*
* @param output - string to parse
* @param outputSchema - zod schema
*
* @returns parsed output
*/
export function parseStructuredOutput<T>(
output: string,
value: unknown,
outputSchema: ZodType<T>
): T {
if (!value || typeof value !== 'string') {
throw new Error('Invalid output: expected string')
}
const output = value as string
let result
if (outputSchema instanceof z.ZodArray || 'element' in outputSchema) {
result = parseArrayOutput(output)
@ -55,16 +48,25 @@ export function parseStructuredOutput<T>(
const safeResult = (outputSchema.safeParse as any)(result)
if (!safeResult.success) {
throw new ParseError(safeResult.error)
throw fromZodError(safeResult.error)
}
return safeResult.data
}
export function safeParseStructuredOutput<T>(
output: string,
value: unknown,
outputSchema: ZodType<T>
): SafeParseResult<T> {
if (!value || typeof value !== 'string') {
return {
success: false,
error: 'Invalid output: expected string'
}
}
const output = value as string
try {
const data = parseStructuredOutput<T>(output, outputSchema)
return {
@ -72,7 +74,7 @@ export function safeParseStructuredOutput<T>(
data
}
} catch (err: any) {
console.error(err)
// console.error(err)
return {
success: false,
@ -179,18 +181,16 @@ const BOOLEAN_OUTPUTS: Record<string, boolean> = {
* @param output - string to parse
* @returns parsed array
*/
export function parseArrayOutput(output: string): Array<any> {
export function parseArrayOutput(output: string): JsonValue[] {
try {
const arrayOutput = extractJSONFromString(output, 'array')
if (arrayOutput.length === 0) {
throw new ParseError(`Invalid JSON array: ${output}`)
throw new ParseError('Invalid JSON array')
}
const parsedOutput = arrayOutput[0]
if (!Array.isArray(parsedOutput)) {
throw new ParseError(
`Invalid JSON array: ${JSON.stringify(parsedOutput)}`
)
throw new ParseError('Expected JSON array')
}
return parsedOutput
@ -211,24 +211,24 @@ export function parseArrayOutput(output: string): Array<any> {
* @param output - string to parse
* @returns parsed object
*/
export function parseObjectOutput(output: string) {
export function parseObjectOutput(output: string): JsonObject {
try {
const arrayOutput = extractJSONFromString(output, 'object')
if (arrayOutput.length === 0) {
throw new ParseError(`Invalid JSON object: ${output}`)
throw new ParseError('Invalid JSON object')
}
let parsedOutput = arrayOutput[0]
if (Array.isArray(parsedOutput)) {
// TODO
parsedOutput = parsedOutput[0]
} else if (typeof parsedOutput !== 'object') {
throw new ParseError(
`Invalid JSON object: ${JSON.stringify(parsedOutput)}`
)
}
return parsedOutput
if (!parsedOutput || typeof parsedOutput !== 'object') {
throw new ParseError('Expected JSON object')
}
return parsedOutput as JsonObject
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new ParseError(err.message, { cause: err })

119
src/schema.ts 100644
Wyświetl plik

@ -0,0 +1,119 @@
import type { z } from 'zod'
import type * as types from './types.js'
import { safeParseStructuredOutput } from './parse-structured-output.js'
import { stringifyForModel } from './utils.js'
import { zodToJsonSchema } from './zod-to-json-schema.js'
/**
* Used to mark schemas so we can support both Zod and custom schemas.
*/
export const schemaSymbol = Symbol('agentic..schema')
export const validatorSymbol = Symbol('agentic.validator')
export type Schema<TData = unknown> = {
/**
* The JSON Schema.
*/
readonly jsonSchema: types.JSONSchema
/**
* Optional. Validates that the structure of a value matches this schema,
* and returns a typed version of the value if it does.
*/
readonly validate?: types.ValidatorFn<TData>
/**
* Used to mark schemas so we can support both Zod and custom schemas.
*/
[schemaSymbol]: true
/**
* Schema type for inference.
*/
_type: TData
}
export function isSchema(value: unknown): value is Schema {
return (
typeof value === 'object' &&
value !== null &&
schemaSymbol in value &&
value[schemaSymbol] === true &&
'jsonSchema' in value &&
'validate' in value
)
}
export function isZodSchema(value: unknown): value is z.ZodType {
return (
typeof value === 'object' &&
value !== null &&
'_type' in value &&
'_output' in value &&
'_input' in value &&
'_def' in value &&
'parse' in value &&
'safeParse' in value
)
}
export function asSchema<TData>(
schema: z.Schema<TData> | Schema<TData>
): Schema<TData> {
return isSchema(schema) ? schema : createSchemaFromZodSchema(schema)
}
/**
* Create a schema from a JSON Schema.
*/
export function createSchema<TData = unknown>(
jsonSchema: types.JSONSchema,
{
validate
}: {
validate?: types.ValidatorFn<TData>
} = {}
): Schema<TData> {
return {
[schemaSymbol]: true,
_type: undefined as TData,
jsonSchema,
validate
}
}
export function createSchemaFromZodSchema<TData>(
zodSchema: z.Schema<TData>
): Schema<TData> {
return createSchema(zodToJsonSchema(zodSchema), {
validate: (value) => {
return safeParseStructuredOutput(value, zodSchema)
}
})
}
const DEFAULT_SCHEMA_PREFIX = `
---
Respond with JSON using the following JSON schema:
\`\`\`json`
const DEFAULT_SCHEMA_SUFFIX = '```'
export function augmentSystemMessageWithJsonSchema({
schema,
system,
schemaPrefix = DEFAULT_SCHEMA_PREFIX,
schemaSuffix = DEFAULT_SCHEMA_SUFFIX
}: {
schema: types.JSONSchema
system?: string
schemaPrefix?: string
schemaSuffix?: string
}): string {
return [system, schemaPrefix, stringifyForModel(schema), schemaSuffix]
.filter(Boolean)
.join('\n')
.trim()
}

Wyświetl plik

@ -3,6 +3,7 @@ import pThrottle from 'p-throttle'
import z from 'zod'
import { aiFunction, AIFunctionsProvider } from '../fns.js'
import { isZodSchema } from '../schema.js'
import { assert, delay, getEnv, throttleKy } from '../utils.js'
import { zodToJsonSchema } from '../zod-to-json-schema.js'
@ -157,7 +158,7 @@ export class FirecrawlClient extends AIFunctionsProvider {
if (opts?.extractorOptions?.extractionSchema) {
let schema = opts.extractorOptions.extractionSchema
if (schema instanceof z.ZodSchema) {
if (isZodSchema(schema)) {
schema = zodToJsonSchema(schema)
}

Wyświetl plik

@ -1,4 +1,4 @@
import type { Jsonifiable } from 'type-fest'
import type { Jsonifiable, SetOptional, Simplify } from 'type-fest'
import type { z } from 'zod'
import type { AIFunctionSet } from './ai-function-set.js'
@ -6,8 +6,10 @@ import type { AIFunctionsProvider } from './fns.js'
import type { Msg } from './message.js'
export type { Msg } from './message.js'
export type { Schema } from './schema.js'
export type { KyInstance } from 'ky'
export type { ThrottledFunction } from 'p-throttle'
export type { Simplify } from 'type-fest'
export type Nullable<T> = T | null
@ -17,7 +19,15 @@ export type DeepNullable<T> = T extends object
export type MaybePromise<T> = T | Promise<T>
export type RelaxedJsonifiable = Jsonifiable | Record<string, Jsonifiable>
// TODO: use a more specific type
export type JSONSchema = Record<string, unknown>
export type RelaxedJsonifiable =
| Jsonifiable
| Record<string, unknown>
| JSONSchema
export type Context = object
export interface AIFunctionSpec {
/** AI Function name. */
@ -27,7 +37,7 @@ export interface AIFunctionSpec {
description: string
/** JSON schema spec of the function's input parameters */
parameters: Record<string, unknown>
parameters: JSONSchema
}
export interface AIToolSpec {
@ -79,3 +89,54 @@ export interface AIFunction<
// TODO: this `any` shouldn't be necessary, but it is for `createAIFunction` results to be assignable to `AIFunctionLike`
impl: (params: z.infer<InputSchema> | any) => MaybePromise<Return>
}
export interface ChatParams {
messages: Msg[]
model: string & {}
functions?: AIFunctionSpec[]
function_call?: 'none' | 'auto' | { name: string }
tools?: AIToolSpec[]
tool_choice?:
| 'none'
| 'auto'
| 'required'
| { type: 'function'; function: { name: string } }
parallel_tool_calls?: boolean
logit_bias?: Record<string, number>
logprobs?: boolean
max_tokens?: number
presence_penalty?: number
frequency_penalty?: number
response_format?: { type: 'text' | 'json_object' }
seed?: number
stop?: string | null | Array<string>
temperature?: number
top_logprobs?: number
top_p?: number
user?: string
}
/** An OpenAI-compatible chat completions API */
export type ChatFn = (
params: Simplify<SetOptional<ChatParams, 'model'>>
) => Promise<{ message: Msg }>
export type AIChainResult = string | Record<string, any>
export type AIChain<Result extends AIChainResult = string> = (
params:
| string
| Simplify<SetOptional<Omit<ChatParams, 'tools' | 'functions'>, 'model'>>
) => Promise<Result>
export type SafeParseResult<TData> =
| {
success: true
data: TData
}
| {
success: false
error: string
}
export type ValidatorFn<TData> = (value: unknown) => SafeParseResult<TData>

Wyświetl plik

@ -1,4 +1,3 @@
import type { Jsonifiable } from 'type-fest'
import dedent from 'dedent'
import hashObjectImpl, { type Options as HashObjectOptions } from 'hash-object'
@ -167,7 +166,9 @@ export function sanitizeSearchParams(
/**
* Stringifies a JSON value in a way that's optimized for use with LLM prompts.
*/
export function stringifyForModel(jsonObject?: Jsonifiable): string {
export function stringifyForModel(
jsonObject?: types.RelaxedJsonifiable
): string {
if (jsonObject === undefined) {
return ''
}
@ -208,3 +209,24 @@ export function isAIFunction(obj: any): obj is types.AIFunction {
return true
}
export function getErrorMessage(error?: unknown): string {
if (!error) {
return 'unknown error'
}
if (typeof error === 'string') {
return error
}
const message = (error as any).message
if (message && typeof message === 'string') {
return message
}
try {
return JSON.stringify(error)
} catch {
return 'unknown error'
}
}

Wyświetl plik

@ -1,10 +1,11 @@
import type { z } from 'zod'
import { zodToJsonSchema as zodToJsonSchemaImpl } from 'zod-to-json-schema'
import type * as types from './types.js'
import { omit } from './utils.js'
/** Generate a JSON Schema from a Zod schema. */
export function zodToJsonSchema(schema: z.ZodType): Record<string, unknown> {
export function zodToJsonSchema(schema: z.ZodType): types.JSONSchema {
return omit(
zodToJsonSchemaImpl(schema, { $refStrategy: 'none' }),
'$schema',