Travis Fischer 2023-06-09 23:11:10 -07:00
rodzic 7596404ccc
commit 72c499449d
6 zmienionych plików z 29 dodań i 28 usunięć

Wyświetl plik

@ -227,7 +227,7 @@ export abstract class BaseChatModel<
} }
protected override async _call( protected override async _call(
ctx: types.TaskCallContext<TInput, TOutput, types.LLMTaskResponseMetadata> ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
): Promise<types.ParsedData<TOutput>> { ): Promise<types.ParsedData<TOutput>> {
const messages = await this.buildMessages(ctx.input, ctx) const messages = await this.buildMessages(ctx.input, ctx)

Wyświetl plik

@ -1,5 +1,5 @@
import pRetry, { FailedAttemptError } from 'p-retry' import pRetry, { FailedAttemptError } from 'p-retry'
import { ZodRawShape, ZodTypeAny } from 'zod' import { ZodRawShape, ZodTypeAny, z } from 'zod'
import * as errors from '@/errors' import * as errors from '@/errors'
import * as types from '@/types' import * as types from '@/types'
@ -65,7 +65,21 @@ export abstract class BaseTask<
public async callWithMetadata( public async callWithMetadata(
input?: types.ParsedData<TInput> input?: types.ParsedData<TInput>
): Promise<types.TaskResponse<TOutput>> { ): Promise<types.TaskResponse<TOutput>> {
const ctx: types.TaskCallContext<TInput, TOutput> = { if (this.inputSchema) {
const inputSchema =
this.inputSchema instanceof z.ZodType
? this.inputSchema
: z.object(this.inputSchema)
const safeInput = inputSchema.safeParse(input)
if (!safeInput.success) {
throw new Error(`Invalid input: ${safeInput.error.message}`)
}
input = safeInput.data
}
const ctx: types.TaskCallContext<TInput> = {
input, input,
attemptNumber: 0, attemptNumber: 0,
metadata: {} metadata: {}
@ -97,7 +111,7 @@ export abstract class BaseTask<
} }
protected abstract _call( protected abstract _call(
ctx: types.TaskCallContext<TInput, TOutput> ctx: types.TaskCallContext<TInput>
): Promise<types.ParsedData<TOutput>> ): Promise<types.ParsedData<TOutput>>
// TODO // TODO

Wyświetl plik

@ -151,13 +151,13 @@ export function getContextSizeForModel(model: string): number {
} }
} }
export const calculateMaxTokens = async ({ export async function calculateMaxTokens({
prompt, prompt,
modelName modelName
}: { }: {
prompt: string prompt: string
modelName: string modelName: string
}) => { }) {
// fallback to approximate calculation if tiktoken is not available // fallback to approximate calculation if tiktoken is not available
let numTokens = Math.ceil(prompt.length / 4) let numTokens = Math.ceil(prompt.length / 4)

Wyświetl plik

@ -59,19 +59,14 @@ export class MetaphorSearchTool extends BaseTask<
} }
protected override async _call( protected override async _call(
input: MetaphorSearchToolInput ctx: types.TaskCallContext<typeof MetaphorSearchToolInputSchema>
): Promise<types.TaskResponse<typeof MetaphorSearchToolOutputSchema>> { ): Promise<MetaphorSearchToolOutput> {
// TODO: handle errors gracefully // TODO: test required inputs
input = this.inputSchema.parse(input)
const result = await this._metaphorClient.search({ const result = await this._metaphorClient.search({
query: input.query, query: ctx.input!.query,
numResults: input.numResults numResults: ctx.input!.numResults
}) })
return { return result
result,
metadata: {}
}
} }
} }

Wyświetl plik

@ -64,15 +64,8 @@ export class NovuNotificationTool extends BaseTask<
} }
protected override async _call( protected override async _call(
input: NovuNotificationToolInput ctx: types.TaskCallContext<typeof NovuNotificationToolInputSchema>
): Promise<types.TaskResponse<typeof NovuNotificationToolOutputSchema>> { ): Promise<NovuNotificationToolOutput> {
// TODO: handle errors gracefully return this._novuClient.triggerEvent(ctx.input!)
input = this.inputSchema.parse(input)
const result = await this._novuClient.triggerEvent(input)
return {
result,
metadata: {}
}
} }
} }

Wyświetl plik

@ -146,7 +146,6 @@ export interface TaskResponse<
export interface TaskCallContext< export interface TaskCallContext<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TMetadata extends TaskResponseMetadata = TaskResponseMetadata TMetadata extends TaskResponseMetadata = TaskResponseMetadata
> { > {
input?: ParsedData<TInput> input?: ParsedData<TInput>