diff --git a/package.json b/package.json index 2701624..04f547e 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "openai-fetch": "^1.2.0", "p-map": "^6.0.0", "parse-json": "^7.0.0", + "type-fest": "^3.9.0", "zod": "^3.21.4", "zod-validation-error": "^1.3.0" }, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index ad15e5b..11cc221 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -16,6 +16,9 @@ dependencies: parse-json: specifier: ^7.0.0 version: 7.0.0 + type-fest: + specifier: ^3.9.0 + version: 3.9.0 zod: specifier: ^3.21.4 version: 3.21.4 diff --git a/src/llm.ts b/src/llm.ts index f43ceb8..cd5c2e3 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -1,7 +1,10 @@ +import type { SetRequired } from 'type-fest' import { ZodRawShape, ZodTypeAny, z } from 'zod' import * as types from './types' +const defaultOpenAIModel = 'gpt-3.5-turbo' + export class Agentic { _client: types.openai.OpenAIClient _verbosity: number @@ -10,21 +13,19 @@ export class Agentic { 'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig' > - constructor( - client: types.openai.OpenAIClient, - opts: { - verbosity?: number - defaults?: Pick< - types.BaseLLMOptions, - 'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig' - > - } = {} - ) { - this._client = client + constructor(opts: { + openai: types.openai.OpenAIClient + verbosity?: number + defaults?: Pick< + types.BaseLLMOptions, + 'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig' + > + }) { + this._client = opts.openai this._verbosity = opts.verbosity ?? 0 this._defaults = { provider: 'openai', - model: 'gpt-3.5-turbo', + model: defaultOpenAIModel, modelParams: {}, timeoutMs: 30000, retryConfig: { @@ -54,7 +55,7 @@ export class Agentic { options = promptOrChatCompletionParams if (!options.messages) { - throw new Error() + throw new Error('messages must be provided') } } @@ -68,7 +69,7 @@ export class Agentic { export abstract class BaseLLMCallBuilder< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, - TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny, + TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TModelParams extends Record = Record > { _options: types.BaseLLMOptions @@ -118,7 +119,7 @@ export abstract class BaseLLMCallBuilder< export abstract class ChatModelBuilder< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, - TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny, + TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TModelParams extends Record = Record > extends BaseLLMCallBuilder { _messages: types.ChatMessage[] @@ -132,11 +133,11 @@ export abstract class ChatModelBuilder< export class OpenAIChatModelBuilder< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, - TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny + TOutput extends ZodRawShape | ZodTypeAny = z.ZodType > extends ChatModelBuilder< TInput, TOutput, - Omit + SetRequired, 'model'> > { _client: types.openai.OpenAIClient @@ -150,6 +151,7 @@ export class OpenAIChatModelBuilder< ) { super({ provider: 'openai', + model: defaultOpenAIModel, ...options }) @@ -159,8 +161,26 @@ export class OpenAIChatModelBuilder< override async call( input?: types.ParsedData ): Promise> { - // this._options.output?.describe - // TODO - return true as types.ParsedData + // TODO: construct messages + + const completion = await this._client.createChatCompletion({ + model: defaultOpenAIModel, // TODO: this shouldn't be necessary + ...this._options.modelParams, + messages: this._messages + }) + + if (this._options.output) { + const schema = + this._options.output instanceof z.ZodType + ? this._options.output + : z.object(this._options.output) + + // TODO: convert string => object if necessary + // TODO: handle errors, retry logic, and self-healing + + return schema.parse(completion.message.content) + } else { + return completion.message.content as any + } } } diff --git a/src/temp.ts b/src/temp.ts index c735131..3f455e4 100644 --- a/src/temp.ts +++ b/src/temp.ts @@ -5,9 +5,10 @@ import { z } from 'zod' import { Agentic } from './llm' dotenv.config() + async function main() { const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! }) - const $ = new Agentic(openai) + const $ = new Agentic({ openai }) const ex0 = await $.gpt4(`give me a single boolean value`) .output(z.boolean()) @@ -15,6 +16,13 @@ async function main() { .call() console.log(ex0) + + const ex1 = await $.gpt4(`give me fake data conforming to this schema`) + .output(z.object({ foo: z.string(), bar: z.number() })) + // .retry({ attempts: 3 }) + .call() + + console.log(ex1) } main() diff --git a/src/types.ts b/src/types.ts index d0c1720..d9ef18c 100644 --- a/src/types.ts +++ b/src/types.ts @@ -3,7 +3,6 @@ import { SafeParseReturnType, ZodObject, ZodRawShape, - ZodSchema, ZodTypeAny, output, z @@ -27,7 +26,7 @@ export type SafeParsedData = export interface BaseLLMOptions< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, - TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny, + TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TModelParams extends Record = Record > { provider?: string @@ -43,7 +42,7 @@ export interface BaseLLMOptions< export interface LLMOptions< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, - TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny, + TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TModelParams extends Record = Record > extends BaseLLMOptions { promptTemplate?: string @@ -51,7 +50,7 @@ export interface LLMOptions< promptSuffix?: string } -export type ChatMessageRole = 'user' | 'system' | 'assistant' | 'tool' +export type ChatMessageRole = 'user' | 'system' | 'assistant' export interface ChatMessage { role: ChatMessageRole @@ -60,7 +59,7 @@ export interface ChatMessage { export interface ChatModelOptions< TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, - TOutput extends ZodRawShape | ZodTypeAny = ZodTypeAny, + TOutput extends ZodRawShape | ZodTypeAny = z.ZodType, TModelParams extends Record = Record > extends BaseLLMOptions { messages: ChatMessage[]