chatgpt-api/src/task.ts

175 wiersze
4.4 KiB
TypeScript
Czysty Zwykły widok Historia

2023-06-10 05:42:20 +00:00
import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodType } from 'zod'
2023-05-26 19:06:39 +00:00
import * as errors from './errors'
import * as types from './types'
import type { Agentic } from './agentic'
import { defaultIDGeneratorFn, isValidTaskIdentifier } from './utils'
2023-05-26 19:06:39 +00:00
2023-05-27 00:36:52 +00:00
/**
* A `Task` is an async function call that may be non-deterministic. It has
* structured input and structured output. Invoking a task is equivalent to
* sampling from a probability distribution.
2023-06-06 22:20:20 +00:00
*
2023-05-27 00:36:52 +00:00
* Examples of tasks include:
2023-06-02 06:26:22 +00:00
* - LLM calls
2023-06-06 22:20:20 +00:00
* - Chain of LLM calls
* - Retrieval task
2023-05-27 00:36:52 +00:00
* - API calls
* - Native function calls
2023-05-27 00:36:52 +00:00
* - Invoking sub-agents
*/
export abstract class BaseTask<
TInput extends void | types.JsonObject = void,
TOutput extends types.JsonValue = string
> {
protected _agentic: Agentic
protected _id: string
2023-06-06 22:20:20 +00:00
protected _timeoutMs?: number
2023-06-09 17:44:02 +00:00
protected _retryConfig: types.RetryConfig
2023-06-01 01:53:09 +00:00
constructor(options: types.BaseTaskOptions = {}) {
this._agentic = options.agentic ?? globalThis.__agentic?.deref()
2023-05-26 19:16:13 +00:00
this._timeoutMs = options.timeoutMs
2023-06-09 17:44:02 +00:00
this._retryConfig = options.retryConfig ?? {
retries: 3,
strategy: 'default'
}
this._id =
options.id ?? this._agentic?.idGeneratorFn() ?? defaultIDGeneratorFn()
2023-05-26 19:06:39 +00:00
}
public get agentic(): Agentic {
return this._agentic
}
public set agentic(agentic: Agentic) {
this._agentic = agentic
}
public get id(): string {
return this._id
}
public abstract get inputSchema(): ZodType<TInput>
public abstract get outputSchema(): ZodType<TOutput>
2023-05-26 19:06:39 +00:00
public get nameForModel(): string {
const name = this.constructor.name
return name[0].toLowerCase() + name.slice(1)
}
public get nameForHuman(): string {
return this.constructor.name
}
public get descForModel(): string {
return ''
}
2023-06-02 06:26:22 +00:00
public validate() {
if (!this._agentic) {
throw new Error(
`Task "${this.nameForHuman}" is missing a required "agentic" instance`
)
}
const nameForModel = this.nameForModel
if (!isValidTaskIdentifier(nameForModel)) {
throw new Error(`Task field nameForModel "${nameForModel}" is invalid`)
}
}
// TODO: is this really necessary?
public clone(): BaseTask<TInput, TOutput> {
// TODO: override in subclass if needed
throw new Error(`clone not implemented for task "${this.nameForModel}"`)
}
2023-06-13 00:25:19 +00:00
public retryConfig(retryConfig: types.RetryConfig): this {
2023-05-26 19:16:13 +00:00
this._retryConfig = retryConfig
2023-05-26 19:06:39 +00:00
return this
}
/**
* Calls this task with the given `input` and returns the result only.
*/
public async call(input?: TInput): Promise<TOutput> {
const res = await this.callWithMetadata(input)
return res.result
2023-06-06 22:20:20 +00:00
}
/**
* Calls this task with the given `input` and returns the result along with metadata.
*/
2023-06-06 22:20:20 +00:00
public async callWithMetadata(
input?: TInput
2023-06-06 22:20:20 +00:00
): Promise<types.TaskResponse<TOutput>> {
this.validate()
2023-06-10 06:11:10 +00:00
if (this.inputSchema) {
const safeInput = this.inputSchema.safeParse(input)
2023-06-11 00:57:33 +00:00
2023-06-10 06:11:10 +00:00
if (!safeInput.success) {
throw new Error(`Invalid input: ${safeInput.error.message}`)
}
input = safeInput.data
}
const ctx: types.TaskCallContext<TInput> = {
2023-06-09 17:44:02 +00:00
input,
2023-06-10 05:42:20 +00:00
attemptNumber: 0,
metadata: {
taskName: this.nameForModel,
2023-06-13 00:25:19 +00:00
taskId: this.id,
callId: this._agentic!.idGeneratorFn()
}
2023-06-09 17:44:02 +00:00
}
2023-06-10 05:42:20 +00:00
const result = await pRetry(() => this._call(ctx), {
...this._retryConfig,
onFailedAttempt: async (err: FailedAttemptError) => {
if (this._retryConfig.onFailedAttempt) {
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
}
// TODO: log this task error
2023-06-10 05:42:20 +00:00
ctx.attemptNumber = err.attemptNumber + 1
2023-06-11 02:21:09 +00:00
ctx.metadata.error = err
2023-06-10 05:42:20 +00:00
2023-06-09 17:44:02 +00:00
if (err instanceof errors.ZodOutputValidationError) {
2023-06-10 05:42:20 +00:00
ctx.retryMessage = err.message
} else if (err instanceof errors.OutputValidationError) {
ctx.retryMessage = err.message
2023-06-09 17:44:02 +00:00
} else {
throw err
}
}
2023-06-10 05:42:20 +00:00
})
2023-06-09 17:44:02 +00:00
2023-06-11 02:21:09 +00:00
ctx.metadata.success = true
ctx.metadata.numRetries = ctx.attemptNumber
ctx.metadata.error = undefined
2023-06-10 05:42:20 +00:00
return {
result,
metadata: ctx.metadata
}
2023-06-06 22:20:20 +00:00
}
/**
* Subclasses must implement the core `_call` logic for this task.
*/
protected abstract _call(ctx: types.TaskCallContext<TInput>): Promise<TOutput>
2023-05-26 19:06:39 +00:00
// TODO
// abstract stream({
// input: TInput,
// onProgress: types.ProgressFunction
// }): Promise<TOutput>
}