feat: refactoring openai completions

old-agentic-v1^2
Travis Fischer 2023-05-26 12:56:20 -07:00
rodzic 4174d7cdbf
commit ab9de9a725
3 zmienionych plików z 192 dodań i 156 usunięć

Wyświetl plik

@ -1,7 +1,15 @@
import { jsonrepair } from 'jsonrepair'
import Mustache from 'mustache'
import { dedent } from 'ts-dedent'
import { ZodRawShape, ZodTypeAny, z } from 'zod' import { ZodRawShape, ZodTypeAny, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts'
import * as types from './types' import * as types from './types'
import { BaseTaskCallBuilder } from './task' import { BaseTaskCallBuilder } from './task'
import {
extractJSONArrayFromString,
extractJSONObjectFromString
} from './utils'
export abstract class BaseLLMCallBuilder< export abstract class BaseLLMCallBuilder<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
@ -59,10 +67,11 @@ export abstract class BaseLLMCallBuilder<
// }): Promise<TOutput> // }): Promise<TOutput>
} }
export abstract class ChatModelBuilder< export abstract class BaseChatModelBuilder<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>, TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any> TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLMCallBuilder<TInput, TOutput, TModelParams> { > extends BaseLLMCallBuilder<TInput, TOutput, TModelParams> {
_messages: types.ChatMessage[] _messages: types.ChatMessage[]
@ -71,4 +80,153 @@ export abstract class ChatModelBuilder<
this._messages = options.messages this._messages = options.messages
} }
protected abstract _createChatCompletion(
messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
override async call(
input?: types.ParsedData<TInput>
): Promise<types.ParsedData<TOutput>> {
if (this._inputSchema) {
const inputSchema =
this._inputSchema instanceof z.ZodType
? this._inputSchema
: z.object(this._inputSchema)
// TODO: handle errors gracefully
input = inputSchema.parse(input)
}
// TODO: validate input message variables against input schema
const messages = this._messages
.map((message) => {
return {
...message,
content: message.content
? Mustache.render(dedent(message.content), input).trim()
: ''
}
})
.filter((message) => message.content)
if (this._examples?.length) {
// TODO: smarter example selection
for (const example of this._examples) {
messages.push({
role: 'system',
content: `Example input: ${example.input}\n\nExample output: ${example.output}`
})
}
}
if (this._outputSchema) {
const outputSchema =
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
const { node } = zodToTs(outputSchema)
if (node.kind === 152) {
// handle raw strings differently
messages.push({
role: 'system',
content: dedent`Output a raw string only, without any additional text.`
})
} else {
const tsTypeString = printNode(node, {
removeComments: false,
// TODO: this doesn't seem to actually work, so we're doing it manually below
omitTrailingSemicolon: true,
noEmitHelpers: true
})
.replace(/^ /gm, ' ')
.replace(/;$/gm, '')
messages.push({
role: 'system',
content: dedent`Output JSON only in the following TypeScript format:
\`\`\`ts
${tsTypeString}
\`\`\``
})
}
}
// TODO: filter/compress messages based on token counts
console.log('>>>')
console.log(messages)
const completion = await this._createChatCompletion(messages)
let output: any = completion.message.content
console.log('===')
console.log(output)
console.log('<<<')
if (this._outputSchema) {
const outputSchema =
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
if (outputSchema instanceof z.ZodArray) {
try {
const trimmedOutput = extractJSONArrayFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodObject) {
try {
const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodBoolean) {
output = output.toLowerCase().trim()
const booleanOutputs = {
true: true,
false: false,
yes: true,
no: false,
1: true,
0: false
}
const booleanOutput = booleanOutputs[output]
if (booleanOutput !== undefined) {
output = booleanOutput
} else {
// TODO
throw new Error(`invalid boolean output: ${output}`)
}
} else if (outputSchema instanceof z.ZodNumber) {
output = output.trim()
const numberOutput = outputSchema.isInt
? parseInt(output)
: parseFloat(output)
if (isNaN(numberOutput)) {
// TODO
throw new Error(`invalid number output: ${output}`)
} else {
output = numberOutput
}
}
// TODO: handle errors, retry logic, and self-healing
return outputSchema.parse(output)
} else {
return output
}
}
} }

Wyświetl plik

@ -1,25 +1,18 @@
import { jsonrepair } from 'jsonrepair'
import Mustache from 'mustache'
import { dedent } from 'ts-dedent'
import type { SetRequired } from 'type-fest' import type { SetRequired } from 'type-fest'
import { ZodRawShape, ZodTypeAny, z } from 'zod' import { ZodRawShape, ZodTypeAny, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts'
import * as types from './types' import * as types from './types'
import { defaultOpenAIModel } from './constants' import { defaultOpenAIModel } from './constants'
import { ChatModelBuilder } from './llm' import { BaseChatModelBuilder } from './llm'
import {
extractJSONArrayFromString,
extractJSONObjectFromString
} from './utils'
export class OpenAIChatModelBuilder< export class OpenAIChatModelBuilder<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny, TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string> TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>
> extends ChatModelBuilder< > extends BaseChatModelBuilder<
TInput, TInput,
TOutput, TOutput,
SetRequired<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'> SetRequired<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>,
types.openai.ChatCompletionResponse
> { > {
_client: types.openai.OpenAIClient _client: types.openai.OpenAIClient
@ -40,151 +33,20 @@ export class OpenAIChatModelBuilder<
this._client = client this._client = client
} }
override async call( protected override async _createChatCompletion(
input?: types.ParsedData<TInput> messages: types.ChatMessage[]
): Promise<types.ParsedData<TOutput>> { ): Promise<
if (this._inputSchema) { types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse>
const inputSchema = > {
this._inputSchema instanceof z.ZodType const response = await this._client.createChatCompletion({
? this._inputSchema model: this._model,
: z.object(this._inputSchema) ...this._modelParams,
// TODO: handle errors gracefully
input = inputSchema.parse(input)
}
// TODO: validate input message variables against input schema
const messages = this._messages
.map((message) => {
return {
...message,
content: message.content
? Mustache.render(dedent(message.content), input).trim()
: ''
}
})
.filter((message) => message.content)
if (this._examples?.length) {
// TODO: smarter example selection
for (const example of this._examples) {
messages.push({
role: 'system',
content: `Example input: ${example.input}\n\nExample output: ${example.output}`
})
}
}
if (this._outputSchema) {
const outputSchema =
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
const { node } = zodToTs(outputSchema)
if (node.kind === 152) {
// handle raw strings differently
messages.push({
role: 'system',
content: dedent`Output a raw string only, without any additional text.`
})
} else {
const tsTypeString = printNode(node, {
removeComments: false,
// TODO: this doesn't seem to actually work, so we're doing it manually below
omitTrailingSemicolon: true,
noEmitHelpers: true
})
.replace(/^ /gm, ' ')
.replace(/;$/gm, '')
messages.push({
role: 'system',
content: dedent`Output JSON only in the following TypeScript format:
\`\`\`ts
${tsTypeString}
\`\`\``
})
}
}
// TODO: filter/compress messages based on token counts
console.log('>>>')
console.log(messages)
const completion = await this._client.createChatCompletion({
model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining
...this._outputSchema,
messages messages
}) })
if (this._outputSchema) { return {
const outputSchema = message: response.message,
this._outputSchema instanceof z.ZodType response: response.response
? this._outputSchema
: z.object(this._outputSchema)
let output: any = completion.message.content
console.log('===')
console.log(output)
console.log('<<<')
if (outputSchema instanceof z.ZodArray) {
try {
const trimmedOutput = extractJSONArrayFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodObject) {
try {
const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodBoolean) {
output = output.toLowerCase().trim()
const booleanOutputs = {
true: true,
false: false,
yes: true,
no: false,
1: true,
0: false
}
const booleanOutput = booleanOutputs[output]
if (booleanOutput !== undefined) {
output = booleanOutput
} else {
// TODO
throw new Error(`invalid boolean output: ${output}`)
}
} else if (outputSchema instanceof z.ZodNumber) {
output = output.trim()
const numberOutput = outputSchema.isInt
? parseInt(output)
: parseFloat(output)
if (isNaN(numberOutput)) {
// TODO
throw new Error(`invalid number output: ${output}`)
} else {
output = numberOutput
}
}
// TODO: handle errors, retry logic, and self-healing
return outputSchema.parse(output)
} else {
return completion.message.content as any
} }
} }
} }

Wyświetl plik

@ -61,7 +61,13 @@ export interface LLMOptions<
promptSuffix?: string promptSuffix?: string
} }
export type ChatMessageRole = 'user' | 'system' | 'assistant' // export type ChatMessageRole = 'user' | 'system' | 'assistant'
export const ChatMessageRoleSchema = z.union([
z.literal('user'),
z.literal('system'),
z.literal('assistant')
])
export type ChatMessageRole = z.infer<typeof ChatMessageRoleSchema>
export interface ChatMessage { export interface ChatMessage {
role: ChatMessageRole role: ChatMessageRole
@ -76,6 +82,16 @@ export interface ChatModelOptions<
messages: ChatMessage[] messages: ChatMessage[]
} }
export interface BaseChatCompletionResponse<
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> {
/** The completion message. */
message: ChatMessage
/** The raw response from the LLM provider. */
response: TChatCompletionResponse
}
export interface LLMExample { export interface LLMExample {
input: string input: string
output: string output: string