chatgpt-api/src/llm.ts

267 wiersze
7.1 KiB
TypeScript
Czysty Zwykły widok Historia

2023-05-24 06:15:59 +00:00
import { jsonrepair } from 'jsonrepair'
2023-05-24 03:22:50 +00:00
import Mustache from 'mustache'
2023-05-24 05:28:38 +00:00
import { dedent } from 'ts-dedent'
2023-05-04 19:48:28 +00:00
import type { SetRequired } from 'type-fest'
2023-05-04 03:54:32 +00:00
import { ZodRawShape, ZodTypeAny, z } from 'zod'
2023-05-24 03:22:50 +00:00
import { printNode, zodToTs } from 'zod-to-ts'
import * as types from './types'
2023-05-24 06:15:59 +00:00
import {
extractJSONArrayFromString,
extractJSONObjectFromString
} from './utils'
2023-05-04 19:48:28 +00:00
const defaultOpenAIModel = 'gpt-3.5-turbo'
export class Agentic {
_client: types.openai.OpenAIClient
_verbosity: number
_defaults: Pick<
types.BaseLLMOptions,
'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig'
>
2023-05-04 19:48:28 +00:00
constructor(opts: {
openai: types.openai.OpenAIClient
verbosity?: number
defaults?: Pick<
types.BaseLLMOptions,
'provider' | 'model' | 'modelParams' | 'timeoutMs' | 'retryConfig'
>
}) {
this._client = opts.openai
this._verbosity = opts.verbosity ?? 0
this._defaults = {
provider: 'openai',
2023-05-04 19:48:28 +00:00
model: defaultOpenAIModel,
modelParams: {},
timeoutMs: 30000,
retryConfig: {
attempts: 3,
strategy: 'heal',
...opts.defaults?.retryConfig
},
...opts.defaults
}
}
gpt4(
promptOrChatCompletionParams: string | types.openai.ChatCompletionParams
) {
let options: Omit<types.openai.ChatCompletionParams, 'model'>
if (typeof promptOrChatCompletionParams === 'string') {
options = {
messages: [
{
role: 'user',
content: promptOrChatCompletionParams
}
]
}
} else {
options = promptOrChatCompletionParams
if (!options.messages) {
2023-05-04 19:48:28 +00:00
throw new Error('messages must be provided')
}
}
return new OpenAIChatModelBuilder(this._client, {
...(this._defaults as any), // TODO
model: 'gpt-4',
...options
})
}
}
2023-05-04 03:54:32 +00:00
export abstract class BaseLLMCallBuilder<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
2023-05-04 19:48:28 +00:00
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
2023-05-04 03:54:32 +00:00
TModelParams extends Record<string, any> = Record<string, any>
> {
_options: types.BaseLLMOptions<TInput, TOutput, TModelParams>
constructor(options: types.BaseLLMOptions<TInput, TOutput, TModelParams>) {
this._options = options
}
2023-05-04 03:54:32 +00:00
input<U extends ZodRawShape | ZodTypeAny = TInput>(
inputSchema: U
): BaseLLMCallBuilder<U, TOutput, TModelParams> {
;(
this as unknown as BaseLLMCallBuilder<U, TOutput, TModelParams>
)._options.input = inputSchema
return this as unknown as BaseLLMCallBuilder<U, TOutput, TModelParams>
}
2023-05-04 03:54:32 +00:00
output<U extends ZodRawShape | ZodTypeAny = TOutput>(
outputSchema: U
): BaseLLMCallBuilder<TInput, U, TModelParams> {
;(
this as unknown as BaseLLMCallBuilder<TInput, U, TModelParams>
)._options.output = outputSchema
return this as unknown as BaseLLMCallBuilder<TInput, U, TModelParams>
}
examples(examples: types.LLMExample[]) {
this._options.examples = examples
return this
}
retry(retryConfig: types.LLMRetryConfig) {
this._options.retryConfig = retryConfig
return this
}
2023-05-04 03:54:32 +00:00
abstract call(
input?: types.ParsedData<TInput>
): Promise<types.ParsedData<TOutput>>
2023-05-02 06:44:08 +00:00
// TODO
2023-05-04 03:54:32 +00:00
// abstract stream({
2023-05-02 06:44:08 +00:00
// input: TInput,
// onProgress: types.ProgressFunction
2023-05-04 03:54:32 +00:00
// }): Promise<TOutput>
}
export abstract class ChatModelBuilder<
2023-05-04 03:54:32 +00:00
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
2023-05-04 19:48:28 +00:00
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
2023-05-04 03:54:32 +00:00
TModelParams extends Record<string, any> = Record<string, any>
> extends BaseLLMCallBuilder<TInput, TOutput, TModelParams> {
_messages: types.ChatMessage[]
constructor(options: types.ChatModelOptions<TInput, TOutput, TModelParams>) {
super(options)
this._messages = options.messages
}
}
2023-05-04 03:54:32 +00:00
export class OpenAIChatModelBuilder<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
2023-05-04 19:48:28 +00:00
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>
2023-05-04 03:54:32 +00:00
> extends ChatModelBuilder<
TInput,
TOutput,
2023-05-04 19:48:28 +00:00
SetRequired<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>
> {
_client: types.openai.OpenAIClient
constructor(
client: types.openai.OpenAIClient,
options: types.ChatModelOptions<
TInput,
TOutput,
Omit<types.openai.ChatCompletionParams, 'messages'>
>
) {
super({
provider: 'openai',
2023-05-04 19:48:28 +00:00
model: defaultOpenAIModel,
...options
})
this._client = client
}
2023-05-04 03:54:32 +00:00
override async call(
input?: types.ParsedData<TInput>
): Promise<types.ParsedData<TOutput>> {
2023-05-24 03:22:50 +00:00
if (this._options.input) {
const inputSchema =
this._options.input instanceof z.ZodType
? this._options.input
: z.object(this._options.input)
// TODO: handle errors gracefully
input = inputSchema.parse(input)
}
2023-05-24 06:15:59 +00:00
// TODO: validate input message variables against input schema
2023-05-24 03:22:50 +00:00
const messages = this._messages
2023-05-24 05:28:38 +00:00
.map((message) => {
return {
...message,
content: message.content
? Mustache.render(dedent(message.content), input).trim()
: ''
}
})
.filter((message) => message.content)
2023-05-24 06:15:59 +00:00
if (this._options.output) {
const outputSchema =
this._options.output instanceof z.ZodType
? this._options.output
: z.object(this._options.output)
const { node } = zodToTs(outputSchema)
2023-05-24 06:36:42 +00:00
const tsTypeString = printNode(node, {
removeComments: true,
// TODO: this doesn't seem to actually work, so we're doing it manually below
omitTrailingSemicolon: true,
noEmitHelpers: true
})
.replace(/^ /gm, ' ')
.replace(/;$/gm, '')
2023-05-24 06:15:59 +00:00
messages.push({
role: 'system',
content: dedent`Output JSON only in the following format:
\`\`\`ts
${tsTypeString}
\`\`\``
})
}
2023-05-24 05:28:38 +00:00
// TODO: filter/compress messages based on token counts
2023-05-04 19:48:28 +00:00
2023-05-24 06:27:09 +00:00
console.log('>>>')
console.log(messages)
2023-05-04 19:48:28 +00:00
const completion = await this._client.createChatCompletion({
2023-05-24 03:22:50 +00:00
model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining
2023-05-04 19:48:28 +00:00
...this._options.modelParams,
2023-05-24 03:22:50 +00:00
messages
2023-05-04 19:48:28 +00:00
})
if (this._options.output) {
2023-05-24 03:22:50 +00:00
const outputSchema =
2023-05-04 19:48:28 +00:00
this._options.output instanceof z.ZodType
? this._options.output
: z.object(this._options.output)
2023-05-24 06:15:59 +00:00
let output: any = completion.message.content
2023-05-24 06:27:09 +00:00
console.log('===')
console.log(output)
console.log('<<<')
2023-05-24 06:15:59 +00:00
if (outputSchema instanceof z.ZodArray) {
try {
const trimmedOutput = extractJSONArrayFromString(output)
2023-05-24 06:27:09 +00:00
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
2023-05-24 06:15:59 +00:00
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodObject) {
try {
const trimmedOutput = extractJSONObjectFromString(output)
2023-05-24 06:27:09 +00:00
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
2023-05-24 06:15:59 +00:00
} catch (err) {
// TODO
throw err
}
}
2023-05-04 19:48:28 +00:00
// TODO: handle errors, retry logic, and self-healing
2023-05-24 06:15:59 +00:00
return outputSchema.parse(output)
2023-05-04 19:48:28 +00:00
} else {
return completion.message.content as any
}
}
}