diff --git a/legacy/examples/dexter/bin/election-news-chain.ts b/legacy/examples/dexter/bin/election-news-chain.ts index beee439c..3e82830e 100644 --- a/legacy/examples/dexter/bin/election-news-chain.ts +++ b/legacy/examples/dexter/bin/election-news-chain.ts @@ -16,6 +16,7 @@ async function main() { }) const chain = createAIChain({ + name: 'search_news', chatFn: chatModel.run.bind(chatModel), tools: [perigon.functions.pick('search_news_stories'), serper], params: { diff --git a/legacy/packages/core/src/create-ai-chain.ts b/legacy/packages/core/src/create-ai-chain.ts index 0e48cb68..d321416f 100644 --- a/legacy/packages/core/src/create-ai-chain.ts +++ b/legacy/packages/core/src/create-ai-chain.ts @@ -80,6 +80,8 @@ export function createAIChain({ const functionSet = new AIFunctionSet(tools) const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined + // TODO: support custom stopping criteria (like setting a flag in a tool call) + const defaultParams: Partial | undefined = schema && !functionSet.size ? { @@ -190,10 +192,10 @@ export function createAIChain({ throw new AbortError( 'Function calls are not supported; expected tool call' ) + } else if (Msg.isRefusal(message)) { + throw new AbortError(`Model refusal: ${message.refusal}`) } else if (Msg.isAssistant(message)) { - if (message.refusal) { - throw new AbortError(`Model refusal: ${message.refusal}`) - } else if (schema && schema.validate) { + if (schema && schema.validate) { const result = schema.validate(message.content) if (result.success) { @@ -212,6 +214,8 @@ export function createAIChain({ throw err } + console.warn(`Chain "${name}" error:`, err.message) + messages.push( Msg.user( `There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}` diff --git a/legacy/packages/core/src/message.ts b/legacy/packages/core/src/message.ts index d6788440..5570d1aa 100644 --- a/legacy/packages/core/src/message.ts +++ b/legacy/packages/core/src/message.ts @@ -50,6 +50,15 @@ export interface Msg { name?: string } +export interface LegacyMsg { + content: string | null + role: Msg.Role + function_call?: Msg.Call.Function + tool_calls?: Msg.Call.Tool[] + tool_call_id?: string + name?: string +} + /** Narrowed OpenAI Message types. */ export namespace Msg { /** Possible roles for a message. */ @@ -102,8 +111,14 @@ export namespace Msg { export type Assistant = { role: 'assistant' name?: string - content?: string - refusal?: string + content: string + } + + /** Message with refusal reason from the assistant. */ + export type Refusal = { + role: 'assistant' + name?: string + refusal: string } /** Message with arguments to call a function. */ @@ -193,6 +208,27 @@ export namespace Msg { } } + /** + * Create an assistant refusal message. Cleans indentation and newlines by + * default. + */ + export function refusal( + refusal: string, + opts?: { + /** Custom name for the message. */ + name?: string + /** Whether to clean extra newlines and indentation. Defaults to true. */ + cleanRefusal?: boolean + } + ): Msg.Refusal { + const { name, cleanRefusal = true } = opts ?? {} + return { + role: 'assistant', + refusal: cleanRefusal ? cleanStringForModel(refusal) : refusal, + ...(name ? { name } : {}) + } + } + /** Create a function call message with argumets. */ export function funcCall( function_call: { @@ -257,7 +293,7 @@ export namespace Msg { // @TODO response: any // response: ChatModel.EnrichedResponse - ): Msg.Assistant | Msg.FuncCall | Msg.ToolCall { + ): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall { const msg = response.choices[0].message as Msg return narrowResponseMessage(msg) } @@ -265,13 +301,15 @@ export namespace Msg { /** Narrow a message received from the API. It only responds with role=assistant */ export function narrowResponseMessage( msg: Msg - ): Msg.Assistant | Msg.FuncCall | Msg.ToolCall { + ): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall { if (msg.content === null && msg.tool_calls != null) { return Msg.toolCall(msg.tool_calls) } else if (msg.content === null && msg.function_call != null) { return Msg.funcCall(msg.function_call) } else if (msg.content !== null && msg.content !== undefined) { return Msg.assistant(msg.content) + } else if (msg.refusal != null) { + return Msg.refusal(msg.refusal) } else { // @TODO: probably don't want to error here console.log('Invalid message', msg) @@ -291,6 +329,10 @@ export namespace Msg { export function isAssistant(message: Msg): message is Msg.Assistant { return message.role === 'assistant' && message.content !== null } + /** Check if a message is an assistant refusal message. */ + export function isRefusal(message: Msg): message is Msg.Refusal { + return message.role === 'assistant' && message.refusal !== null + } /** Check if a message is a function call message with arguments. */ export function isFuncCall(message: Msg): message is Msg.FuncCall { return message.role === 'assistant' && message.function_call != null @@ -312,6 +354,7 @@ export namespace Msg { export function narrow(message: Msg.System): Msg.System export function narrow(message: Msg.User): Msg.User export function narrow(message: Msg.Assistant): Msg.Assistant + export function narrow(message: Msg.Assistant): Msg.Refusal export function narrow(message: Msg.FuncCall): Msg.FuncCall export function narrow(message: Msg.FuncResult): Msg.FuncResult export function narrow(message: Msg.ToolCall): Msg.ToolCall @@ -322,6 +365,7 @@ export namespace Msg { | Msg.System | Msg.User | Msg.Assistant + | Msg.Refusal | Msg.FuncCall | Msg.FuncResult | Msg.ToolCall @@ -335,6 +379,9 @@ export namespace Msg { if (isAssistant(message)) { return message } + if (isRefusal(message)) { + return message + } if (isFuncCall(message)) { return message } diff --git a/legacy/packages/core/src/types.ts b/legacy/packages/core/src/types.ts index 8c532a8e..b25a6f82 100644 --- a/legacy/packages/core/src/types.ts +++ b/legacy/packages/core/src/types.ts @@ -3,7 +3,7 @@ import type { z } from 'zod' import type { AIFunctionSet } from './ai-function-set' import type { AIFunctionsProvider } from './fns' -import type { Msg } from './message' +import type { LegacyMsg, Msg } from './message' export type { Msg } from './message' export type { Schema } from './schema' @@ -130,6 +130,10 @@ export interface ChatParams { user?: string } +export type LegacyChatParams = Simplify< + Omit & { messages: LegacyMsg[] } +> + export interface ResponseFormatJSONSchema { /** * The name of the response format. Must be a-z, A-Z, 0-9, or contain @@ -158,10 +162,21 @@ export interface ResponseFormatJSONSchema { strict?: boolean } +/** + * OpenAI has changed some of their types, so instead of trying to support all + * possible types, for these params, just relax them for now. + */ +export type RelaxedChatParams = Simplify< + Omit & { + messages: object[] + response_format?: { type: 'text' | 'json_object' | string } + } +> + /** An OpenAI-compatible chat completions API */ export type ChatFn = ( - params: Simplify> -) => Promise<{ message: Msg }> + params: Simplify> +) => Promise<{ message: Msg | LegacyMsg }> export type AIChainResult = string | Record