kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: relax chatFn types
rodzic
f4b79d69b5
commit
f2169f1011
|
@ -16,6 +16,7 @@ async function main() {
|
|||
})
|
||||
|
||||
const chain = createAIChain({
|
||||
name: 'search_news',
|
||||
chatFn: chatModel.run.bind(chatModel),
|
||||
tools: [perigon.functions.pick('search_news_stories'), serper],
|
||||
params: {
|
||||
|
|
|
@ -80,6 +80,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
const functionSet = new AIFunctionSet(tools)
|
||||
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
|
||||
|
||||
// TODO: support custom stopping criteria (like setting a flag in a tool call)
|
||||
|
||||
const defaultParams: Partial<types.ChatParams> | undefined =
|
||||
schema && !functionSet.size
|
||||
? {
|
||||
|
@ -190,10 +192,10 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
throw new AbortError(
|
||||
'Function calls are not supported; expected tool call'
|
||||
)
|
||||
} else if (Msg.isRefusal(message)) {
|
||||
throw new AbortError(`Model refusal: ${message.refusal}`)
|
||||
} else if (Msg.isAssistant(message)) {
|
||||
if (message.refusal) {
|
||||
throw new AbortError(`Model refusal: ${message.refusal}`)
|
||||
} else if (schema && schema.validate) {
|
||||
if (schema && schema.validate) {
|
||||
const result = schema.validate(message.content)
|
||||
|
||||
if (result.success) {
|
||||
|
@ -212,6 +214,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
|||
throw err
|
||||
}
|
||||
|
||||
console.warn(`Chain "${name}" error:`, err.message)
|
||||
|
||||
messages.push(
|
||||
Msg.user(
|
||||
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
|
||||
|
|
|
@ -50,6 +50,15 @@ export interface Msg {
|
|||
name?: string
|
||||
}
|
||||
|
||||
export interface LegacyMsg {
|
||||
content: string | null
|
||||
role: Msg.Role
|
||||
function_call?: Msg.Call.Function
|
||||
tool_calls?: Msg.Call.Tool[]
|
||||
tool_call_id?: string
|
||||
name?: string
|
||||
}
|
||||
|
||||
/** Narrowed OpenAI Message types. */
|
||||
export namespace Msg {
|
||||
/** Possible roles for a message. */
|
||||
|
@ -102,8 +111,14 @@ export namespace Msg {
|
|||
export type Assistant = {
|
||||
role: 'assistant'
|
||||
name?: string
|
||||
content?: string
|
||||
refusal?: string
|
||||
content: string
|
||||
}
|
||||
|
||||
/** Message with refusal reason from the assistant. */
|
||||
export type Refusal = {
|
||||
role: 'assistant'
|
||||
name?: string
|
||||
refusal: string
|
||||
}
|
||||
|
||||
/** Message with arguments to call a function. */
|
||||
|
@ -193,6 +208,27 @@ export namespace Msg {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an assistant refusal message. Cleans indentation and newlines by
|
||||
* default.
|
||||
*/
|
||||
export function refusal(
|
||||
refusal: string,
|
||||
opts?: {
|
||||
/** Custom name for the message. */
|
||||
name?: string
|
||||
/** Whether to clean extra newlines and indentation. Defaults to true. */
|
||||
cleanRefusal?: boolean
|
||||
}
|
||||
): Msg.Refusal {
|
||||
const { name, cleanRefusal = true } = opts ?? {}
|
||||
return {
|
||||
role: 'assistant',
|
||||
refusal: cleanRefusal ? cleanStringForModel(refusal) : refusal,
|
||||
...(name ? { name } : {})
|
||||
}
|
||||
}
|
||||
|
||||
/** Create a function call message with argumets. */
|
||||
export function funcCall(
|
||||
function_call: {
|
||||
|
@ -257,7 +293,7 @@ export namespace Msg {
|
|||
// @TODO
|
||||
response: any
|
||||
// response: ChatModel.EnrichedResponse
|
||||
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
|
||||
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
|
||||
const msg = response.choices[0].message as Msg
|
||||
return narrowResponseMessage(msg)
|
||||
}
|
||||
|
@ -265,13 +301,15 @@ export namespace Msg {
|
|||
/** Narrow a message received from the API. It only responds with role=assistant */
|
||||
export function narrowResponseMessage(
|
||||
msg: Msg
|
||||
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
|
||||
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
|
||||
if (msg.content === null && msg.tool_calls != null) {
|
||||
return Msg.toolCall(msg.tool_calls)
|
||||
} else if (msg.content === null && msg.function_call != null) {
|
||||
return Msg.funcCall(msg.function_call)
|
||||
} else if (msg.content !== null && msg.content !== undefined) {
|
||||
return Msg.assistant(msg.content)
|
||||
} else if (msg.refusal != null) {
|
||||
return Msg.refusal(msg.refusal)
|
||||
} else {
|
||||
// @TODO: probably don't want to error here
|
||||
console.log('Invalid message', msg)
|
||||
|
@ -291,6 +329,10 @@ export namespace Msg {
|
|||
export function isAssistant(message: Msg): message is Msg.Assistant {
|
||||
return message.role === 'assistant' && message.content !== null
|
||||
}
|
||||
/** Check if a message is an assistant refusal message. */
|
||||
export function isRefusal(message: Msg): message is Msg.Refusal {
|
||||
return message.role === 'assistant' && message.refusal !== null
|
||||
}
|
||||
/** Check if a message is a function call message with arguments. */
|
||||
export function isFuncCall(message: Msg): message is Msg.FuncCall {
|
||||
return message.role === 'assistant' && message.function_call != null
|
||||
|
@ -312,6 +354,7 @@ export namespace Msg {
|
|||
export function narrow(message: Msg.System): Msg.System
|
||||
export function narrow(message: Msg.User): Msg.User
|
||||
export function narrow(message: Msg.Assistant): Msg.Assistant
|
||||
export function narrow(message: Msg.Assistant): Msg.Refusal
|
||||
export function narrow(message: Msg.FuncCall): Msg.FuncCall
|
||||
export function narrow(message: Msg.FuncResult): Msg.FuncResult
|
||||
export function narrow(message: Msg.ToolCall): Msg.ToolCall
|
||||
|
@ -322,6 +365,7 @@ export namespace Msg {
|
|||
| Msg.System
|
||||
| Msg.User
|
||||
| Msg.Assistant
|
||||
| Msg.Refusal
|
||||
| Msg.FuncCall
|
||||
| Msg.FuncResult
|
||||
| Msg.ToolCall
|
||||
|
@ -335,6 +379,9 @@ export namespace Msg {
|
|||
if (isAssistant(message)) {
|
||||
return message
|
||||
}
|
||||
if (isRefusal(message)) {
|
||||
return message
|
||||
}
|
||||
if (isFuncCall(message)) {
|
||||
return message
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ import type { z } from 'zod'
|
|||
|
||||
import type { AIFunctionSet } from './ai-function-set'
|
||||
import type { AIFunctionsProvider } from './fns'
|
||||
import type { Msg } from './message'
|
||||
import type { LegacyMsg, Msg } from './message'
|
||||
|
||||
export type { Msg } from './message'
|
||||
export type { Schema } from './schema'
|
||||
|
@ -130,6 +130,10 @@ export interface ChatParams {
|
|||
user?: string
|
||||
}
|
||||
|
||||
export type LegacyChatParams = Simplify<
|
||||
Omit<ChatParams, 'messages'> & { messages: LegacyMsg[] }
|
||||
>
|
||||
|
||||
export interface ResponseFormatJSONSchema {
|
||||
/**
|
||||
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
|
||||
|
@ -158,10 +162,21 @@ export interface ResponseFormatJSONSchema {
|
|||
strict?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI has changed some of their types, so instead of trying to support all
|
||||
* possible types, for these params, just relax them for now.
|
||||
*/
|
||||
export type RelaxedChatParams = Simplify<
|
||||
Omit<ChatParams, 'messages' | 'response_format'> & {
|
||||
messages: object[]
|
||||
response_format?: { type: 'text' | 'json_object' | string }
|
||||
}
|
||||
>
|
||||
|
||||
/** An OpenAI-compatible chat completions API */
|
||||
export type ChatFn = (
|
||||
params: Simplify<SetOptional<ChatParams, 'model'>>
|
||||
) => Promise<{ message: Msg }>
|
||||
params: Simplify<SetOptional<RelaxedChatParams, 'model'>>
|
||||
) => Promise<{ message: Msg | LegacyMsg }>
|
||||
|
||||
export type AIChainResult = string | Record<string, any>
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue