kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: minor improvements to types and task
rodzic
cff74009c5
commit
113fdebee3
|
@ -4,8 +4,8 @@ export type Metadata = Record<string, unknown>;
|
|||
|
||||
|
||||
export abstract class BaseTask<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = ZodTypeAny
|
||||
> {
|
||||
|
||||
// ...
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { ZodRawShape, ZodTypeAny } from 'zod'
|
||||
import { ZodTypeAny } from 'zod'
|
||||
|
||||
import { Agentic } from '@/agentic'
|
||||
import { BaseTask } from '@/task'
|
||||
|
@ -37,8 +37,8 @@ export class HumanFeedbackMechanismCLI extends HumanFeedbackMechanism {
|
|||
}
|
||||
|
||||
export function withHumanFeedback<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = ZodTypeAny
|
||||
>(
|
||||
task: BaseTask<TInput, TOutput>,
|
||||
options: HumanFeedbackOptions = {
|
||||
|
|
|
@ -86,4 +86,19 @@ export class AnthropicChatModel<
|
|||
response
|
||||
}
|
||||
}
|
||||
|
||||
public override clone(): AnthropicChatModel<TInput, TOutput> {
|
||||
return new AnthropicChatModel<TInput, TOutput>({
|
||||
agentic: this._agentic,
|
||||
timeoutMs: this._timeoutMs,
|
||||
retryConfig: this._retryConfig,
|
||||
inputSchema: this._inputSchema,
|
||||
outputSchema: this._outputSchema,
|
||||
provider: this._provider,
|
||||
model: this._model,
|
||||
examples: this._examples,
|
||||
messages: this._messages,
|
||||
...this._modelParams
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ import { JSONRepairError, jsonrepair } from 'jsonrepair'
|
|||
import pMap from 'p-map'
|
||||
import { dedent } from 'ts-dedent'
|
||||
import { type SetRequired } from 'type-fest'
|
||||
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
||||
import { ZodTypeAny, z } from 'zod'
|
||||
import { printNode, zodToTs } from 'zod-to-ts'
|
||||
|
||||
import * as errors from '@/errors'
|
||||
|
@ -20,8 +20,8 @@ import {
|
|||
} from '@/utils'
|
||||
|
||||
export abstract class BaseLLM<
|
||||
TInput extends ZodRawShape | ZodTypeAny = z.ZodVoid,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TInput extends ZodTypeAny = z.ZodVoid,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
> extends BaseTask<TInput, TOutput> {
|
||||
protected _inputSchema: TInput | undefined
|
||||
|
@ -50,7 +50,7 @@ export abstract class BaseLLM<
|
|||
this._examples = options.examples
|
||||
}
|
||||
|
||||
input<U extends ZodRawShape | ZodTypeAny = TInput>(
|
||||
input<U extends ZodTypeAny = TInput>(
|
||||
inputSchema: U
|
||||
): BaseLLM<U, TOutput, TModelParams> {
|
||||
;(this as unknown as BaseLLM<U, TOutput, TModelParams>)._inputSchema =
|
||||
|
@ -58,7 +58,7 @@ export abstract class BaseLLM<
|
|||
return this as unknown as BaseLLM<U, TOutput, TModelParams>
|
||||
}
|
||||
|
||||
output<U extends ZodRawShape | ZodTypeAny = TOutput>(
|
||||
output<U extends ZodTypeAny = TOutput>(
|
||||
outputSchema: U
|
||||
): BaseLLM<TInput, U, TModelParams> {
|
||||
;(this as unknown as BaseLLM<TInput, U, TModelParams>)._outputSchema =
|
||||
|
@ -125,8 +125,8 @@ export abstract class BaseLLM<
|
|||
}
|
||||
|
||||
export abstract class BaseChatModel<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TModelParams extends Record<string, any> = Record<string, any>,
|
||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||
> extends BaseLLM<TInput, TOutput, TModelParams> {
|
||||
|
@ -152,13 +152,8 @@ export abstract class BaseChatModel<
|
|||
ctx?: types.TaskCallContext
|
||||
) {
|
||||
if (this._inputSchema) {
|
||||
const inputSchema =
|
||||
this._inputSchema instanceof z.ZodType
|
||||
? this._inputSchema
|
||||
: z.object(this._inputSchema)
|
||||
|
||||
// TODO: handle errors gracefully
|
||||
input = inputSchema.parse(input)
|
||||
input = this.inputSchema.parse(input)
|
||||
}
|
||||
|
||||
// TODO: validate input message variables against input schema
|
||||
|
@ -185,12 +180,7 @@ export abstract class BaseChatModel<
|
|||
}
|
||||
|
||||
if (this._outputSchema) {
|
||||
const outputSchema =
|
||||
this._outputSchema instanceof z.ZodType
|
||||
? this._outputSchema
|
||||
: z.object(this._outputSchema)
|
||||
|
||||
const { node } = zodToTs(outputSchema)
|
||||
const { node } = zodToTs(this._outputSchema)
|
||||
|
||||
if (node.kind === 152) {
|
||||
// handle raw strings differently
|
||||
|
@ -248,10 +238,7 @@ export abstract class BaseChatModel<
|
|||
console.log('<<<')
|
||||
|
||||
if (this._outputSchema) {
|
||||
const outputSchema =
|
||||
this._outputSchema instanceof z.ZodType
|
||||
? this._outputSchema
|
||||
: z.object(this._outputSchema)
|
||||
const outputSchema = this._outputSchema
|
||||
|
||||
if (outputSchema instanceof z.ZodArray) {
|
||||
try {
|
||||
|
|
|
@ -50,4 +50,19 @@ export class OpenAIChatModel<
|
|||
messages
|
||||
})
|
||||
}
|
||||
|
||||
public override clone(): OpenAIChatModel<TInput, TOutput> {
|
||||
return new OpenAIChatModel<TInput, TOutput>({
|
||||
agentic: this._agentic,
|
||||
timeoutMs: this._timeoutMs,
|
||||
retryConfig: this._retryConfig,
|
||||
inputSchema: this._inputSchema,
|
||||
outputSchema: this._outputSchema,
|
||||
provider: this._provider,
|
||||
model: this._model,
|
||||
examples: this._examples,
|
||||
messages: this._messages,
|
||||
...this._modelParams
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
31
src/task.ts
31
src/task.ts
|
@ -1,5 +1,5 @@
|
|||
import pRetry, { FailedAttemptError } from 'p-retry'
|
||||
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
||||
import { ZodTypeAny } from 'zod'
|
||||
|
||||
import * as errors from '@/errors'
|
||||
import * as types from '@/types'
|
||||
|
@ -19,8 +19,8 @@ import { Agentic } from '@/agentic'
|
|||
* - Invoking sub-agents
|
||||
*/
|
||||
export abstract class BaseTask<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = ZodTypeAny
|
||||
> {
|
||||
protected _agentic: Agentic
|
||||
protected _id: string
|
||||
|
@ -51,6 +51,24 @@ export abstract class BaseTask<
|
|||
|
||||
public abstract get name(): string
|
||||
|
||||
public serialize(): types.SerializedTask {
|
||||
return {
|
||||
_taskName: this.name
|
||||
// inputSchema: this.inputSchema.serialize()
|
||||
}
|
||||
}
|
||||
|
||||
// public abstract deserialize<
|
||||
// TInput extends ZodTypeAny = ZodTypeAny,
|
||||
// TOutput extends ZodTypeAny = ZodTypeAny
|
||||
// >(data: types.SerializedTask): BaseTask<TInput, TOutput>
|
||||
|
||||
// TODO: is this really necessary?
|
||||
public clone(): BaseTask<TInput, TOutput> {
|
||||
// TODO: override in subclass if needed
|
||||
throw new Error(`clone not implemented for task "${this.name}"`)
|
||||
}
|
||||
|
||||
public retryConfig(retryConfig: types.RetryConfig) {
|
||||
this._retryConfig = retryConfig
|
||||
return this
|
||||
|
@ -67,12 +85,7 @@ export abstract class BaseTask<
|
|||
input?: types.ParsedData<TInput>
|
||||
): Promise<types.TaskResponse<TOutput>> {
|
||||
if (this.inputSchema) {
|
||||
const inputSchema =
|
||||
this.inputSchema instanceof z.ZodType
|
||||
? this.inputSchema
|
||||
: z.object(this.inputSchema)
|
||||
|
||||
const safeInput = inputSchema.safeParse(input)
|
||||
const safeInput = this.inputSchema.safeParse(input)
|
||||
|
||||
if (!safeInput.success) {
|
||||
throw new Error(`Invalid input: ${safeInput.error.message}`)
|
||||
|
|
|
@ -38,13 +38,15 @@ export class MetaphorSearchTool extends BaseTask<
|
|||
|
||||
constructor({
|
||||
agentic,
|
||||
metaphorClient = new MetaphorClient()
|
||||
metaphorClient = new MetaphorClient(),
|
||||
...rest
|
||||
}: {
|
||||
agentic: Agentic
|
||||
metaphorClient?: MetaphorClient
|
||||
}) {
|
||||
} & types.BaseTaskOptions) {
|
||||
super({
|
||||
agentic
|
||||
agentic,
|
||||
...rest
|
||||
})
|
||||
|
||||
this._metaphorClient = metaphorClient
|
||||
|
@ -66,11 +68,9 @@ export class MetaphorSearchTool extends BaseTask<
|
|||
ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema>
|
||||
): Promise<MetaphorSearchToolOutput> {
|
||||
// TODO: test required inputs
|
||||
const result = await this._metaphorClient.search({
|
||||
return this._metaphorClient.search({
|
||||
query: ctx.input!.query,
|
||||
numResults: ctx.input!.numResults
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
|
48
src/types.ts
48
src/types.ts
|
@ -1,33 +1,21 @@
|
|||
import * as anthropic from '@anthropic-ai/sdk'
|
||||
import * as openai from 'openai-fetch'
|
||||
import type { Options as RetryOptions } from 'p-retry'
|
||||
import {
|
||||
SafeParseReturnType,
|
||||
ZodObject,
|
||||
ZodRawShape,
|
||||
ZodTypeAny,
|
||||
output,
|
||||
z
|
||||
} from 'zod'
|
||||
import type { JsonObject } from 'type-fest'
|
||||
import { SafeParseReturnType, ZodTypeAny, output, z } from 'zod'
|
||||
|
||||
import type { Agentic } from './agentic'
|
||||
|
||||
export { openai }
|
||||
export { anthropic }
|
||||
|
||||
export type ParsedData<T extends ZodRawShape | ZodTypeAny> =
|
||||
T extends ZodTypeAny
|
||||
? output<T>
|
||||
: T extends ZodRawShape
|
||||
? output<ZodObject<T>>
|
||||
: never
|
||||
export type ParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
|
||||
? output<T>
|
||||
: never
|
||||
|
||||
export type SafeParsedData<T extends ZodRawShape | ZodTypeAny> =
|
||||
T extends ZodTypeAny
|
||||
? SafeParseReturnType<z.infer<T>, ParsedData<T>>
|
||||
: T extends ZodRawShape
|
||||
? SafeParseReturnType<ZodObject<T>, ParsedData<T>>
|
||||
: never
|
||||
export type SafeParsedData<T extends ZodTypeAny> = T extends ZodTypeAny
|
||||
? SafeParseReturnType<z.infer<T>, ParsedData<T>>
|
||||
: never
|
||||
|
||||
export interface BaseTaskOptions {
|
||||
agentic: Agentic
|
||||
|
@ -43,8 +31,8 @@ export interface BaseTaskOptions {
|
|||
}
|
||||
|
||||
export interface BaseLLMOptions<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
> extends BaseTaskOptions {
|
||||
inputSchema?: TInput
|
||||
|
@ -57,8 +45,8 @@ export interface BaseLLMOptions<
|
|||
}
|
||||
|
||||
export interface LLMOptions<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
||||
promptTemplate?: string
|
||||
|
@ -81,8 +69,8 @@ export interface ChatMessage {
|
|||
}
|
||||
|
||||
export interface ChatModelOptions<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
> extends BaseLLMOptions<TInput, TOutput, TModelParams> {
|
||||
messages: ChatMessage[]
|
||||
|
@ -132,7 +120,7 @@ export interface LLMTaskResponseMetadata<
|
|||
}
|
||||
|
||||
export interface TaskResponse<
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TOutput extends ZodTypeAny = z.ZodType<string>,
|
||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||
> {
|
||||
result: ParsedData<TOutput>
|
||||
|
@ -140,7 +128,7 @@ export interface TaskResponse<
|
|||
}
|
||||
|
||||
export interface TaskCallContext<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TInput extends ZodTypeAny = ZodTypeAny,
|
||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||
> {
|
||||
input?: ParsedData<TInput>
|
||||
|
@ -152,4 +140,8 @@ export interface TaskCallContext<
|
|||
|
||||
export type IDGeneratorFunction = () => string
|
||||
|
||||
export interface SerializedTask extends JsonObject {
|
||||
_taskName: string
|
||||
}
|
||||
|
||||
// export type ProgressFunction = (partialResponse: ChatMessage) => void
|
||||
|
|
Ładowanie…
Reference in New Issue