diff --git a/src/task.ts b/src/task.ts index ef3f3b2..a8eb3ee 100644 --- a/src/task.ts +++ b/src/task.ts @@ -1,4 +1,5 @@ import pRetry, { FailedAttemptError } from 'p-retry' +import QuickLRU from 'quick-lru' import { ZodType } from 'zod' import * as errors from './errors' @@ -34,12 +35,14 @@ export abstract class BaseTask< protected _timeoutMs?: number protected _retryConfig: types.RetryConfig + protected _cacheConfig: types.CacheConfig - private _preHooks: Array<{ + protected _preHooks: Array<{ hook: types.TaskBeforeCallHook priority: number }> = [] - private _postHooks: Array<{ + + protected _postHooks: Array<{ hook: types.TaskAfterCallHook priority: number }> = [] @@ -53,6 +56,19 @@ export abstract class BaseTask< strategy: 'default' } + this._cacheConfig = { + cacheStrategy: 'default', + cacheKey: (input: TInput) => JSON.stringify(input), + ...options.cacheConfig + } + + if ( + this._cacheConfig.cacheStrategy === 'default' && + !this._cacheConfig.cache + ) { + this._cacheConfig.cache = new QuickLRU({ maxSize: 1000 }) + } + this._id = options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn() } @@ -115,7 +131,7 @@ export abstract class BaseTask< priority = 0 ): this { this._postHooks.push({ hook, priority }) - this._postHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019) + this._postHooks.sort((a, b) => b.priority - a.priority) return this } @@ -142,6 +158,10 @@ export abstract class BaseTask< throw new Error(`clone not implemented for task "${this.nameForModel}"`) } + /** + * Adds an after call hook to confirm or refine the output of this task with + * human feedback. + */ public withHumanFeedback( options: HumanFeedbackOptions = {} ): this { @@ -194,6 +214,11 @@ export abstract class BaseTask< return this } + public cacheConfig(cacheConfig: types.CacheConfig): this { + this._cacheConfig = cacheConfig + return this + } + /** * Calls this task with the given `input` and returns the result only. */ @@ -227,6 +252,13 @@ export abstract class BaseTask< input = safeInput.data } + const maybeCacheKey = this._cacheConfig.cache + ? this._cacheConfig.cacheKey?.(input) + : undefined + const cacheKey = maybeCacheKey + ? await Promise.resolve(maybeCacheKey) + : undefined + const ctx: types.TaskCallContext = { input, attemptNumber: 0, @@ -235,7 +267,8 @@ export abstract class BaseTask< taskId: this.id, callId: this._agentic!.idGeneratorFn(), parentTaskId: parentCtx?.metadata.taskId, - parentCallId: parentCtx?.metadata.callId + parentCallId: parentCtx?.metadata.callId, + cacheStatus: 'miss' } } @@ -260,6 +293,22 @@ export abstract class BaseTask< } } + if (cacheKey && this._cacheConfig.cache) { + const cachedValue = await Promise.resolve( + this._cacheConfig.cache.get(cacheKey) + ) + + if (cachedValue) { + ctx.metadata.success = true + ctx.metadata.cacheStatus = 'hit' + + return { + result: cachedValue, + metadata: ctx.metadata + } + } + } + const result = await pRetry( async () => { let result = await this._call(ctx) @@ -326,6 +375,10 @@ export abstract class BaseTask< ctx.metadata.numRetries = ctx.attemptNumber ctx.metadata.error = undefined + if (cacheKey && this._cacheConfig.cache) { + await Promise.resolve(this._cacheConfig.cache.set(cacheKey, result)) + } + // ctx.tracker.setOutput(stringifyForDebugging(result, { maxLength: 100 })) return { diff --git a/src/types.ts b/src/types.ts index e6bc995..23d53ce 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,6 +1,7 @@ import * as anthropic from '@anthropic-ai/sdk' import * as openai from 'openai-fetch' import ky from 'ky' +import type { CacheStorage } from 'p-memoize' import type { Options as RetryOptions } from 'p-retry' import type { JsonObject, Jsonifiable } from 'type-fest' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' @@ -39,6 +40,7 @@ export interface BaseTaskOptions { timeoutMs?: number retryConfig?: RetryConfig + cacheConfig?: CacheConfig id?: string // TODO @@ -55,6 +57,8 @@ export interface BaseLLMOptions< inputSchema?: ZodType outputSchema?: ZodType + cacheConfig?: CacheConfig + provider?: string model?: string modelParams?: TModelParams @@ -120,6 +124,20 @@ export interface RetryConfig extends RetryOptions { strategy?: string } +export type MaybePromise = T | Promise +export type CacheStatus = 'miss' | 'hit' +export type CacheStrategy = 'default' | 'none' + +export interface CacheConfig< + TInput extends TaskInput, + TOutput extends TaskOutput, + TCacheKey = any +> { + cacheStrategy?: CacheStrategy + cacheKey?: (input: TInput) => MaybePromise + cache?: CacheStorage | false +} + export interface TaskResponseMetadata extends Record { // task info taskName: string @@ -132,6 +150,7 @@ export interface TaskResponseMetadata extends Record { numRetries?: number callId?: string parentCallId?: string + cacheStatus?: CacheStatus // human feedback info feedback?: FeedbackTypeToMetadata diff --git a/test/calculator.test.ts b/test/calculator.test.ts index 0828f44..89638f6 100644 --- a/test/calculator.test.ts +++ b/test/calculator.test.ts @@ -24,7 +24,22 @@ test('CalculatorTool', async (t) => { t.deepEqual(metadata, { success: true, taskName: 'calculator', + cacheStatus: 'miss', numRetries: 0, error: undefined }) }) + +test.only('CalculatorTool - caching', async (t) => { + const agentic = createTestAgenticRuntime() + const tool = new CalculatorTool({ agentic }) + + const res = await tool.callWithMetadata({ expression: '2 * 3' }) + t.is(res.result, 6) + t.is(res.metadata.cacheStatus, 'miss') + expectTypeOf(res.result).toMatchTypeOf() + + const res2 = await tool.callWithMetadata({ expression: '2 * 3' }) + t.is(res2.result, 6) + t.is(res2.metadata.cacheStatus, 'hit') +})