From 6a96f5429bf37e0a769a6603df577aa912b6a99a Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Sat, 10 Jun 2023 19:21:09 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/index.ts | 1 + src/task.ts | 5 +++++ src/types.ts | 10 ++-------- test/openai.test.ts | 39 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/index.ts b/src/index.ts index 5c853df9..81a4b436 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,7 @@ export * from './agentic' export * from './task' export * from './llms' +export * from './errors' export * from './tokenizer' export * from './human-feedback' diff --git a/src/task.ts b/src/task.ts index 1e972186..f32d9a15 100644 --- a/src/task.ts +++ b/src/task.ts @@ -94,6 +94,7 @@ export abstract class BaseTask< } ctx.attemptNumber = err.attemptNumber + 1 + ctx.metadata.error = err if (err instanceof errors.ZodOutputValidationError) { ctx.retryMessage = err.message @@ -105,6 +106,10 @@ export abstract class BaseTask< } }) + ctx.metadata.success = true + ctx.metadata.numRetries = ctx.attemptNumber + ctx.metadata.error = undefined + return { result, metadata: ctx.metadata diff --git a/src/types.ts b/src/types.ts index 6d2a2acd..db1efe81 100644 --- a/src/types.ts +++ b/src/types.ts @@ -103,7 +103,7 @@ export interface LLMExample { } export interface RetryConfig extends RetryOptions { - strategy: string + strategy?: string } export type TaskError = @@ -118,21 +118,15 @@ export interface TaskResponseMetadata extends Record { // - task name // - task id - // config - input?: any - stream?: boolean - // execution info success?: boolean - numRetries?: number - errorType?: TaskError error?: Error + numRetries?: number } export interface LLMTaskResponseMetadata< TChatCompletionResponse extends Record = Record > extends TaskResponseMetadata { - messages?: ChatMessage[] completion?: TChatCompletionResponse } diff --git a/test/openai.test.ts b/test/openai.test.ts index 7034b65b..f97fafb5 100644 --- a/test/openai.test.ts +++ b/test/openai.test.ts @@ -1,7 +1,9 @@ import test from 'ava' import { expectTypeOf } from 'expect-type' +import sinon from 'sinon' import { z } from 'zod' +import { OutputValidationError } from '@/errors' import { OpenAIChatModel } from '@/llms/openai' import { createTestAgenticRuntime } from './_utils' @@ -89,3 +91,40 @@ test('OpenAIChatModel ⇒ boolean output', async (t) => { expectTypeOf(result).toMatchTypeOf() }) + +test('OpenAIChatModel ⇒ retry logic', async (t) => { + t.timeout(2 * 60 * 1000) + const agentic = createTestAgenticRuntime() + + const builder = new OpenAIChatModel({ + agentic, + modelParams: { + temperature: 0, + max_tokens: 30 + }, + retryConfig: { + retries: 2 + }, + messages: [ + { + role: 'system', + content: + 'You are a human, not a large language model. You are not an AI. Respond like a human would. Your name is Leeroy Jenkins.' + }, + { + role: 'user', + content: 'what is your name?' + } + ] + }) + + const fakeCall = sinon.fake.rejects(new OutputValidationError('test')) + sinon.replace(builder as any, '_call', fakeCall) + + await t.throwsAsync(() => builder.call(), { + instanceOf: OutputValidationError, + name: 'OutputValidationError', + message: 'test' + }) + t.is(fakeCall.callCount, 3) +})