kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
rodzic
9d54530880
commit
7596404ccc
|
@ -143,7 +143,10 @@ export abstract class BaseChatModel<
|
||||||
messages: types.ChatMessage[]
|
messages: types.ChatMessage[]
|
||||||
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
||||||
|
|
||||||
public async buildMessages(input?: types.ParsedData<TInput>) {
|
public async buildMessages(
|
||||||
|
input?: types.ParsedData<TInput>,
|
||||||
|
ctx?: types.TaskCallContext
|
||||||
|
) {
|
||||||
if (this._inputSchema) {
|
if (this._inputSchema) {
|
||||||
const inputSchema =
|
const inputSchema =
|
||||||
this._inputSchema instanceof z.ZodType
|
this._inputSchema instanceof z.ZodType
|
||||||
|
@ -211,20 +214,29 @@ export abstract class BaseChatModel<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ctx?.retryMessage) {
|
||||||
|
messages.push({
|
||||||
|
role: 'system',
|
||||||
|
content: ctx.retryMessage
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: filter/compress messages based on token counts
|
// TODO: filter/compress messages based on token counts
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override async _call(
|
protected override async _call(
|
||||||
input?: types.ParsedData<TInput>
|
ctx: types.TaskCallContext<TInput, TOutput, types.LLMTaskResponseMetadata>
|
||||||
): Promise<types.TaskResponse<TOutput>> {
|
): Promise<types.ParsedData<TOutput>> {
|
||||||
const messages = await this.buildMessages(input)
|
const messages = await this.buildMessages(ctx.input, ctx)
|
||||||
|
|
||||||
console.log('>>>')
|
console.log('>>>')
|
||||||
console.log(messages)
|
console.log(messages)
|
||||||
|
|
||||||
const completion = await this._createChatCompletion(messages)
|
const completion = await this._createChatCompletion(messages)
|
||||||
|
ctx.metadata.completion = completion
|
||||||
|
|
||||||
let output: any = completion.message.content
|
let output: any = completion.message.content
|
||||||
|
|
||||||
console.log('===')
|
console.log('===')
|
||||||
|
@ -246,7 +258,7 @@ export abstract class BaseChatModel<
|
||||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||||
} else if (err instanceof SyntaxError) {
|
} else if (err instanceof SyntaxError) {
|
||||||
throw new errors.OutputValidationError(
|
throw new errors.OutputValidationError(
|
||||||
`Invalid JSON: ${err.message}`,
|
`Invalid JSON array: ${err.message}`,
|
||||||
{ cause: err }
|
{ cause: err }
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
@ -262,7 +274,7 @@ export abstract class BaseChatModel<
|
||||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||||
} else if (err instanceof SyntaxError) {
|
} else if (err instanceof SyntaxError) {
|
||||||
throw new errors.OutputValidationError(
|
throw new errors.OutputValidationError(
|
||||||
`Invalid JSON: ${err.message}`,
|
`Invalid JSON object: ${err.message}`,
|
||||||
{ cause: err }
|
{ cause: err }
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
@ -310,23 +322,9 @@ export abstract class BaseChatModel<
|
||||||
throw new errors.ZodOutputValidationError(safeResult.error)
|
throw new errors.ZodOutputValidationError(safeResult.error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return safeResult.data
|
||||||
result: safeResult.data,
|
|
||||||
metadata: {
|
|
||||||
input,
|
|
||||||
messages,
|
|
||||||
completion
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return {
|
return output
|
||||||
result: output,
|
|
||||||
metadata: {
|
|
||||||
input,
|
|
||||||
messages,
|
|
||||||
completion
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
41
src/task.ts
41
src/task.ts
|
@ -1,4 +1,4 @@
|
||||||
import pRetry from 'p-retry'
|
import pRetry, { FailedAttemptError } from 'p-retry'
|
||||||
import { ZodRawShape, ZodTypeAny } from 'zod'
|
import { ZodRawShape, ZodTypeAny } from 'zod'
|
||||||
|
|
||||||
import * as errors from '@/errors'
|
import * as errors from '@/errors'
|
||||||
|
@ -65,35 +65,40 @@ export abstract class BaseTask<
|
||||||
public async callWithMetadata(
|
public async callWithMetadata(
|
||||||
input?: types.ParsedData<TInput>
|
input?: types.ParsedData<TInput>
|
||||||
): Promise<types.TaskResponse<TOutput>> {
|
): Promise<types.TaskResponse<TOutput>> {
|
||||||
const metadata: types.TaskResponseMetadata = {
|
const ctx: types.TaskCallContext<TInput, TOutput> = {
|
||||||
input,
|
input,
|
||||||
numRetries: 0
|
attemptNumber: 0,
|
||||||
|
metadata: {}
|
||||||
}
|
}
|
||||||
|
|
||||||
do {
|
const result = await pRetry(() => this._call(ctx), {
|
||||||
try {
|
...this._retryConfig,
|
||||||
const response = await this._call(input)
|
onFailedAttempt: async (err: FailedAttemptError) => {
|
||||||
return response
|
if (this._retryConfig.onFailedAttempt) {
|
||||||
} catch (err: any) {
|
await Promise.resolve(this._retryConfig.onFailedAttempt(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.attemptNumber = err.attemptNumber + 1
|
||||||
|
|
||||||
if (err instanceof errors.ZodOutputValidationError) {
|
if (err instanceof errors.ZodOutputValidationError) {
|
||||||
// TODO
|
ctx.retryMessage = err.message
|
||||||
|
} else if (err instanceof errors.OutputValidationError) {
|
||||||
|
ctx.retryMessage = err.message
|
||||||
} else {
|
} else {
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// TODO: handle errors, retry logic, and self-healing
|
return {
|
||||||
metadata.numRetries = (metadata.numRetries ?? 0) + 1
|
result,
|
||||||
if (metadata.numRetries > this._retryConfig.retries) {
|
metadata: ctx.metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
// eslint-disable-next-line no-constant-condition
|
|
||||||
} while (true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract _call(
|
protected abstract _call(
|
||||||
input?: types.ParsedData<TInput>
|
ctx: types.TaskCallContext<TInput, TOutput>
|
||||||
): Promise<types.TaskResponse<TOutput>>
|
): Promise<types.ParsedData<TOutput>>
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
// abstract stream({
|
// abstract stream({
|
||||||
|
|
19
src/types.ts
19
src/types.ts
|
@ -1,5 +1,6 @@
|
||||||
import * as anthropic from '@anthropic-ai/sdk'
|
import * as anthropic from '@anthropic-ai/sdk'
|
||||||
import * as openai from 'openai-fetch'
|
import * as openai from 'openai-fetch'
|
||||||
|
import type { Options as RetryOptions } from 'p-retry'
|
||||||
import {
|
import {
|
||||||
SafeParseReturnType,
|
SafeParseReturnType,
|
||||||
ZodObject,
|
ZodObject,
|
||||||
|
@ -101,8 +102,7 @@ export interface LLMExample {
|
||||||
output: string
|
output: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface RetryConfig {
|
export interface RetryConfig extends RetryOptions {
|
||||||
retries: number
|
|
||||||
strategy: string
|
strategy: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,15 +132,28 @@ export interface TaskResponseMetadata extends Record<string, any> {
|
||||||
export interface LLMTaskResponseMetadata<
|
export interface LLMTaskResponseMetadata<
|
||||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||||
> extends TaskResponseMetadata {
|
> extends TaskResponseMetadata {
|
||||||
|
messages?: ChatMessage[]
|
||||||
completion?: TChatCompletionResponse
|
completion?: TChatCompletionResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TaskResponse<
|
export interface TaskResponse<
|
||||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||||
TMetadata extends Record<string, any> = Record<string, any>
|
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||||
> {
|
> {
|
||||||
result: ParsedData<TOutput>
|
result: ParsedData<TOutput>
|
||||||
metadata: TMetadata
|
metadata: TMetadata
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface TaskCallContext<
|
||||||
|
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||||
|
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||||
|
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||||
|
> {
|
||||||
|
input?: ParsedData<TInput>
|
||||||
|
retryMessage?: string
|
||||||
|
|
||||||
|
attemptNumber: number
|
||||||
|
metadata: Partial<TMetadata>
|
||||||
|
}
|
||||||
|
|
||||||
// export type ProgressFunction = (partialResponse: ChatMessage) => void
|
// export type ProgressFunction = (partialResponse: ChatMessage) => void
|
||||||
|
|
Ładowanie…
Reference in New Issue