diff --git a/legacy/src/task.ts b/legacy/src/task.ts index d3d8ce73..0054b6c4 100644 --- a/legacy/src/task.ts +++ b/legacy/src/task.ts @@ -29,10 +29,8 @@ export abstract class BaseTask< protected _timeoutMs?: number protected _retryConfig: types.RetryConfig - private _preHooks: Array<(ctx: types.TaskCallContext) => void> = [] - private _postHooks: Array< - (output: TOutput, ctx: types.TaskCallContext) => void - > = [] + private _preHooks: Array> = [] + private _postHooks: Array> = [] constructor(options: types.BaseTaskOptions = {}) { this._agentic = options.agentic ?? globalThis.__agentic?.deref() @@ -79,16 +77,12 @@ export abstract class BaseTask< return '' } - public addBeforeCallHook( - hook: (ctx: types.TaskCallContext) => Promise - ): this { + public addBeforeCallHook(hook: types.BeforeCallHook): this { this._preHooks.push(hook) return this } - public addAfterCallHook( - hook: (output: TOutput, ctx: types.TaskCallContext) => Promise - ): this { + public addAfterCallHook(hook: types.AfterCallHook): this { this._postHooks.push(hook) return this } diff --git a/legacy/src/types.ts b/legacy/src/types.ts index 181bee20..f2541004 100644 --- a/legacy/src/types.ts +++ b/legacy/src/types.ts @@ -152,3 +152,12 @@ export declare class CancelablePromise extends Promise { } // export type ProgressFunction = (partialResponse: ChatMessage) => void + +export type BeforeCallHook = ( + ctx: TaskCallContext +) => void | Promise + +export type AfterCallHook< + TInput extends TaskInput = void, + TOutput extends TaskOutput = string +> = (output: TOutput, ctx: TaskCallContext) => void | Promise