kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
rodzic
9d54530880
commit
7596404ccc
|
@ -143,7 +143,10 @@ export abstract class BaseChatModel<
|
|||
messages: types.ChatMessage[]
|
||||
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
||||
|
||||
public async buildMessages(input?: types.ParsedData<TInput>) {
|
||||
public async buildMessages(
|
||||
input?: types.ParsedData<TInput>,
|
||||
ctx?: types.TaskCallContext
|
||||
) {
|
||||
if (this._inputSchema) {
|
||||
const inputSchema =
|
||||
this._inputSchema instanceof z.ZodType
|
||||
|
@ -211,20 +214,29 @@ export abstract class BaseChatModel<
|
|||
}
|
||||
}
|
||||
|
||||
if (ctx?.retryMessage) {
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: ctx.retryMessage
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: filter/compress messages based on token counts
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
protected override async _call(
|
||||
input?: types.ParsedData<TInput>
|
||||
): Promise<types.TaskResponse<TOutput>> {
|
||||
const messages = await this.buildMessages(input)
|
||||
ctx: types.TaskCallContext<TInput, TOutput, types.LLMTaskResponseMetadata>
|
||||
): Promise<types.ParsedData<TOutput>> {
|
||||
const messages = await this.buildMessages(ctx.input, ctx)
|
||||
|
||||
console.log('>>>')
|
||||
console.log(messages)
|
||||
|
||||
const completion = await this._createChatCompletion(messages)
|
||||
ctx.metadata.completion = completion
|
||||
|
||||
let output: any = completion.message.content
|
||||
|
||||
console.log('===')
|
||||
|
@ -246,7 +258,7 @@ export abstract class BaseChatModel<
|
|||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
} else if (err instanceof SyntaxError) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid JSON: ${err.message}`,
|
||||
`Invalid JSON array: ${err.message}`,
|
||||
{ cause: err }
|
||||
)
|
||||
} else {
|
||||
|
@ -262,7 +274,7 @@ export abstract class BaseChatModel<
|
|||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
} else if (err instanceof SyntaxError) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid JSON: ${err.message}`,
|
||||
`Invalid JSON object: ${err.message}`,
|
||||
{ cause: err }
|
||||
)
|
||||
} else {
|
||||
|
@ -310,23 +322,9 @@ export abstract class BaseChatModel<
|
|||
throw new errors.ZodOutputValidationError(safeResult.error)
|
||||
}
|
||||
|
||||
return {
|
||||
result: safeResult.data,
|
||||
metadata: {
|
||||
input,
|
||||
messages,
|
||||
completion
|
||||
}
|
||||
}
|
||||
return safeResult.data
|
||||
} else {
|
||||
return {
|
||||
result: output,
|
||||
metadata: {
|
||||
input,
|
||||
messages,
|
||||
completion
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
}
|
||||
|
||||
|
|
41
src/task.ts
41
src/task.ts
|
@ -1,4 +1,4 @@
|
|||
import pRetry from 'p-retry'
|
||||
import pRetry, { FailedAttemptError } from 'p-retry'
|
||||
import { ZodRawShape, ZodTypeAny } from 'zod'
|
||||
|
||||
import * as errors from '@/errors'
|
||||
|
@ -65,35 +65,40 @@ export abstract class BaseTask<
|
|||
public async callWithMetadata(
|
||||
input?: types.ParsedData<TInput>
|
||||
): Promise<types.TaskResponse<TOutput>> {
|
||||
const metadata: types.TaskResponseMetadata = {
|
||||
const ctx: types.TaskCallContext<TInput, TOutput> = {
|
||||
input,
|
||||
numRetries: 0
|
||||
attemptNumber: 0,
|
||||
metadata: {}
|
||||
}
|
||||
|
||||
do {
|
||||
try {
|
||||
const response = await this._call(input)
|
||||
return response
|
||||
} catch (err: any) {
|
||||
const result = await pRetry(() => this._call(ctx), {
|
||||
...this._retryConfig,
|
||||
onFailedAttempt: async (err: FailedAttemptError) => {
|
||||
if (this._retryConfig.onFailedAttempt) {
|
||||
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
|
||||
}
|
||||
|
||||
ctx.attemptNumber = err.attemptNumber + 1
|
||||
|
||||
if (err instanceof errors.ZodOutputValidationError) {
|
||||
// TODO
|
||||
ctx.retryMessage = err.message
|
||||
} else if (err instanceof errors.OutputValidationError) {
|
||||
ctx.retryMessage = err.message
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// TODO: handle errors, retry logic, and self-healing
|
||||
metadata.numRetries = (metadata.numRetries ?? 0) + 1
|
||||
if (metadata.numRetries > this._retryConfig.retries) {
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
} while (true)
|
||||
return {
|
||||
result,
|
||||
metadata: ctx.metadata
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract _call(
|
||||
input?: types.ParsedData<TInput>
|
||||
): Promise<types.TaskResponse<TOutput>>
|
||||
ctx: types.TaskCallContext<TInput, TOutput>
|
||||
): Promise<types.ParsedData<TOutput>>
|
||||
|
||||
// TODO
|
||||
// abstract stream({
|
||||
|
|
19
src/types.ts
19
src/types.ts
|
@ -1,5 +1,6 @@
|
|||
import * as anthropic from '@anthropic-ai/sdk'
|
||||
import * as openai from 'openai-fetch'
|
||||
import type { Options as RetryOptions } from 'p-retry'
|
||||
import {
|
||||
SafeParseReturnType,
|
||||
ZodObject,
|
||||
|
@ -101,8 +102,7 @@ export interface LLMExample {
|
|||
output: string
|
||||
}
|
||||
|
||||
export interface RetryConfig {
|
||||
retries: number
|
||||
export interface RetryConfig extends RetryOptions {
|
||||
strategy: string
|
||||
}
|
||||
|
||||
|
@ -132,15 +132,28 @@ export interface TaskResponseMetadata extends Record<string, any> {
|
|||
export interface LLMTaskResponseMetadata<
|
||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||
> extends TaskResponseMetadata {
|
||||
messages?: ChatMessage[]
|
||||
completion?: TChatCompletionResponse
|
||||
}
|
||||
|
||||
export interface TaskResponse<
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TMetadata extends Record<string, any> = Record<string, any>
|
||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||
> {
|
||||
result: ParsedData<TOutput>
|
||||
metadata: TMetadata
|
||||
}
|
||||
|
||||
export interface TaskCallContext<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||
> {
|
||||
input?: ParsedData<TInput>
|
||||
retryMessage?: string
|
||||
|
||||
attemptNumber: number
|
||||
metadata: Partial<TMetadata>
|
||||
}
|
||||
|
||||
// export type ProgressFunction = (partialResponse: ChatMessage) => void
|
||||
|
|
Ładowanie…
Reference in New Issue