kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: relax chatFn types
rodzic
965d235f8c
commit
eae1356112
|
@ -16,6 +16,7 @@ async function main() {
|
||||||
})
|
})
|
||||||
|
|
||||||
const chain = createAIChain({
|
const chain = createAIChain({
|
||||||
|
name: 'search_news',
|
||||||
chatFn: chatModel.run.bind(chatModel),
|
chatFn: chatModel.run.bind(chatModel),
|
||||||
tools: [perigon.functions.pick('search_news_stories'), serper],
|
tools: [perigon.functions.pick('search_news_stories'), serper],
|
||||||
params: {
|
params: {
|
||||||
|
|
|
@ -80,6 +80,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
const functionSet = new AIFunctionSet(tools)
|
const functionSet = new AIFunctionSet(tools)
|
||||||
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
|
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined
|
||||||
|
|
||||||
|
// TODO: support custom stopping criteria (like setting a flag in a tool call)
|
||||||
|
|
||||||
const defaultParams: Partial<types.ChatParams> | undefined =
|
const defaultParams: Partial<types.ChatParams> | undefined =
|
||||||
schema && !functionSet.size
|
schema && !functionSet.size
|
||||||
? {
|
? {
|
||||||
|
@ -190,10 +192,10 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
throw new AbortError(
|
throw new AbortError(
|
||||||
'Function calls are not supported; expected tool call'
|
'Function calls are not supported; expected tool call'
|
||||||
)
|
)
|
||||||
} else if (Msg.isAssistant(message)) {
|
} else if (Msg.isRefusal(message)) {
|
||||||
if (message.refusal) {
|
|
||||||
throw new AbortError(`Model refusal: ${message.refusal}`)
|
throw new AbortError(`Model refusal: ${message.refusal}`)
|
||||||
} else if (schema && schema.validate) {
|
} else if (Msg.isAssistant(message)) {
|
||||||
|
if (schema && schema.validate) {
|
||||||
const result = schema.validate(message.content)
|
const result = schema.validate(message.content)
|
||||||
|
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
|
@ -212,6 +214,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
console.warn(`Chain "${name}" error:`, err.message)
|
||||||
|
|
||||||
messages.push(
|
messages.push(
|
||||||
Msg.user(
|
Msg.user(
|
||||||
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
|
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
|
||||||
|
|
|
@ -50,6 +50,15 @@ export interface Msg {
|
||||||
name?: string
|
name?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface LegacyMsg {
|
||||||
|
content: string | null
|
||||||
|
role: Msg.Role
|
||||||
|
function_call?: Msg.Call.Function
|
||||||
|
tool_calls?: Msg.Call.Tool[]
|
||||||
|
tool_call_id?: string
|
||||||
|
name?: string
|
||||||
|
}
|
||||||
|
|
||||||
/** Narrowed OpenAI Message types. */
|
/** Narrowed OpenAI Message types. */
|
||||||
export namespace Msg {
|
export namespace Msg {
|
||||||
/** Possible roles for a message. */
|
/** Possible roles for a message. */
|
||||||
|
@ -102,8 +111,14 @@ export namespace Msg {
|
||||||
export type Assistant = {
|
export type Assistant = {
|
||||||
role: 'assistant'
|
role: 'assistant'
|
||||||
name?: string
|
name?: string
|
||||||
content?: string
|
content: string
|
||||||
refusal?: string
|
}
|
||||||
|
|
||||||
|
/** Message with refusal reason from the assistant. */
|
||||||
|
export type Refusal = {
|
||||||
|
role: 'assistant'
|
||||||
|
name?: string
|
||||||
|
refusal: string
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Message with arguments to call a function. */
|
/** Message with arguments to call a function. */
|
||||||
|
@ -193,6 +208,27 @@ export namespace Msg {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an assistant refusal message. Cleans indentation and newlines by
|
||||||
|
* default.
|
||||||
|
*/
|
||||||
|
export function refusal(
|
||||||
|
refusal: string,
|
||||||
|
opts?: {
|
||||||
|
/** Custom name for the message. */
|
||||||
|
name?: string
|
||||||
|
/** Whether to clean extra newlines and indentation. Defaults to true. */
|
||||||
|
cleanRefusal?: boolean
|
||||||
|
}
|
||||||
|
): Msg.Refusal {
|
||||||
|
const { name, cleanRefusal = true } = opts ?? {}
|
||||||
|
return {
|
||||||
|
role: 'assistant',
|
||||||
|
refusal: cleanRefusal ? cleanStringForModel(refusal) : refusal,
|
||||||
|
...(name ? { name } : {})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Create a function call message with argumets. */
|
/** Create a function call message with argumets. */
|
||||||
export function funcCall(
|
export function funcCall(
|
||||||
function_call: {
|
function_call: {
|
||||||
|
@ -257,7 +293,7 @@ export namespace Msg {
|
||||||
// @TODO
|
// @TODO
|
||||||
response: any
|
response: any
|
||||||
// response: ChatModel.EnrichedResponse
|
// response: ChatModel.EnrichedResponse
|
||||||
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
|
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
|
||||||
const msg = response.choices[0].message as Msg
|
const msg = response.choices[0].message as Msg
|
||||||
return narrowResponseMessage(msg)
|
return narrowResponseMessage(msg)
|
||||||
}
|
}
|
||||||
|
@ -265,13 +301,15 @@ export namespace Msg {
|
||||||
/** Narrow a message received from the API. It only responds with role=assistant */
|
/** Narrow a message received from the API. It only responds with role=assistant */
|
||||||
export function narrowResponseMessage(
|
export function narrowResponseMessage(
|
||||||
msg: Msg
|
msg: Msg
|
||||||
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
|
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
|
||||||
if (msg.content === null && msg.tool_calls != null) {
|
if (msg.content === null && msg.tool_calls != null) {
|
||||||
return Msg.toolCall(msg.tool_calls)
|
return Msg.toolCall(msg.tool_calls)
|
||||||
} else if (msg.content === null && msg.function_call != null) {
|
} else if (msg.content === null && msg.function_call != null) {
|
||||||
return Msg.funcCall(msg.function_call)
|
return Msg.funcCall(msg.function_call)
|
||||||
} else if (msg.content !== null && msg.content !== undefined) {
|
} else if (msg.content !== null && msg.content !== undefined) {
|
||||||
return Msg.assistant(msg.content)
|
return Msg.assistant(msg.content)
|
||||||
|
} else if (msg.refusal != null) {
|
||||||
|
return Msg.refusal(msg.refusal)
|
||||||
} else {
|
} else {
|
||||||
// @TODO: probably don't want to error here
|
// @TODO: probably don't want to error here
|
||||||
console.log('Invalid message', msg)
|
console.log('Invalid message', msg)
|
||||||
|
@ -291,6 +329,10 @@ export namespace Msg {
|
||||||
export function isAssistant(message: Msg): message is Msg.Assistant {
|
export function isAssistant(message: Msg): message is Msg.Assistant {
|
||||||
return message.role === 'assistant' && message.content !== null
|
return message.role === 'assistant' && message.content !== null
|
||||||
}
|
}
|
||||||
|
/** Check if a message is an assistant refusal message. */
|
||||||
|
export function isRefusal(message: Msg): message is Msg.Refusal {
|
||||||
|
return message.role === 'assistant' && message.refusal !== null
|
||||||
|
}
|
||||||
/** Check if a message is a function call message with arguments. */
|
/** Check if a message is a function call message with arguments. */
|
||||||
export function isFuncCall(message: Msg): message is Msg.FuncCall {
|
export function isFuncCall(message: Msg): message is Msg.FuncCall {
|
||||||
return message.role === 'assistant' && message.function_call != null
|
return message.role === 'assistant' && message.function_call != null
|
||||||
|
@ -312,6 +354,7 @@ export namespace Msg {
|
||||||
export function narrow(message: Msg.System): Msg.System
|
export function narrow(message: Msg.System): Msg.System
|
||||||
export function narrow(message: Msg.User): Msg.User
|
export function narrow(message: Msg.User): Msg.User
|
||||||
export function narrow(message: Msg.Assistant): Msg.Assistant
|
export function narrow(message: Msg.Assistant): Msg.Assistant
|
||||||
|
export function narrow(message: Msg.Assistant): Msg.Refusal
|
||||||
export function narrow(message: Msg.FuncCall): Msg.FuncCall
|
export function narrow(message: Msg.FuncCall): Msg.FuncCall
|
||||||
export function narrow(message: Msg.FuncResult): Msg.FuncResult
|
export function narrow(message: Msg.FuncResult): Msg.FuncResult
|
||||||
export function narrow(message: Msg.ToolCall): Msg.ToolCall
|
export function narrow(message: Msg.ToolCall): Msg.ToolCall
|
||||||
|
@ -322,6 +365,7 @@ export namespace Msg {
|
||||||
| Msg.System
|
| Msg.System
|
||||||
| Msg.User
|
| Msg.User
|
||||||
| Msg.Assistant
|
| Msg.Assistant
|
||||||
|
| Msg.Refusal
|
||||||
| Msg.FuncCall
|
| Msg.FuncCall
|
||||||
| Msg.FuncResult
|
| Msg.FuncResult
|
||||||
| Msg.ToolCall
|
| Msg.ToolCall
|
||||||
|
@ -335,6 +379,9 @@ export namespace Msg {
|
||||||
if (isAssistant(message)) {
|
if (isAssistant(message)) {
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
if (isRefusal(message)) {
|
||||||
|
return message
|
||||||
|
}
|
||||||
if (isFuncCall(message)) {
|
if (isFuncCall(message)) {
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ import type { z } from 'zod'
|
||||||
|
|
||||||
import type { AIFunctionSet } from './ai-function-set'
|
import type { AIFunctionSet } from './ai-function-set'
|
||||||
import type { AIFunctionsProvider } from './fns'
|
import type { AIFunctionsProvider } from './fns'
|
||||||
import type { Msg } from './message'
|
import type { LegacyMsg, Msg } from './message'
|
||||||
|
|
||||||
export type { Msg } from './message'
|
export type { Msg } from './message'
|
||||||
export type { Schema } from './schema'
|
export type { Schema } from './schema'
|
||||||
|
@ -130,6 +130,10 @@ export interface ChatParams {
|
||||||
user?: string
|
user?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type LegacyChatParams = Simplify<
|
||||||
|
Omit<ChatParams, 'messages'> & { messages: LegacyMsg[] }
|
||||||
|
>
|
||||||
|
|
||||||
export interface ResponseFormatJSONSchema {
|
export interface ResponseFormatJSONSchema {
|
||||||
/**
|
/**
|
||||||
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
|
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
|
||||||
|
@ -158,10 +162,21 @@ export interface ResponseFormatJSONSchema {
|
||||||
strict?: boolean
|
strict?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpenAI has changed some of their types, so instead of trying to support all
|
||||||
|
* possible types, for these params, just relax them for now.
|
||||||
|
*/
|
||||||
|
export type RelaxedChatParams = Simplify<
|
||||||
|
Omit<ChatParams, 'messages' | 'response_format'> & {
|
||||||
|
messages: object[]
|
||||||
|
response_format?: { type: 'text' | 'json_object' | string }
|
||||||
|
}
|
||||||
|
>
|
||||||
|
|
||||||
/** An OpenAI-compatible chat completions API */
|
/** An OpenAI-compatible chat completions API */
|
||||||
export type ChatFn = (
|
export type ChatFn = (
|
||||||
params: Simplify<SetOptional<ChatParams, 'model'>>
|
params: Simplify<SetOptional<RelaxedChatParams, 'model'>>
|
||||||
) => Promise<{ message: Msg }>
|
) => Promise<{ message: Msg | LegacyMsg }>
|
||||||
|
|
||||||
export type AIChainResult = string | Record<string, any>
|
export type AIChainResult = string | Record<string, any>
|
||||||
|
|
||||||
|
|
Ładowanie…
Reference in New Issue