kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
rodzic
7596404ccc
commit
72c499449d
|
@ -227,7 +227,7 @@ export abstract class BaseChatModel<
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override async _call(
|
protected override async _call(
|
||||||
ctx: types.TaskCallContext<TInput, TOutput, types.LLMTaskResponseMetadata>
|
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
|
||||||
): Promise<types.ParsedData<TOutput>> {
|
): Promise<types.ParsedData<TOutput>> {
|
||||||
const messages = await this.buildMessages(ctx.input, ctx)
|
const messages = await this.buildMessages(ctx.input, ctx)
|
||||||
|
|
||||||
|
|
20
src/task.ts
20
src/task.ts
|
@ -1,5 +1,5 @@
|
||||||
import pRetry, { FailedAttemptError } from 'p-retry'
|
import pRetry, { FailedAttemptError } from 'p-retry'
|
||||||
import { ZodRawShape, ZodTypeAny } from 'zod'
|
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
||||||
|
|
||||||
import * as errors from '@/errors'
|
import * as errors from '@/errors'
|
||||||
import * as types from '@/types'
|
import * as types from '@/types'
|
||||||
|
@ -65,7 +65,21 @@ export abstract class BaseTask<
|
||||||
public async callWithMetadata(
|
public async callWithMetadata(
|
||||||
input?: types.ParsedData<TInput>
|
input?: types.ParsedData<TInput>
|
||||||
): Promise<types.TaskResponse<TOutput>> {
|
): Promise<types.TaskResponse<TOutput>> {
|
||||||
const ctx: types.TaskCallContext<TInput, TOutput> = {
|
if (this.inputSchema) {
|
||||||
|
const inputSchema =
|
||||||
|
this.inputSchema instanceof z.ZodType
|
||||||
|
? this.inputSchema
|
||||||
|
: z.object(this.inputSchema)
|
||||||
|
|
||||||
|
const safeInput = inputSchema.safeParse(input)
|
||||||
|
if (!safeInput.success) {
|
||||||
|
throw new Error(`Invalid input: ${safeInput.error.message}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
input = safeInput.data
|
||||||
|
}
|
||||||
|
|
||||||
|
const ctx: types.TaskCallContext<TInput> = {
|
||||||
input,
|
input,
|
||||||
attemptNumber: 0,
|
attemptNumber: 0,
|
||||||
metadata: {}
|
metadata: {}
|
||||||
|
@ -97,7 +111,7 @@ export abstract class BaseTask<
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract _call(
|
protected abstract _call(
|
||||||
ctx: types.TaskCallContext<TInput, TOutput>
|
ctx: types.TaskCallContext<TInput>
|
||||||
): Promise<types.ParsedData<TOutput>>
|
): Promise<types.ParsedData<TOutput>>
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
|
|
|
@ -151,13 +151,13 @@ export function getContextSizeForModel(model: string): number {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const calculateMaxTokens = async ({
|
export async function calculateMaxTokens({
|
||||||
prompt,
|
prompt,
|
||||||
modelName
|
modelName
|
||||||
}: {
|
}: {
|
||||||
prompt: string
|
prompt: string
|
||||||
modelName: string
|
modelName: string
|
||||||
}) => {
|
}) {
|
||||||
// fallback to approximate calculation if tiktoken is not available
|
// fallback to approximate calculation if tiktoken is not available
|
||||||
let numTokens = Math.ceil(prompt.length / 4)
|
let numTokens = Math.ceil(prompt.length / 4)
|
||||||
|
|
||||||
|
|
|
@ -59,19 +59,14 @@ export class MetaphorSearchTool extends BaseTask<
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override async _call(
|
protected override async _call(
|
||||||
input: MetaphorSearchToolInput
|
ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema>
|
||||||
): Promise<types.TaskResponse<typeof MetaphorSearchToolOutputSchema>> {
|
): Promise<MetaphorSearchToolOutput> {
|
||||||
// TODO: handle errors gracefully
|
// TODO: test required inputs
|
||||||
input = this.inputSchema.parse(input)
|
|
||||||
|
|
||||||
const result = await this._metaphorClient.search({
|
const result = await this._metaphorClient.search({
|
||||||
query: input.query,
|
query: ctx.input!.query,
|
||||||
numResults: input.numResults
|
numResults: ctx.input!.numResults
|
||||||
})
|
})
|
||||||
|
|
||||||
return {
|
return result
|
||||||
result,
|
|
||||||
metadata: {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,15 +64,8 @@ export class NovuNotificationTool extends BaseTask<
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override async _call(
|
protected override async _call(
|
||||||
input: NovuNotificationToolInput
|
ctx: types.TaskCallContext<typeof NovuNotificationToolInputSchema>
|
||||||
): Promise<types.TaskResponse<typeof NovuNotificationToolOutputSchema>> {
|
): Promise<NovuNotificationToolOutput> {
|
||||||
// TODO: handle errors gracefully
|
return this._novuClient.triggerEvent(ctx.input!)
|
||||||
input = this.inputSchema.parse(input)
|
|
||||||
|
|
||||||
const result = await this._novuClient.triggerEvent(input)
|
|
||||||
return {
|
|
||||||
result,
|
|
||||||
metadata: {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -146,7 +146,6 @@ export interface TaskResponse<
|
||||||
|
|
||||||
export interface TaskCallContext<
|
export interface TaskCallContext<
|
||||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
|
||||||
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
TMetadata extends TaskResponseMetadata = TaskResponseMetadata
|
||||||
> {
|
> {
|
||||||
input?: ParsedData<TInput>
|
input?: ParsedData<TInput>
|
||||||
|
|
Ładowanie…
Reference in New Issue