Travis Fischer 2023-06-09 22:42:20 -07:00
rodzic 9d54530880
commit 7596404ccc
3 zmienionych plików z 59 dodań i 43 usunięć

Wyświetl plik

@ -143,7 +143,10 @@ export abstract class BaseChatModel<
messages: types.ChatMessage[] messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>> ): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public async buildMessages(input?: types.ParsedData<TInput>) { public async buildMessages(
input?: types.ParsedData<TInput>,
ctx?: types.TaskCallContext
) {
if (this._inputSchema) { if (this._inputSchema) {
const inputSchema = const inputSchema =
this._inputSchema instanceof z.ZodType this._inputSchema instanceof z.ZodType
@ -211,20 +214,29 @@ export abstract class BaseChatModel<
} }
} }
if (ctx?.retryMessage) {
messages.push({
role: 'system',
content: ctx.retryMessage
})
}
// TODO: filter/compress messages based on token counts // TODO: filter/compress messages based on token counts
return messages return messages
} }
protected override async _call( protected override async _call(
input?: types.ParsedData<TInput> ctx: types.TaskCallContext<TInput, TOutput, types.LLMTaskResponseMetadata>
): Promise<types.TaskResponse<TOutput>> { ): Promise<types.ParsedData<TOutput>> {
const messages = await this.buildMessages(input) const messages = await this.buildMessages(ctx.input, ctx)
console.log('>>>') console.log('>>>')
console.log(messages) console.log(messages)
const completion = await this._createChatCompletion(messages) const completion = await this._createChatCompletion(messages)
ctx.metadata.completion = completion
let output: any = completion.message.content let output: any = completion.message.content
console.log('===') console.log('===')
@ -246,7 +258,7 @@ export abstract class BaseChatModel<
throw new errors.OutputValidationError(err.message, { cause: err }) throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) { } else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError( throw new errors.OutputValidationError(
`Invalid JSON: ${err.message}`, `Invalid JSON array: ${err.message}`,
{ cause: err } { cause: err }
) )
} else { } else {
@ -262,7 +274,7 @@ export abstract class BaseChatModel<
throw new errors.OutputValidationError(err.message, { cause: err }) throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) { } else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError( throw new errors.OutputValidationError(
`Invalid JSON: ${err.message}`, `Invalid JSON object: ${err.message}`,
{ cause: err } { cause: err }
) )
} else { } else {
@ -310,23 +322,9 @@ export abstract class BaseChatModel<
throw new errors.ZodOutputValidationError(safeResult.error) throw new errors.ZodOutputValidationError(safeResult.error)
} }
return { return safeResult.data
result: safeResult.data,
metadata: {
input,
messages,
completion
}
}
} else { } else {
return { return output
result: output,
metadata: {
input,
messages,
completion
}
}
} }
} }

Wyświetl plik

@ -1,4 +1,4 @@
import pRetry from 'p-retry' import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodRawShape, ZodTypeAny } from 'zod' import { ZodRawShape, ZodTypeAny } from 'zod'
import * as errors from '@/errors' import * as errors from '@/errors'
@ -65,35 +65,40 @@ export abstract class BaseTask<
public async callWithMetadata( public async callWithMetadata(
input?: types.ParsedData<TInput> input?: types.ParsedData<TInput>
): Promise<types.TaskResponse<TOutput>> { ): Promise<types.TaskResponse<TOutput>> {
const metadata: types.TaskResponseMetadata = { const ctx: types.TaskCallContext<TInput, TOutput> = {
input, input,
numRetries: 0 attemptNumber: 0,
metadata: {}
} }
do { const result = await pRetry(() => this._call(ctx), {
try { ...this._retryConfig,
const response = await this._call(input) onFailedAttempt: async (err: FailedAttemptError) => {
return response if (this._retryConfig.onFailedAttempt) {
} catch (err: any) { await Promise.resolve(this._retryConfig.onFailedAttempt(err))
}
ctx.attemptNumber = err.attemptNumber + 1
if (err instanceof errors.ZodOutputValidationError) { if (err instanceof errors.ZodOutputValidationError) {
// TODO ctx.retryMessage = err.message
} else if (err instanceof errors.OutputValidationError) {
ctx.retryMessage = err.message
} else { } else {
throw err throw err
} }
} }
})
// TODO: handle errors, retry logic, and self-healing return {
metadata.numRetries = (metadata.numRetries ?? 0) + 1 result,
if (metadata.numRetries > this._retryConfig.retries) { metadata: ctx.metadata
} }
// eslint-disable-next-line no-constant-condition
} while (true)
} }
protected abstract _call( protected abstract _call(
input?: types.ParsedData<TInput> ctx: types.TaskCallContext<TInput, TOutput>
): Promise<types.TaskResponse<TOutput>> ): Promise<types.ParsedData<TOutput>>
// TODO // TODO
// abstract stream({ // abstract stream({

Wyświetl plik

@ -1,5 +1,6 @@
import * as anthropic from '@anthropic-ai/sdk' import * as anthropic from '@anthropic-ai/sdk'
import * as openai from 'openai-fetch' import * as openai from 'openai-fetch'
import type { Options as RetryOptions } from 'p-retry'
import { import {
SafeParseReturnType, SafeParseReturnType,
ZodObject, ZodObject,
@ -101,8 +102,7 @@ export interface LLMExample {
output: string output: string
} }
export interface RetryConfig { export interface RetryConfig extends RetryOptions {
retries: number
strategy: string strategy: string
} }
@ -132,15 +132,28 @@ export interface TaskResponseMetadata extends Record<string, any> {
export interface LLMTaskResponseMetadata< export interface LLMTaskResponseMetadata<
TChatCompletionResponse extends Record<string, any> = Record<string, any> TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends TaskResponseMetadata { > extends TaskResponseMetadata {
messages?: ChatMessage[]
completion?: TChatCompletionResponse completion?: TChatCompletionResponse
} }
export interface TaskResponse< export interface TaskResponse<
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>, TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TMetadata extends Record<string, any> = Record<string, any> TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> { > {
result: ParsedData<TOutput> result: ParsedData<TOutput>
metadata: TMetadata metadata: TMetadata
} }
export interface TaskCallContext<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
input?: ParsedData<TInput>
retryMessage?: string
attemptNumber: number
metadata: Partial<TMetadata>
}
// export type ProgressFunction = (partialResponse: ChatMessage) => void // export type ProgressFunction = (partialResponse: ChatMessage) => void