Travis Fischer 2023-06-09 22:42:20 -07:00
rodzic 9d54530880
commit 7596404ccc
3 zmienionych plików z 59 dodań i 43 usunięć

Wyświetl plik

@ -143,7 +143,10 @@ export abstract class BaseChatModel<
messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
public async buildMessages(input?: types.ParsedData<TInput>) {
public async buildMessages(
input?: types.ParsedData<TInput>,
ctx?: types.TaskCallContext
) {
if (this._inputSchema) {
const inputSchema =
this._inputSchema instanceof z.ZodType
@ -211,20 +214,29 @@ export abstract class BaseChatModel<
}
}
if (ctx?.retryMessage) {
messages.push({
role: 'system',
content: ctx.retryMessage
})
}
// TODO: filter/compress messages based on token counts
return messages
}
protected override async _call(
input?: types.ParsedData<TInput>
): Promise<types.TaskResponse<TOutput>> {
const messages = await this.buildMessages(input)
ctx: types.TaskCallContext<TInput, TOutput, types.LLMTaskResponseMetadata>
): Promise<types.ParsedData<TOutput>> {
const messages = await this.buildMessages(ctx.input, ctx)
console.log('>>>')
console.log(messages)
const completion = await this._createChatCompletion(messages)
ctx.metadata.completion = completion
let output: any = completion.message.content
console.log('===')
@ -246,7 +258,7 @@ export abstract class BaseChatModel<
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON: ${err.message}`,
`Invalid JSON array: ${err.message}`,
{ cause: err }
)
} else {
@ -262,7 +274,7 @@ export abstract class BaseChatModel<
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON: ${err.message}`,
`Invalid JSON object: ${err.message}`,
{ cause: err }
)
} else {
@ -310,23 +322,9 @@ export abstract class BaseChatModel<
throw new errors.ZodOutputValidationError(safeResult.error)
}
return {
result: safeResult.data,
metadata: {
input,
messages,
completion
}
}
return safeResult.data
} else {
return {
result: output,
metadata: {
input,
messages,
completion
}
}
return output
}
}

Wyświetl plik

@ -1,4 +1,4 @@
import pRetry from 'p-retry'
import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodRawShape, ZodTypeAny } from 'zod'
import * as errors from '@/errors'
@ -65,35 +65,40 @@ export abstract class BaseTask<
public async callWithMetadata(
input?: types.ParsedData<TInput>
): Promise<types.TaskResponse<TOutput>> {
const metadata: types.TaskResponseMetadata = {
const ctx: types.TaskCallContext<TInput, TOutput> = {
input,
numRetries: 0
attemptNumber: 0,
metadata: {}
}
do {
try {
const response = await this._call(input)
return response
} catch (err: any) {
const result = await pRetry(() => this._call(ctx), {
...this._retryConfig,
onFailedAttempt: async (err: FailedAttemptError) => {
if (this._retryConfig.onFailedAttempt) {
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
}
ctx.attemptNumber = err.attemptNumber + 1
if (err instanceof errors.ZodOutputValidationError) {
// TODO
ctx.retryMessage = err.message
} else if (err instanceof errors.OutputValidationError) {
ctx.retryMessage = err.message
} else {
throw err
}
}
})
// TODO: handle errors, retry logic, and self-healing
metadata.numRetries = (metadata.numRetries ?? 0) + 1
if (metadata.numRetries > this._retryConfig.retries) {
}
// eslint-disable-next-line no-constant-condition
} while (true)
return {
result,
metadata: ctx.metadata
}
}
protected abstract _call(
input?: types.ParsedData<TInput>
): Promise<types.TaskResponse<TOutput>>
ctx: types.TaskCallContext<TInput, TOutput>
): Promise<types.ParsedData<TOutput>>
// TODO
// abstract stream({

Wyświetl plik

@ -1,5 +1,6 @@
import * as anthropic from '@anthropic-ai/sdk'
import * as openai from 'openai-fetch'
import type { Options as RetryOptions } from 'p-retry'
import {
SafeParseReturnType,
ZodObject,
@ -101,8 +102,7 @@ export interface LLMExample {
output: string
}
export interface RetryConfig {
retries: number
export interface RetryConfig extends RetryOptions {
strategy: string
}
@ -132,15 +132,28 @@ export interface TaskResponseMetadata extends Record<string, any> {
export interface LLMTaskResponseMetadata<
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends TaskResponseMetadata {
messages?: ChatMessage[]
completion?: TChatCompletionResponse
}
export interface TaskResponse<
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TMetadata extends Record<string, any> = Record<string, any>
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
result: ParsedData<TOutput>
metadata: TMetadata
}
export interface TaskCallContext<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> {
input?: ParsedData<TInput>
retryMessage?: string
attemptNumber: number
metadata: Partial<TMetadata>
}
// export type ProgressFunction = (partialResponse: ChatMessage) => void