kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add support for hook priorities and hook return values
rodzic
dbd8603a74
commit
603e37f49d
|
@ -1,3 +1,4 @@
|
||||||
export const DEFAULT_OPENAI_MODEL = 'gpt-3.5-turbo'
|
export const DEFAULT_OPENAI_MODEL = 'gpt-3.5-turbo'
|
||||||
export const DEFAULT_ANTHROPIC_MODEL = 'claude-instant-v1'
|
export const DEFAULT_ANTHROPIC_MODEL = 'claude-instant-v1'
|
||||||
export const DEFAULT_BOT_NAME = 'Agentic Bot'
|
export const DEFAULT_BOT_NAME = 'Agentic Bot'
|
||||||
|
export const SKIP_HOOKS = Symbol('SKIP_HOOKS')
|
||||||
|
|
|
@ -4,6 +4,7 @@ import { ZodType } from 'zod'
|
||||||
import * as errors from './errors'
|
import * as errors from './errors'
|
||||||
import * as types from './types'
|
import * as types from './types'
|
||||||
import type { Agentic } from './agentic'
|
import type { Agentic } from './agentic'
|
||||||
|
import { SKIP_HOOKS } from './constants'
|
||||||
import {
|
import {
|
||||||
HumanFeedbackMechanismCLI,
|
HumanFeedbackMechanismCLI,
|
||||||
HumanFeedbackOptions,
|
HumanFeedbackOptions,
|
||||||
|
@ -34,8 +35,16 @@ export abstract class BaseTask<
|
||||||
protected _timeoutMs?: number
|
protected _timeoutMs?: number
|
||||||
protected _retryConfig: types.RetryConfig
|
protected _retryConfig: types.RetryConfig
|
||||||
|
|
||||||
private _preHooks: Array<types.TaskBeforeCallHook<TInput>> = []
|
private _preHooks: Array<{
|
||||||
private _postHooks: Array<types.TaskAfterCallHook<TInput, TOutput>> = []
|
hook: types.TaskBeforeCallHook<TInput>
|
||||||
|
priority: number
|
||||||
|
name: string
|
||||||
|
}> = []
|
||||||
|
private _postHooks: Array<{
|
||||||
|
hook: types.TaskAfterCallHook<TInput, TOutput>
|
||||||
|
priority: number
|
||||||
|
name: string
|
||||||
|
}> = []
|
||||||
|
|
||||||
constructor(options: types.BaseTaskOptions = {}) {
|
constructor(options: types.BaseTaskOptions = {}) {
|
||||||
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
|
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
|
||||||
|
@ -82,15 +91,95 @@ export abstract class BaseTask<
|
||||||
return ''
|
return ''
|
||||||
}
|
}
|
||||||
|
|
||||||
public addBeforeCallHook(hook: types.TaskBeforeCallHook<TInput>): this {
|
/**
|
||||||
this._preHooks.push(hook)
|
* Adds a hook to be called before the task is invoked.
|
||||||
|
*
|
||||||
|
* @param hook - function to be called before the task is invoked
|
||||||
|
* @param options - options for the hook; `priority` is used to determine the order in which hooks are called, with higher priority hooks being called first, and `name` is used to identify the hook
|
||||||
|
*/
|
||||||
|
public addBeforeCallHook(
|
||||||
|
hook: types.TaskBeforeCallHook<TInput>,
|
||||||
|
{ priority = 0, name }: { priority?: number; name?: string } = {}
|
||||||
|
): this {
|
||||||
|
const hookName = name ?? `preHook_${this._preHooks.length}`
|
||||||
|
this._preHooks.push({ hook, priority, name: hookName })
|
||||||
|
this._preHooks.sort((a, b) => b.priority - a.priority)
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a hook to be called after the task is invoked.
|
||||||
|
*
|
||||||
|
* @param hook - function to be called after the task is invoked
|
||||||
|
* @param options - options for the hook; `priority` is used to determine the order in which hooks are called, with higher priority hooks being called first, and `name` is used to identify the hook
|
||||||
|
*/
|
||||||
public addAfterCallHook(
|
public addAfterCallHook(
|
||||||
hook: types.TaskAfterCallHook<TInput, TOutput>
|
hook: types.TaskAfterCallHook<TInput, TOutput>,
|
||||||
|
{ priority = 0, name }: { priority?: number; name?: string } = {}
|
||||||
): this {
|
): this {
|
||||||
this._postHooks.push(hook)
|
const hookName = name ?? `postHook_${this._postHooks.length}`
|
||||||
|
this._postHooks.push({ hook, priority, name: hookName })
|
||||||
|
this._postHooks.sort((a, b) => b.priority - a.priority)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Changes the priority of a before call hook.
|
||||||
|
*
|
||||||
|
* @param hookType - `before`
|
||||||
|
* @param hookOrName - hook or the name of the hook to change the priority of
|
||||||
|
* @param newPriority - new priority of the hook
|
||||||
|
*/
|
||||||
|
public changeHookPriority(
|
||||||
|
hookType: 'before',
|
||||||
|
hookOrName: types.TaskBeforeCallHook<TInput> | string,
|
||||||
|
newPriority: number
|
||||||
|
): this
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Changes the priority of a after call hook.
|
||||||
|
*
|
||||||
|
* @param hookType - `after`
|
||||||
|
* @param hookOrName - hook or the name of the hook to change the priority of
|
||||||
|
* @param newPriority - new priority of the hook
|
||||||
|
*/
|
||||||
|
public changeHookPriority(
|
||||||
|
hookType: 'after',
|
||||||
|
hookOrName: types.TaskAfterCallHook<TInput, TOutput> | string,
|
||||||
|
newPriority: number
|
||||||
|
): this
|
||||||
|
|
||||||
|
public changeHookPriority(
|
||||||
|
hookType: 'before' | 'after',
|
||||||
|
hookOrName:
|
||||||
|
| types.TaskBeforeCallHook<TInput>
|
||||||
|
| types.TaskAfterCallHook<TInput, TOutput>
|
||||||
|
| string,
|
||||||
|
newPriority: number
|
||||||
|
): this {
|
||||||
|
const hooks = hookType === 'before' ? this._preHooks : this._postHooks
|
||||||
|
|
||||||
|
if (typeof hookOrName === 'string') {
|
||||||
|
const hookObj = hooks.find((h) => h.name === hookOrName)
|
||||||
|
if (!hookObj) {
|
||||||
|
throw new Error(
|
||||||
|
`Could not find a ${hookType}-call hook named "${hookOrName}" to change its priority`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
hookObj.priority = newPriority
|
||||||
|
} else {
|
||||||
|
const hookObj = hooks.find((h) => h.hook === hookOrName)
|
||||||
|
if (!hookObj) {
|
||||||
|
throw new Error(
|
||||||
|
`Could not find the provided ${hookType}-call hook to change its priority`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
hookObj.priority = newPriority
|
||||||
|
}
|
||||||
|
|
||||||
|
hooks.sort((a, b) => b.priority - a.priority)
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,10 +229,13 @@ export abstract class BaseTask<
|
||||||
options
|
options
|
||||||
})
|
})
|
||||||
|
|
||||||
this.addAfterCallHook(async (output, ctx) => {
|
this.addAfterCallHook(
|
||||||
const feedback = await feedbackMechanism.interact(output)
|
async (output, ctx) => {
|
||||||
ctx.metadata = { ...ctx.metadata, feedback }
|
const feedback = await feedbackMechanism.interact(output)
|
||||||
})
|
ctx.metadata = { ...ctx.metadata, feedback }
|
||||||
|
},
|
||||||
|
{ name: 'humanFeedback' }
|
||||||
|
)
|
||||||
|
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
@ -194,16 +286,32 @@ export abstract class BaseTask<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const preHook of this._preHooks) {
|
for (const { hook: preHook } of this._preHooks) {
|
||||||
await preHook(ctx)
|
const preHookResult = await preHook(ctx)
|
||||||
|
if (preHookResult === SKIP_HOOKS) {
|
||||||
|
break
|
||||||
|
} else if (preHookResult !== undefined) {
|
||||||
|
const output = this.outputSchema?.safeParse(preHookResult)
|
||||||
|
if (!output?.success) {
|
||||||
|
throw new Error(`Invalid preHook output: ${output?.error.message}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
result: output.data,
|
||||||
|
metadata: ctx.metadata
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await pRetry(
|
const result = await pRetry(
|
||||||
async () => {
|
async () => {
|
||||||
const result = await this._call(ctx)
|
const result = await this._call(ctx)
|
||||||
|
|
||||||
for (const postHook of this._postHooks) {
|
for (const { hook: postHook } of this._postHooks) {
|
||||||
await postHook(result, ctx)
|
const postHookResult = await postHook(result, ctx)
|
||||||
|
if (postHookResult === SKIP_HOOKS) {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -6,6 +6,7 @@ import type { JsonObject, Jsonifiable } from 'type-fest'
|
||||||
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
||||||
|
|
||||||
import type { Agentic } from './agentic'
|
import type { Agentic } from './agentic'
|
||||||
|
import { SKIP_HOOKS } from './constants'
|
||||||
import type {
|
import type {
|
||||||
FeedbackTypeToMetadata,
|
FeedbackTypeToMetadata,
|
||||||
HumanFeedbackType
|
HumanFeedbackType
|
||||||
|
@ -155,11 +156,21 @@ export declare class CancelablePromise<T> extends Promise<T> {
|
||||||
|
|
||||||
// export type ProgressFunction = (partialResponse: ChatMessage) => void
|
// export type ProgressFunction = (partialResponse: ChatMessage) => void
|
||||||
|
|
||||||
export type TaskBeforeCallHook<TInput extends TaskInput = void> = (
|
export type TaskBeforeCallHook<
|
||||||
|
TInput extends TaskInput = void,
|
||||||
|
TOutput extends TaskOutput = string
|
||||||
|
> = (
|
||||||
ctx: TaskCallContext<TInput>
|
ctx: TaskCallContext<TInput>
|
||||||
) => void | Promise<void>
|
) =>
|
||||||
|
| void
|
||||||
|
| TOutput
|
||||||
|
| typeof SKIP_HOOKS
|
||||||
|
| Promise<void | TOutput | typeof SKIP_HOOKS>
|
||||||
|
|
||||||
export type TaskAfterCallHook<
|
export type TaskAfterCallHook<
|
||||||
TInput extends TaskInput = void,
|
TInput extends TaskInput = void,
|
||||||
TOutput extends TaskOutput = string
|
TOutput extends TaskOutput = string
|
||||||
> = (output: TOutput, ctx: TaskCallContext<TInput>) => void | Promise<void>
|
> = (
|
||||||
|
output: TOutput,
|
||||||
|
ctx: TaskCallContext<TInput>
|
||||||
|
) => void | typeof SKIP_HOOKS | Promise<void | typeof SKIP_HOOKS>
|
||||||
|
|
Ładowanie…
Reference in New Issue