kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add task caching
rodzic
c634b3222b
commit
034ea7a6c2
61
src/task.ts
61
src/task.ts
|
@ -1,4 +1,5 @@
|
||||||
import pRetry, { FailedAttemptError } from 'p-retry'
|
import pRetry, { FailedAttemptError } from 'p-retry'
|
||||||
|
import QuickLRU from 'quick-lru'
|
||||||
import { ZodType } from 'zod'
|
import { ZodType } from 'zod'
|
||||||
|
|
||||||
import * as errors from './errors'
|
import * as errors from './errors'
|
||||||
|
@ -34,12 +35,14 @@ export abstract class BaseTask<
|
||||||
|
|
||||||
protected _timeoutMs?: number
|
protected _timeoutMs?: number
|
||||||
protected _retryConfig: types.RetryConfig
|
protected _retryConfig: types.RetryConfig
|
||||||
|
protected _cacheConfig: types.CacheConfig<TInput, TOutput>
|
||||||
|
|
||||||
private _preHooks: Array<{
|
protected _preHooks: Array<{
|
||||||
hook: types.TaskBeforeCallHook<TInput>
|
hook: types.TaskBeforeCallHook<TInput>
|
||||||
priority: number
|
priority: number
|
||||||
}> = []
|
}> = []
|
||||||
private _postHooks: Array<{
|
|
||||||
|
protected _postHooks: Array<{
|
||||||
hook: types.TaskAfterCallHook<TInput, TOutput>
|
hook: types.TaskAfterCallHook<TInput, TOutput>
|
||||||
priority: number
|
priority: number
|
||||||
}> = []
|
}> = []
|
||||||
|
@ -53,6 +56,19 @@ export abstract class BaseTask<
|
||||||
strategy: 'default'
|
strategy: 'default'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this._cacheConfig = {
|
||||||
|
cacheStrategy: 'default',
|
||||||
|
cacheKey: (input: TInput) => JSON.stringify(input),
|
||||||
|
...options.cacheConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
this._cacheConfig.cacheStrategy === 'default' &&
|
||||||
|
!this._cacheConfig.cache
|
||||||
|
) {
|
||||||
|
this._cacheConfig.cache = new QuickLRU<string, TOutput>({ maxSize: 1000 })
|
||||||
|
}
|
||||||
|
|
||||||
this._id =
|
this._id =
|
||||||
options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn()
|
options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn()
|
||||||
}
|
}
|
||||||
|
@ -115,7 +131,7 @@ export abstract class BaseTask<
|
||||||
priority = 0
|
priority = 0
|
||||||
): this {
|
): this {
|
||||||
this._postHooks.push({ hook, priority })
|
this._postHooks.push({ hook, priority })
|
||||||
this._postHooks.sort((a, b) => b.priority - a.priority) // two elements that compare equal will remain in their original order (>= ECMAScript 2019)
|
this._postHooks.sort((a, b) => b.priority - a.priority)
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,6 +158,10 @@ export abstract class BaseTask<
|
||||||
throw new Error(`clone not implemented for task "${this.nameForModel}"`)
|
throw new Error(`clone not implemented for task "${this.nameForModel}"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds an after call hook to confirm or refine the output of this task with
|
||||||
|
* human feedback.
|
||||||
|
*/
|
||||||
public withHumanFeedback<V extends HumanFeedbackType>(
|
public withHumanFeedback<V extends HumanFeedbackType>(
|
||||||
options: HumanFeedbackOptions<V, TOutput> = {}
|
options: HumanFeedbackOptions<V, TOutput> = {}
|
||||||
): this {
|
): this {
|
||||||
|
@ -194,6 +214,11 @@ export abstract class BaseTask<
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public cacheConfig(cacheConfig: types.CacheConfig<TInput, TOutput>): this {
|
||||||
|
this._cacheConfig = cacheConfig
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls this task with the given `input` and returns the result only.
|
* Calls this task with the given `input` and returns the result only.
|
||||||
*/
|
*/
|
||||||
|
@ -227,6 +252,13 @@ export abstract class BaseTask<
|
||||||
input = safeInput.data
|
input = safeInput.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const maybeCacheKey = this._cacheConfig.cache
|
||||||
|
? this._cacheConfig.cacheKey?.(input)
|
||||||
|
: undefined
|
||||||
|
const cacheKey = maybeCacheKey
|
||||||
|
? await Promise.resolve(maybeCacheKey)
|
||||||
|
: undefined
|
||||||
|
|
||||||
const ctx: types.TaskCallContext<TInput> = {
|
const ctx: types.TaskCallContext<TInput> = {
|
||||||
input,
|
input,
|
||||||
attemptNumber: 0,
|
attemptNumber: 0,
|
||||||
|
@ -235,7 +267,8 @@ export abstract class BaseTask<
|
||||||
taskId: this.id,
|
taskId: this.id,
|
||||||
callId: this._agentic!.idGeneratorFn(),
|
callId: this._agentic!.idGeneratorFn(),
|
||||||
parentTaskId: parentCtx?.metadata.taskId,
|
parentTaskId: parentCtx?.metadata.taskId,
|
||||||
parentCallId: parentCtx?.metadata.callId
|
parentCallId: parentCtx?.metadata.callId,
|
||||||
|
cacheStatus: 'miss'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,6 +293,22 @@ export abstract class BaseTask<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cacheKey && this._cacheConfig.cache) {
|
||||||
|
const cachedValue = await Promise.resolve(
|
||||||
|
this._cacheConfig.cache.get(cacheKey)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (cachedValue) {
|
||||||
|
ctx.metadata.success = true
|
||||||
|
ctx.metadata.cacheStatus = 'hit'
|
||||||
|
|
||||||
|
return {
|
||||||
|
result: cachedValue,
|
||||||
|
metadata: ctx.metadata
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const result = await pRetry(
|
const result = await pRetry(
|
||||||
async () => {
|
async () => {
|
||||||
let result = await this._call(ctx)
|
let result = await this._call(ctx)
|
||||||
|
@ -326,6 +375,10 @@ export abstract class BaseTask<
|
||||||
ctx.metadata.numRetries = ctx.attemptNumber
|
ctx.metadata.numRetries = ctx.attemptNumber
|
||||||
ctx.metadata.error = undefined
|
ctx.metadata.error = undefined
|
||||||
|
|
||||||
|
if (cacheKey && this._cacheConfig.cache) {
|
||||||
|
await Promise.resolve(this._cacheConfig.cache.set(cacheKey, result))
|
||||||
|
}
|
||||||
|
|
||||||
// ctx.tracker.setOutput(stringifyForDebugging(result, { maxLength: 100 }))
|
// ctx.tracker.setOutput(stringifyForDebugging(result, { maxLength: 100 }))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
19
src/types.ts
19
src/types.ts
|
@ -1,6 +1,7 @@
|
||||||
import * as anthropic from '@anthropic-ai/sdk'
|
import * as anthropic from '@anthropic-ai/sdk'
|
||||||
import * as openai from 'openai-fetch'
|
import * as openai from 'openai-fetch'
|
||||||
import ky from 'ky'
|
import ky from 'ky'
|
||||||
|
import type { CacheStorage } from 'p-memoize'
|
||||||
import type { Options as RetryOptions } from 'p-retry'
|
import type { Options as RetryOptions } from 'p-retry'
|
||||||
import type { JsonObject, Jsonifiable } from 'type-fest'
|
import type { JsonObject, Jsonifiable } from 'type-fest'
|
||||||
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
import { SafeParseReturnType, ZodType, ZodTypeAny, output, z } from 'zod'
|
||||||
|
@ -39,6 +40,7 @@ export interface BaseTaskOptions {
|
||||||
|
|
||||||
timeoutMs?: number
|
timeoutMs?: number
|
||||||
retryConfig?: RetryConfig
|
retryConfig?: RetryConfig
|
||||||
|
cacheConfig?: CacheConfig<any, any>
|
||||||
id?: string
|
id?: string
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
|
@ -55,6 +57,8 @@ export interface BaseLLMOptions<
|
||||||
inputSchema?: ZodType<TInput>
|
inputSchema?: ZodType<TInput>
|
||||||
outputSchema?: ZodType<TOutput>
|
outputSchema?: ZodType<TOutput>
|
||||||
|
|
||||||
|
cacheConfig?: CacheConfig<TInput, TOutput>
|
||||||
|
|
||||||
provider?: string
|
provider?: string
|
||||||
model?: string
|
model?: string
|
||||||
modelParams?: TModelParams
|
modelParams?: TModelParams
|
||||||
|
@ -120,6 +124,20 @@ export interface RetryConfig extends RetryOptions {
|
||||||
strategy?: string
|
strategy?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type MaybePromise<T> = T | Promise<T>
|
||||||
|
export type CacheStatus = 'miss' | 'hit'
|
||||||
|
export type CacheStrategy = 'default' | 'none'
|
||||||
|
|
||||||
|
export interface CacheConfig<
|
||||||
|
TInput extends TaskInput,
|
||||||
|
TOutput extends TaskOutput,
|
||||||
|
TCacheKey = any
|
||||||
|
> {
|
||||||
|
cacheStrategy?: CacheStrategy
|
||||||
|
cacheKey?: (input: TInput) => MaybePromise<TCacheKey | undefined>
|
||||||
|
cache?: CacheStorage<TCacheKey, TOutput> | false
|
||||||
|
}
|
||||||
|
|
||||||
export interface TaskResponseMetadata extends Record<string, any> {
|
export interface TaskResponseMetadata extends Record<string, any> {
|
||||||
// task info
|
// task info
|
||||||
taskName: string
|
taskName: string
|
||||||
|
@ -132,6 +150,7 @@ export interface TaskResponseMetadata extends Record<string, any> {
|
||||||
numRetries?: number
|
numRetries?: number
|
||||||
callId?: string
|
callId?: string
|
||||||
parentCallId?: string
|
parentCallId?: string
|
||||||
|
cacheStatus?: CacheStatus
|
||||||
|
|
||||||
// human feedback info
|
// human feedback info
|
||||||
feedback?: FeedbackTypeToMetadata<HumanFeedbackType>
|
feedback?: FeedbackTypeToMetadata<HumanFeedbackType>
|
||||||
|
|
|
@ -24,7 +24,22 @@ test('CalculatorTool', async (t) => {
|
||||||
t.deepEqual(metadata, {
|
t.deepEqual(metadata, {
|
||||||
success: true,
|
success: true,
|
||||||
taskName: 'calculator',
|
taskName: 'calculator',
|
||||||
|
cacheStatus: 'miss',
|
||||||
numRetries: 0,
|
numRetries: 0,
|
||||||
error: undefined
|
error: undefined
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test.only('CalculatorTool - caching', async (t) => {
|
||||||
|
const agentic = createTestAgenticRuntime()
|
||||||
|
const tool = new CalculatorTool({ agentic })
|
||||||
|
|
||||||
|
const res = await tool.callWithMetadata({ expression: '2 * 3' })
|
||||||
|
t.is(res.result, 6)
|
||||||
|
t.is(res.metadata.cacheStatus, 'miss')
|
||||||
|
expectTypeOf(res.result).toMatchTypeOf<number>()
|
||||||
|
|
||||||
|
const res2 = await tool.callWithMetadata({ expression: '2 * 3' })
|
||||||
|
t.is(res2.result, 6)
|
||||||
|
t.is(res2.metadata.cacheStatus, 'hit')
|
||||||
|
})
|
||||||
|
|
Ładowanie…
Reference in New Issue