feat: add task caching

old-agentic-v1^2
Travis Fischer 2023-06-28 17:22:52 -07:00
rodzic c634b3222b
commit 034ea7a6c2
3 zmienionych plików z 91 dodań i 4 usunięć

Wyświetl plik

@ -1,4 +1,5 @@
import pRetry, { FailedAttemptError } from 'p-retry'
import QuickLRU from 'quick-lru'
import { ZodType } from 'zod'
import * as errors from './errors'
@ -34,12 +35,14 @@ export abstract class BaseTask<
protected _timeoutMs?: number
protected _retryConfig: types.RetryConfig
protected _cacheConfig: types.CacheConfig<TInput, TOutput>
private _preHooks: Array<{
protected _preHooks: Array<{
hook: types.TaskBeforeCallHook<TInput>
priority: number
}> = []
private _postHooks: Array<{
protected _postHooks: Array<{
hook: types.TaskAfterCallHook<TInput, TOutput>
priority: number
}> = []
@ -53,6 +56,19 @@ export abstract class BaseTask<
strategy: 'default'
}
this._cacheConfig = {
cacheStrategy: 'default',
cacheKey: (input: TInput) => JSON.stringify(input),
...options.cacheConfig
}
if (
this._cacheConfig.cacheStrategy === 'default' &&
!this._cacheConfig.cache
) {
this._cacheConfig.cache = new QuickLRU<string, TOutput>({ maxSize: 1000 })
}
this._id =
options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn()
}
@ -115,7 +131,7 @@ export abstract class BaseTask<
priority = 0
): this {
this._postHooks.push({ hook, priority })
this._postHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019)
this._postHooks.sort((a, b) => b.priority - a.priority)
return this
}
@ -142,6 +158,10 @@ export abstract class BaseTask<
throw new Error(`clone not implemented for task "${this.nameForModel}"`)
}
/**
* Adds an after call hook to confirm or refine the output of this task with
* human feedback.
*/
public withHumanFeedback<V extends HumanFeedbackType>(
options: HumanFeedbackOptions<V, TOutput> = {}
): this {
@ -194,6 +214,11 @@ export abstract class BaseTask<
return this
}
public cacheConfig(cacheConfig: types.CacheConfig<TInput, TOutput>): this {
this._cacheConfig = cacheConfig
return this
}
/**
* Calls this task with the given `input` and returns the result only.
*/
@ -227,6 +252,13 @@ export abstract class BaseTask<
input = safeInput.data
}
const maybeCacheKey = this._cacheConfig.cache
? this._cacheConfig.cacheKey?.(input)
: undefined
const cacheKey = maybeCacheKey
? await Promise.resolve(maybeCacheKey)
: undefined
const ctx: types.TaskCallContext<TInput> = {
input,
attemptNumber: 0,
@ -235,7 +267,8 @@ export abstract class BaseTask<
taskId: this.id,
callId: this._agentic!.idGeneratorFn(),
parentTaskId: parentCtx?.metadata.taskId,
parentCallId: parentCtx?.metadata.callId
parentCallId: parentCtx?.metadata.callId,
cacheStatus: 'miss'
}
}
@ -260,6 +293,22 @@ export abstract class BaseTask<
}
}
if (cacheKey && this._cacheConfig.cache) {
const cachedValue = await Promise.resolve(
this._cacheConfig.cache.get(cacheKey)
)
if (cachedValue) {
ctx.metadata.success = true
ctx.metadata.cacheStatus = 'hit'
return {
result: cachedValue,
metadata: ctx.metadata
}
}
}
const result = await pRetry(
async () => {
let result = await this._call(ctx)
@ -326,6 +375,10 @@ export abstract class BaseTask<
ctx.metadata.numRetries = ctx.attemptNumber
ctx.metadata.error = undefined
if (cacheKey && this._cacheConfig.cache) {
await Promise.resolve(this._cacheConfig.cache.set(cacheKey, result))
}
// ctx.tracker.setOutput(stringifyForDebugging(result, { maxLength: 100 }))
return {

Wyświetl plik

@ -1,6 +1,7 @@
import * as anthropic from '@anthropic-ai/sdk'
import * as openai from 'openai-fetch'
import ky from 'ky'
import type { CacheStorage } from 'p-memoize'
import type { Options as RetryOptions } from 'p-retry'
import type { JsonObject, Jsonifiable } from 'type-fest'
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
@ -39,6 +40,7 @@ export interface BaseTaskOptions {
timeoutMs?: number
retryConfig?: RetryConfig
cacheConfig?: CacheConfig<any, any>
id?: string
// TODO
@ -55,6 +57,8 @@ export interface BaseLLMOptions<
inputSchema?: ZodType<TInput>
outputSchema?: ZodType<TOutput>
cacheConfig?: CacheConfig<TInput, TOutput>
provider?: string
model?: string
modelParams?: TModelParams
@ -120,6 +124,20 @@ export interface RetryConfig extends RetryOptions {
strategy?: string
}
export type MaybePromise<T> = T | Promise<T>
export type CacheStatus = 'miss' | 'hit'
export type CacheStrategy = 'default' | 'none'
export interface CacheConfig<
TInput extends TaskInput,
TOutput extends TaskOutput,
TCacheKey = any
> {
cacheStrategy?: CacheStrategy
cacheKey?: (input: TInput) => MaybePromise<TCacheKey | undefined>
cache?: CacheStorage<TCacheKey, TOutput> | false
}
export interface TaskResponseMetadata extends Record<string, any> {
// task info
taskName: string
@ -132,6 +150,7 @@ export interface TaskResponseMetadata extends Record<string, any> {
numRetries?: number
callId?: string
parentCallId?: string
cacheStatus?: CacheStatus
// human feedback info
feedback?: FeedbackTypeToMetadata<HumanFeedbackType>

Wyświetl plik

@ -24,7 +24,22 @@ test('CalculatorTool', async (t) => {
t.deepEqual(metadata, {
success: true,
taskName: 'calculator',
cacheStatus: 'miss',
numRetries: 0,
error: undefined
})
})
test.only('CalculatorTool - caching', async (t) => {
const agentic = createTestAgenticRuntime()
const tool = new CalculatorTool({ agentic })
const res = await tool.callWithMetadata({ expression: '2 * 3' })
t.is(res.result, 6)
t.is(res.metadata.cacheStatus, 'miss')
expectTypeOf(res.result).toMatchTypeOf<number>()
const res2 = await tool.callWithMetadata({ expression: '2 * 3' })
t.is(res2.result, 6)
t.is(res2.metadata.cacheStatus, 'hit')
})