feat: relax chatFn types

pull/658/head
Travis Fischer 2024-08-07 05:53:39 -05:00
rodzic f4b79d69b5
commit f2169f1011
4 zmienionych plików z 77 dodań i 10 usunięć

Wyświetl plik

@ -16,6 +16,7 @@ async function main() {
})
const chain = createAIChain({
name: 'search_news',
chatFn: chatModel.run.bind(chatModel),
tools: [perigon.functions.pick('search_news_stories'), serper],
params: {

Wyświetl plik

@ -80,6 +80,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
const functionSet = new AIFunctionSet(tools)
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
// TODO: support custom stopping criteria (like setting a flag in a tool call)
const defaultParams: Partial<types.ChatParams> | undefined =
schema && !functionSet.size
? {
@ -190,10 +192,10 @@ export function createAIChain<Result extends types.AIChainResult = string>({
throw new AbortError(
'Function calls are not supported; expected tool call'
)
} else if (Msg.isRefusal(message)) {
throw new AbortError(`Model refusal: ${message.refusal}`)
} else if (Msg.isAssistant(message)) {
if (message.refusal) {
throw new AbortError(`Model refusal: ${message.refusal}`)
} else if (schema && schema.validate) {
if (schema && schema.validate) {
const result = schema.validate(message.content)
if (result.success) {
@ -212,6 +214,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
throw err
}
console.warn(`Chain "${name}" error:`, err.message)
messages.push(
Msg.user(
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`

Wyświetl plik

@ -50,6 +50,15 @@ export interface Msg {
name?: string
}
export interface LegacyMsg {
content: string | null
role: Msg.Role
function_call?: Msg.Call.Function
tool_calls?: Msg.Call.Tool[]
tool_call_id?: string
name?: string
}
/** Narrowed OpenAI Message types. */
export namespace Msg {
/** Possible roles for a message. */
@ -102,8 +111,14 @@ export namespace Msg {
export type Assistant = {
role: 'assistant'
name?: string
content?: string
refusal?: string
content: string
}
/** Message with refusal reason from the assistant. */
export type Refusal = {
role: 'assistant'
name?: string
refusal: string
}
/** Message with arguments to call a function. */
@ -193,6 +208,27 @@ export namespace Msg {
}
}
/**
* Create an assistant refusal message. Cleans indentation and newlines by
* default.
*/
export function refusal(
refusal: string,
opts?: {
/** Custom name for the message. */
name?: string
/** Whether to clean extra newlines and indentation. Defaults to true. */
cleanRefusal?: boolean
}
): Msg.Refusal {
const { name, cleanRefusal = true } = opts ?? {}
return {
role: 'assistant',
refusal: cleanRefusal ? cleanStringForModel(refusal) : refusal,
...(name ? { name } : {})
}
}
/** Create a function call message with argumets. */
export function funcCall(
function_call: {
@ -257,7 +293,7 @@ export namespace Msg {
// @TODO
response: any
// response: ChatModel.EnrichedResponse
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
const msg = response.choices[0].message as Msg
return narrowResponseMessage(msg)
}
@ -265,13 +301,15 @@ export namespace Msg {
/** Narrow a message received from the API. It only responds with role=assistant */
export function narrowResponseMessage(
msg: Msg
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
if (msg.content === null && msg.tool_calls != null) {
return Msg.toolCall(msg.tool_calls)
} else if (msg.content === null && msg.function_call != null) {
return Msg.funcCall(msg.function_call)
} else if (msg.content !== null && msg.content !== undefined) {
return Msg.assistant(msg.content)
} else if (msg.refusal != null) {
return Msg.refusal(msg.refusal)
} else {
// @TODO: probably don't want to error here
console.log('Invalid message', msg)
@ -291,6 +329,10 @@ export namespace Msg {
export function isAssistant(message: Msg): message is Msg.Assistant {
return message.role === 'assistant' && message.content !== null
}
/** Check if a message is an assistant refusal message. */
export function isRefusal(message: Msg): message is Msg.Refusal {
return message.role === 'assistant' && message.refusal !== null
}
/** Check if a message is a function call message with arguments. */
export function isFuncCall(message: Msg): message is Msg.FuncCall {
return message.role === 'assistant' && message.function_call != null
@ -312,6 +354,7 @@ export namespace Msg {
export function narrow(message: Msg.System): Msg.System
export function narrow(message: Msg.User): Msg.User
export function narrow(message: Msg.Assistant): Msg.Assistant
export function narrow(message: Msg.Assistant): Msg.Refusal
export function narrow(message: Msg.FuncCall): Msg.FuncCall
export function narrow(message: Msg.FuncResult): Msg.FuncResult
export function narrow(message: Msg.ToolCall): Msg.ToolCall
@ -322,6 +365,7 @@ export namespace Msg {
| Msg.System
| Msg.User
| Msg.Assistant
| Msg.Refusal
| Msg.FuncCall
| Msg.FuncResult
| Msg.ToolCall
@ -335,6 +379,9 @@ export namespace Msg {
if (isAssistant(message)) {
return message
}
if (isRefusal(message)) {
return message
}
if (isFuncCall(message)) {
return message
}

Wyświetl plik

@ -3,7 +3,7 @@ import type { z } from 'zod'
import type { AIFunctionSet } from './ai-function-set'
import type { AIFunctionsProvider } from './fns'
import type { Msg } from './message'
import type { LegacyMsg, Msg } from './message'
export type { Msg } from './message'
export type { Schema } from './schema'
@ -130,6 +130,10 @@ export interface ChatParams {
user?: string
}
export type LegacyChatParams = Simplify<
Omit<ChatParams, 'messages'> & { messages: LegacyMsg[] }
>
export interface ResponseFormatJSONSchema {
/**
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
@ -158,10 +162,21 @@ export interface ResponseFormatJSONSchema {
strict?: boolean
}
/**
* OpenAI has changed some of their types, so instead of trying to support all
* possible types, for these params, just relax them for now.
*/
export type RelaxedChatParams = Simplify<
Omit<ChatParams, 'messages' | 'response_format'> & {
messages: object[]
response_format?: { type: 'text' | 'json_object' | string }
}
>
/** An OpenAI-compatible chat completions API */
export type ChatFn = (
params: Simplify<SetOptional<ChatParams, 'model'>>
) => Promise<{ message: Msg }>
params: Simplify<SetOptional<RelaxedChatParams, 'model'>>
) => Promise<{ message: Msg | LegacyMsg }>
export type AIChainResult = string | Record<string, any>