kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
Merge pull request #26 from transitive-bullshit/feature/hook-priorities
WIP feat: add support for hook priorities and return valuesold-agentic-v1^2
commit
63f9536eb9
|
@ -1,3 +1,4 @@
|
|||
export const DEFAULT_OPENAI_MODEL = 'gpt-3.5-turbo'
|
||||
export const DEFAULT_ANTHROPIC_MODEL = 'claude-instant-v1'
|
||||
export const DEFAULT_BOT_NAME = 'Agentic Bot'
|
||||
export const SKIP_HOOKS = Symbol('SKIP_HOOKS')
|
||||
|
|
|
@ -4,6 +4,7 @@ import { ZodType } from 'zod'
|
|||
import * as errors from './errors'
|
||||
import * as types from './types'
|
||||
import type { Agentic } from './agentic'
|
||||
import { SKIP_HOOKS } from './constants'
|
||||
import {
|
||||
HumanFeedbackMechanismCLI,
|
||||
HumanFeedbackOptions,
|
||||
|
@ -34,8 +35,14 @@ export abstract class BaseTask<
|
|||
protected _timeoutMs?: number
|
||||
protected _retryConfig: types.RetryConfig
|
||||
|
||||
private _preHooks: Array<types.TaskBeforeCallHook<TInput>> = []
|
||||
private _postHooks: Array<types.TaskAfterCallHook<TInput, TOutput>> = []
|
||||
private _preHooks: Array<{
|
||||
hook: types.TaskBeforeCallHook<TInput>
|
||||
priority: number
|
||||
}> = []
|
||||
private _postHooks: Array<{
|
||||
hook: types.TaskAfterCallHook<TInput, TOutput>
|
||||
priority: number
|
||||
}> = []
|
||||
|
||||
constructor(options: types.BaseTaskOptions = {}) {
|
||||
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
|
||||
|
@ -82,15 +89,33 @@ export abstract class BaseTask<
|
|||
return ''
|
||||
}
|
||||
|
||||
public addBeforeCallHook(hook: types.TaskBeforeCallHook<TInput>): this {
|
||||
this._preHooks.push(hook)
|
||||
/**
|
||||
* Adds a hook to be called before the task is invoked.
|
||||
*
|
||||
* @param hook - function to be called before the task is invoked
|
||||
* @param priority - priority of the hook; higher priority hooks are called first
|
||||
*/
|
||||
public addBeforeCallHook(
|
||||
hook: types.TaskBeforeCallHook<TInput>,
|
||||
priority = 0
|
||||
): this {
|
||||
this._preHooks.push({ hook, priority })
|
||||
this._preHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a hook to be called after the task is invoked.
|
||||
*
|
||||
* @param hook - function to be called after the task is invoked
|
||||
* @param priority - priority of the hook; higher priority hooks are called first
|
||||
*/
|
||||
public addAfterCallHook(
|
||||
hook: types.TaskAfterCallHook<TInput, TOutput>
|
||||
hook: types.TaskAfterCallHook<TInput, TOutput>,
|
||||
priority = 0
|
||||
): this {
|
||||
this._postHooks.push(hook)
|
||||
this._postHooks.push({ hook, priority })
|
||||
this._postHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019)
|
||||
return this
|
||||
}
|
||||
|
||||
|
@ -202,16 +227,36 @@ export abstract class BaseTask<
|
|||
}
|
||||
}
|
||||
|
||||
for (const preHook of this._preHooks) {
|
||||
await preHook(ctx)
|
||||
for (const { hook: preHook } of this._preHooks) {
|
||||
const preHookResult = await preHook(ctx)
|
||||
if (preHookResult === SKIP_HOOKS) {
|
||||
break
|
||||
} else if (preHookResult !== undefined) {
|
||||
const output = this.outputSchema?.safeParse(preHookResult)
|
||||
if (!output?.success) {
|
||||
throw new Error(`Invalid preHook output: ${output?.error.message}`)
|
||||
}
|
||||
|
||||
ctx.metadata.success = true
|
||||
ctx.metadata.numRetries = ctx.attemptNumber
|
||||
ctx.metadata.error = undefined
|
||||
|
||||
return {
|
||||
result: output.data,
|
||||
metadata: ctx.metadata
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const result = await pRetry(
|
||||
async () => {
|
||||
const result = await this._call(ctx)
|
||||
|
||||
for (const postHook of this._postHooks) {
|
||||
await postHook(result, ctx)
|
||||
for (const { hook: postHook } of this._postHooks) {
|
||||
const postHookResult = await postHook(result, ctx)
|
||||
if (postHookResult === SKIP_HOOKS) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
|
|
@ -6,6 +6,7 @@ import type { JsonObject, Jsonifiable } from 'type-fest'
|
|||
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
||||
|
||||
import type { Agentic } from './agentic'
|
||||
import { SKIP_HOOKS } from './constants'
|
||||
import type {
|
||||
FeedbackTypeToMetadata,
|
||||
HumanFeedbackType
|
||||
|
@ -173,11 +174,21 @@ export declare class CancelablePromise<T> extends Promise<T> {
|
|||
|
||||
// export type ProgressFunction = (partialResponse: ChatMessage) => void
|
||||
|
||||
export type TaskBeforeCallHook<TInput extends TaskInput = void> = (
|
||||
export type TaskBeforeCallHook<
|
||||
TInput extends TaskInput = void,
|
||||
TOutput extends TaskOutput = string
|
||||
> = (
|
||||
ctx: TaskCallContext<TInput>
|
||||
) => void | Promise<void>
|
||||
) =>
|
||||
| void
|
||||
| TOutput
|
||||
| typeof SKIP_HOOKS
|
||||
| Promise<void | TOutput | typeof SKIP_HOOKS>
|
||||
|
||||
export type TaskAfterCallHook<
|
||||
TInput extends TaskInput = void,
|
||||
TOutput extends TaskOutput = string
|
||||
> = (output: TOutput, ctx: TaskCallContext<TInput>) => void | Promise<void>
|
||||
> = (
|
||||
output: TOutput,
|
||||
ctx: TaskCallContext<TInput>
|
||||
) => void | typeof SKIP_HOOKS | Promise<void | typeof SKIP_HOOKS>
|
||||
|
|
Ładowanie…
Reference in New Issue