kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add task caching
rodzic
c634b3222b
commit
034ea7a6c2
61
src/task.ts
61
src/task.ts
|
@ -1,4 +1,5 @@
|
|||
import pRetry, { FailedAttemptError } from 'p-retry'
|
||||
import QuickLRU from 'quick-lru'
|
||||
import { ZodType } from 'zod'
|
||||
|
||||
import * as errors from './errors'
|
||||
|
@ -34,12 +35,14 @@ export abstract class BaseTask<
|
|||
|
||||
protected _timeoutMs?: number
|
||||
protected _retryConfig: types.RetryConfig
|
||||
protected _cacheConfig: types.CacheConfig<TInput, TOutput>
|
||||
|
||||
private _preHooks: Array<{
|
||||
protected _preHooks: Array<{
|
||||
hook: types.TaskBeforeCallHook<TInput>
|
||||
priority: number
|
||||
}> = []
|
||||
private _postHooks: Array<{
|
||||
|
||||
protected _postHooks: Array<{
|
||||
hook: types.TaskAfterCallHook<TInput, TOutput>
|
||||
priority: number
|
||||
}> = []
|
||||
|
@ -53,6 +56,19 @@ export abstract class BaseTask<
|
|||
strategy: 'default'
|
||||
}
|
||||
|
||||
this._cacheConfig = {
|
||||
cacheStrategy: 'default',
|
||||
cacheKey: (input: TInput) => JSON.stringify(input),
|
||||
...options.cacheConfig
|
||||
}
|
||||
|
||||
if (
|
||||
this._cacheConfig.cacheStrategy === 'default' &&
|
||||
!this._cacheConfig.cache
|
||||
) {
|
||||
this._cacheConfig.cache = new QuickLRU<string, TOutput>({ maxSize: 1000 })
|
||||
}
|
||||
|
||||
this._id =
|
||||
options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn()
|
||||
}
|
||||
|
@ -115,7 +131,7 @@ export abstract class BaseTask<
|
|||
priority = 0
|
||||
): this {
|
||||
this._postHooks.push({ hook, priority })
|
||||
this._postHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019)
|
||||
this._postHooks.sort((a, b) => b.priority - a.priority)
|
||||
return this
|
||||
}
|
||||
|
||||
|
@ -142,6 +158,10 @@ export abstract class BaseTask<
|
|||
throw new Error(`clone not implemented for task "${this.nameForModel}"`)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an after call hook to confirm or refine the output of this task with
|
||||
* human feedback.
|
||||
*/
|
||||
public withHumanFeedback<V extends HumanFeedbackType>(
|
||||
options: HumanFeedbackOptions<V, TOutput> = {}
|
||||
): this {
|
||||
|
@ -194,6 +214,11 @@ export abstract class BaseTask<
|
|||
return this
|
||||
}
|
||||
|
||||
public cacheConfig(cacheConfig: types.CacheConfig<TInput, TOutput>): this {
|
||||
this._cacheConfig = cacheConfig
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls this task with the given `input` and returns the result only.
|
||||
*/
|
||||
|
@ -227,6 +252,13 @@ export abstract class BaseTask<
|
|||
input = safeInput.data
|
||||
}
|
||||
|
||||
const maybeCacheKey = this._cacheConfig.cache
|
||||
? this._cacheConfig.cacheKey?.(input)
|
||||
: undefined
|
||||
const cacheKey = maybeCacheKey
|
||||
? await Promise.resolve(maybeCacheKey)
|
||||
: undefined
|
||||
|
||||
const ctx: types.TaskCallContext<TInput> = {
|
||||
input,
|
||||
attemptNumber: 0,
|
||||
|
@ -235,7 +267,8 @@ export abstract class BaseTask<
|
|||
taskId: this.id,
|
||||
callId: this._agentic!.idGeneratorFn(),
|
||||
parentTaskId: parentCtx?.metadata.taskId,
|
||||
parentCallId: parentCtx?.metadata.callId
|
||||
parentCallId: parentCtx?.metadata.callId,
|
||||
cacheStatus: 'miss'
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -260,6 +293,22 @@ export abstract class BaseTask<
|
|||
}
|
||||
}
|
||||
|
||||
if (cacheKey && this._cacheConfig.cache) {
|
||||
const cachedValue = await Promise.resolve(
|
||||
this._cacheConfig.cache.get(cacheKey)
|
||||
)
|
||||
|
||||
if (cachedValue) {
|
||||
ctx.metadata.success = true
|
||||
ctx.metadata.cacheStatus = 'hit'
|
||||
|
||||
return {
|
||||
result: cachedValue,
|
||||
metadata: ctx.metadata
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const result = await pRetry(
|
||||
async () => {
|
||||
let result = await this._call(ctx)
|
||||
|
@ -326,6 +375,10 @@ export abstract class BaseTask<
|
|||
ctx.metadata.numRetries = ctx.attemptNumber
|
||||
ctx.metadata.error = undefined
|
||||
|
||||
if (cacheKey && this._cacheConfig.cache) {
|
||||
await Promise.resolve(this._cacheConfig.cache.set(cacheKey, result))
|
||||
}
|
||||
|
||||
// ctx.tracker.setOutput(stringifyForDebugging(result, { maxLength: 100 }))
|
||||
|
||||
return {
|
||||
|
|
19
src/types.ts
19
src/types.ts
|
@ -1,6 +1,7 @@
|
|||
import * as anthropic from '@anthropic-ai/sdk'
|
||||
import * as openai from 'openai-fetch'
|
||||
import ky from 'ky'
|
||||
import type { CacheStorage } from 'p-memoize'
|
||||
import type { Options as RetryOptions } from 'p-retry'
|
||||
import type { JsonObject, Jsonifiable } from 'type-fest'
|
||||
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
||||
|
@ -39,6 +40,7 @@ export interface BaseTaskOptions {
|
|||
|
||||
timeoutMs?: number
|
||||
retryConfig?: RetryConfig
|
||||
cacheConfig?: CacheConfig<any, any>
|
||||
id?: string
|
||||
|
||||
// TODO
|
||||
|
@ -55,6 +57,8 @@ export interface BaseLLMOptions<
|
|||
inputSchema?: ZodType<TInput>
|
||||
outputSchema?: ZodType<TOutput>
|
||||
|
||||
cacheConfig?: CacheConfig<TInput, TOutput>
|
||||
|
||||
provider?: string
|
||||
model?: string
|
||||
modelParams?: TModelParams
|
||||
|
@ -120,6 +124,20 @@ export interface RetryConfig extends RetryOptions {
|
|||
strategy?: string
|
||||
}
|
||||
|
||||
export type MaybePromise<T> = T | Promise<T>
|
||||
export type CacheStatus = 'miss' | 'hit'
|
||||
export type CacheStrategy = 'default' | 'none'
|
||||
|
||||
export interface CacheConfig<
|
||||
TInput extends TaskInput,
|
||||
TOutput extends TaskOutput,
|
||||
TCacheKey = any
|
||||
> {
|
||||
cacheStrategy?: CacheStrategy
|
||||
cacheKey?: (input: TInput) => MaybePromise<TCacheKey | undefined>
|
||||
cache?: CacheStorage<TCacheKey, TOutput> | false
|
||||
}
|
||||
|
||||
export interface TaskResponseMetadata extends Record<string, any> {
|
||||
// task info
|
||||
taskName: string
|
||||
|
@ -132,6 +150,7 @@ export interface TaskResponseMetadata extends Record<string, any> {
|
|||
numRetries?: number
|
||||
callId?: string
|
||||
parentCallId?: string
|
||||
cacheStatus?: CacheStatus
|
||||
|
||||
// human feedback info
|
||||
feedback?: FeedbackTypeToMetadata<HumanFeedbackType>
|
||||
|
|
|
@ -24,7 +24,22 @@ test('CalculatorTool', async (t) => {
|
|||
t.deepEqual(metadata, {
|
||||
success: true,
|
||||
taskName: 'calculator',
|
||||
cacheStatus: 'miss',
|
||||
numRetries: 0,
|
||||
error: undefined
|
||||
})
|
||||
})
|
||||
|
||||
test.only('CalculatorTool - caching', async (t) => {
|
||||
const agentic = createTestAgenticRuntime()
|
||||
const tool = new CalculatorTool({ agentic })
|
||||
|
||||
const res = await tool.callWithMetadata({ expression: '2 * 3' })
|
||||
t.is(res.result, 6)
|
||||
t.is(res.metadata.cacheStatus, 'miss')
|
||||
expectTypeOf(res.result).toMatchTypeOf<number>()
|
||||
|
||||
const res2 = await tool.callWithMetadata({ expression: '2 * 3' })
|
||||
t.is(res2.result, 6)
|
||||
t.is(res2.metadata.cacheStatus, 'hit')
|
||||
})
|
||||
|
|
Ładowanie…
Reference in New Issue