feat: add task caching

old-agentic-v1^2
Travis Fischer 2023-06-28 17:22:52 -07:00
rodzic c634b3222b
commit 034ea7a6c2
3 zmienionych plików z 91 dodań i 4 usunięć

Wyświetl plik

@ -1,4 +1,5 @@
import pRetry, { FailedAttemptError } from 'p-retry' import pRetry, { FailedAttemptError } from 'p-retry'
import QuickLRU from 'quick-lru'
import { ZodType } from 'zod' import { ZodType } from 'zod'
import * as errors from './errors' import * as errors from './errors'
@ -34,12 +35,14 @@ export abstract class BaseTask<
protected _timeoutMs?: number protected _timeoutMs?: number
protected _retryConfig: types.RetryConfig protected _retryConfig: types.RetryConfig
protected _cacheConfig: types.CacheConfig<TInput, TOutput>
private _preHooks: Array<{ protected _preHooks: Array<{
hook: types.TaskBeforeCallHook<TInput> hook: types.TaskBeforeCallHook<TInput>
priority: number priority: number
}> = [] }> = []
private _postHooks: Array<{
protected _postHooks: Array<{
hook: types.TaskAfterCallHook<TInput, TOutput> hook: types.TaskAfterCallHook<TInput, TOutput>
priority: number priority: number
}> = [] }> = []
@ -53,6 +56,19 @@ export abstract class BaseTask<
strategy: 'default' strategy: 'default'
} }
this._cacheConfig = {
cacheStrategy: 'default',
cacheKey: (input: TInput) => JSON.stringify(input),
...options.cacheConfig
}
if (
this._cacheConfig.cacheStrategy === 'default' &&
!this._cacheConfig.cache
) {
this._cacheConfig.cache = new QuickLRU<string, TOutput>({ maxSize: 1000 })
}
this._id = this._id =
options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn() options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn()
} }
@ -115,7 +131,7 @@ export abstract class BaseTask<
priority = 0 priority = 0
): this { ): this {
this._postHooks.push({ hook, priority }) this._postHooks.push({ hook, priority })
this._postHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019) this._postHooks.sort((a, b) => b.priority - a.priority)
return this return this
} }
@ -142,6 +158,10 @@ export abstract class BaseTask<
throw new Error(`clone not implemented for task "${this.nameForModel}"`) throw new Error(`clone not implemented for task "${this.nameForModel}"`)
} }
/**
* Adds an after call hook to confirm or refine the output of this task with
* human feedback.
*/
public withHumanFeedback<V extends HumanFeedbackType>( public withHumanFeedback<V extends HumanFeedbackType>(
options: HumanFeedbackOptions<V, TOutput> = {} options: HumanFeedbackOptions<V, TOutput> = {}
): this { ): this {
@ -194,6 +214,11 @@ export abstract class BaseTask<
return this return this
} }
public cacheConfig(cacheConfig: types.CacheConfig<TInput, TOutput>): this {
this._cacheConfig = cacheConfig
return this
}
/** /**
* Calls this task with the given `input` and returns the result only. * Calls this task with the given `input` and returns the result only.
*/ */
@ -227,6 +252,13 @@ export abstract class BaseTask<
input = safeInput.data input = safeInput.data
} }
const maybeCacheKey = this._cacheConfig.cache
? this._cacheConfig.cacheKey?.(input)
: undefined
const cacheKey = maybeCacheKey
? await Promise.resolve(maybeCacheKey)
: undefined
const ctx: types.TaskCallContext<TInput> = { const ctx: types.TaskCallContext<TInput> = {
input, input,
attemptNumber: 0, attemptNumber: 0,
@ -235,7 +267,8 @@ export abstract class BaseTask<
taskId: this.id, taskId: this.id,
callId: this._agentic!.idGeneratorFn(), callId: this._agentic!.idGeneratorFn(),
parentTaskId: parentCtx?.metadata.taskId, parentTaskId: parentCtx?.metadata.taskId,
parentCallId: parentCtx?.metadata.callId parentCallId: parentCtx?.metadata.callId,
cacheStatus: 'miss'
} }
} }
@ -260,6 +293,22 @@ export abstract class BaseTask<
} }
} }
if (cacheKey && this._cacheConfig.cache) {
const cachedValue = await Promise.resolve(
this._cacheConfig.cache.get(cacheKey)
)
if (cachedValue) {
ctx.metadata.success = true
ctx.metadata.cacheStatus = 'hit'
return {
result: cachedValue,
metadata: ctx.metadata
}
}
}
const result = await pRetry( const result = await pRetry(
async () => { async () => {
let result = await this._call(ctx) let result = await this._call(ctx)
@ -326,6 +375,10 @@ export abstract class BaseTask<
ctx.metadata.numRetries = ctx.attemptNumber ctx.metadata.numRetries = ctx.attemptNumber
ctx.metadata.error = undefined ctx.metadata.error = undefined
if (cacheKey && this._cacheConfig.cache) {
await Promise.resolve(this._cacheConfig.cache.set(cacheKey, result))
}
// ctx.tracker.setOutput(stringifyForDebugging(result, { maxLength: 100 })) // ctx.tracker.setOutput(stringifyForDebugging(result, { maxLength: 100 }))
return { return {

Wyświetl plik

@ -1,6 +1,7 @@
import * as anthropic from '@anthropic-ai/sdk' import * as anthropic from '@anthropic-ai/sdk'
import * as openai from 'openai-fetch' import * as openai from 'openai-fetch'
import ky from 'ky' import ky from 'ky'
import type { CacheStorage } from 'p-memoize'
import type { Options as RetryOptions } from 'p-retry' import type { Options as RetryOptions } from 'p-retry'
import type { JsonObject, Jsonifiable } from 'type-fest' import type { JsonObject, Jsonifiable } from 'type-fest'
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod' import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
@ -39,6 +40,7 @@ export interface BaseTaskOptions {
timeoutMs?: number timeoutMs?: number
retryConfig?: RetryConfig retryConfig?: RetryConfig
cacheConfig?: CacheConfig<any, any>
id?: string id?: string
// TODO // TODO
@ -55,6 +57,8 @@ export interface BaseLLMOptions<
inputSchema?: ZodType<TInput> inputSchema?: ZodType<TInput>
outputSchema?: ZodType<TOutput> outputSchema?: ZodType<TOutput>
cacheConfig?: CacheConfig<TInput, TOutput>
provider?: string provider?: string
model?: string model?: string
modelParams?: TModelParams modelParams?: TModelParams
@ -120,6 +124,20 @@ export interface RetryConfig extends RetryOptions {
strategy?: string strategy?: string
} }
export type MaybePromise<T> = T | Promise<T>
export type CacheStatus = 'miss' | 'hit'
export type CacheStrategy = 'default' | 'none'
export interface CacheConfig<
TInput extends TaskInput,
TOutput extends TaskOutput,
TCacheKey = any
> {
cacheStrategy?: CacheStrategy
cacheKey?: (input: TInput) => MaybePromise<TCacheKey | undefined>
cache?: CacheStorage<TCacheKey, TOutput> | false
}
export interface TaskResponseMetadata extends Record<string, any> { export interface TaskResponseMetadata extends Record<string, any> {
// task info // task info
taskName: string taskName: string
@ -132,6 +150,7 @@ export interface TaskResponseMetadata extends Record<string, any> {
numRetries?: number numRetries?: number
callId?: string callId?: string
parentCallId?: string parentCallId?: string
cacheStatus?: CacheStatus
// human feedback info // human feedback info
feedback?: FeedbackTypeToMetadata<HumanFeedbackType> feedback?: FeedbackTypeToMetadata<HumanFeedbackType>

Wyświetl plik

@ -24,7 +24,22 @@ test('CalculatorTool', async (t) => {
t.deepEqual(metadata, { t.deepEqual(metadata, {
success: true, success: true,
taskName: 'calculator', taskName: 'calculator',
cacheStatus: 'miss',
numRetries: 0, numRetries: 0,
error: undefined error: undefined
}) })
}) })
test.only('CalculatorTool - caching', async (t) => {
const agentic = createTestAgenticRuntime()
const tool = new CalculatorTool({ agentic })
const res = await tool.callWithMetadata({ expression: '2 * 3' })
t.is(res.result, 6)
t.is(res.metadata.cacheStatus, 'miss')
expectTypeOf(res.result).toMatchTypeOf<number>()
const res2 = await tool.callWithMetadata({ expression: '2 * 3' })
t.is(res2.result, 6)
t.is(res2.metadata.cacheStatus, 'hit')
})