feat: refactoring openai completions

old-agentic-v1^2
Travis Fischer 2023-05-26 12:56:20 -07:00
rodzic 4174d7cdbf
commit ab9de9a725
3 zmienionych plików z 192 dodań i 156 usunięć

Wyświetl plik

@ -1,7 +1,15 @@
import { jsonrepair } from 'jsonrepair'
import Mustache from 'mustache'
import { dedent } from 'ts-dedent'
import { ZodRawShape, ZodTypeAny, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts'
import * as types from './types'
import { BaseTaskCallBuilder } from './task'
import {
extractJSONArrayFromString,
extractJSONObjectFromString
} from './utils'
export abstract class BaseLLMCallBuilder<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
@ -59,10 +67,11 @@ export abstract class BaseLLMCallBuilder<
// }): Promise<TOutput>
}
export abstract class ChatModelBuilder<
export abstract class BaseChatModelBuilder<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
TModelParams extends Record<string, any> = Record<string, any>
TModelParams extends Record<string, any> = Record<string, any>,
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> extends BaseLLMCallBuilder<TInput, TOutput, TModelParams> {
_messages: types.ChatMessage[]
@ -71,4 +80,153 @@ export abstract class ChatModelBuilder<
this._messages = options.messages
}
protected abstract _createChatCompletion(
messages: types.ChatMessage[]
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
override async call(
input?: types.ParsedData<TInput>
): Promise<types.ParsedData<TOutput>> {
if (this._inputSchema) {
const inputSchema =
this._inputSchema instanceof z.ZodType
? this._inputSchema
: z.object(this._inputSchema)
// TODO: handle errors gracefully
input = inputSchema.parse(input)
}
// TODO: validate input message variables against input schema
const messages = this._messages
.map((message) => {
return {
...message,
content: message.content
? Mustache.render(dedent(message.content), input).trim()
: ''
}
})
.filter((message) => message.content)
if (this._examples?.length) {
// TODO: smarter example selection
for (const example of this._examples) {
messages.push({
role: 'system',
content: `Example input: ${example.input}\n\nExample output: ${example.output}`
})
}
}
if (this._outputSchema) {
const outputSchema =
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
const { node } = zodToTs(outputSchema)
if (node.kind === 152) {
// handle raw strings differently
messages.push({
role: 'system',
content: dedent`Output a raw string only, without any additional text.`
})
} else {
const tsTypeString = printNode(node, {
removeComments: false,
// TODO: this doesn't seem to actually work, so we're doing it manually below
omitTrailingSemicolon: true,
noEmitHelpers: true
})
.replace(/^ /gm, ' ')
.replace(/;$/gm, '')
messages.push({
role: 'system',
content: dedent`Output JSON only in the following TypeScript format:
\`\`\`ts
${tsTypeString}
\`\`\``
})
}
}
// TODO: filter/compress messages based on token counts
console.log('>>>')
console.log(messages)
const completion = await this._createChatCompletion(messages)
let output: any = completion.message.content
console.log('===')
console.log(output)
console.log('<<<')
if (this._outputSchema) {
const outputSchema =
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
if (outputSchema instanceof z.ZodArray) {
try {
const trimmedOutput = extractJSONArrayFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodObject) {
try {
const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodBoolean) {
output = output.toLowerCase().trim()
const booleanOutputs = {
true: true,
false: false,
yes: true,
no: false,
1: true,
0: false
}
const booleanOutput = booleanOutputs[output]
if (booleanOutput !== undefined) {
output = booleanOutput
} else {
// TODO
throw new Error(`invalid boolean output: ${output}`)
}
} else if (outputSchema instanceof z.ZodNumber) {
output = output.trim()
const numberOutput = outputSchema.isInt
? parseInt(output)
: parseFloat(output)
if (isNaN(numberOutput)) {
// TODO
throw new Error(`invalid number output: ${output}`)
} else {
output = numberOutput
}
}
// TODO: handle errors, retry logic, and self-healing
return outputSchema.parse(output)
} else {
return output
}
}
}

Wyświetl plik

@ -1,25 +1,18 @@
import { jsonrepair } from 'jsonrepair'
import Mustache from 'mustache'
import { dedent } from 'ts-dedent'
import type { SetRequired } from 'type-fest'
import { ZodRawShape, ZodTypeAny, z } from 'zod'
import { printNode, zodToTs } from 'zod-to-ts'
import * as types from './types'
import { defaultOpenAIModel } from './constants'
import { ChatModelBuilder } from './llm'
import {
extractJSONArrayFromString,
extractJSONObjectFromString
} from './utils'
import { BaseChatModelBuilder } from './llm'
export class OpenAIChatModelBuilder<
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>
> extends ChatModelBuilder<
> extends BaseChatModelBuilder<
TInput,
TOutput,
SetRequired<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>
SetRequired<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>,
types.openai.ChatCompletionResponse
> {
_client: types.openai.OpenAIClient
@ -40,151 +33,20 @@ export class OpenAIChatModelBuilder<
this._client = client
}
override async call(
input?: types.ParsedData<TInput>
): Promise<types.ParsedData<TOutput>> {
if (this._inputSchema) {
const inputSchema =
this._inputSchema instanceof z.ZodType
? this._inputSchema
: z.object(this._inputSchema)
// TODO: handle errors gracefully
input = inputSchema.parse(input)
}
// TODO: validate input message variables against input schema
const messages = this._messages
.map((message) => {
return {
...message,
content: message.content
? Mustache.render(dedent(message.content), input).trim()
: ''
}
})
.filter((message) => message.content)
if (this._examples?.length) {
// TODO: smarter example selection
for (const example of this._examples) {
messages.push({
role: 'system',
content: `Example input: ${example.input}\n\nExample output: ${example.output}`
})
}
}
if (this._outputSchema) {
const outputSchema =
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
const { node } = zodToTs(outputSchema)
if (node.kind === 152) {
// handle raw strings differently
messages.push({
role: 'system',
content: dedent`Output a raw string only, without any additional text.`
})
} else {
const tsTypeString = printNode(node, {
removeComments: false,
// TODO: this doesn't seem to actually work, so we're doing it manually below
omitTrailingSemicolon: true,
noEmitHelpers: true
})
.replace(/^ /gm, ' ')
.replace(/;$/gm, '')
messages.push({
role: 'system',
content: dedent`Output JSON only in the following TypeScript format:
\`\`\`ts
${tsTypeString}
\`\`\``
})
}
}
// TODO: filter/compress messages based on token counts
console.log('>>>')
console.log(messages)
const completion = await this._client.createChatCompletion({
model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining
...this._outputSchema,
protected override async _createChatCompletion(
messages: types.ChatMessage[]
): Promise<
types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse>
> {
const response = await this._client.createChatCompletion({
model: this._model,
...this._modelParams,
messages
})
if (this._outputSchema) {
const outputSchema =
this._outputSchema instanceof z.ZodType
? this._outputSchema
: z.object(this._outputSchema)
let output: any = completion.message.content
console.log('===')
console.log(output)
console.log('<<<')
if (outputSchema instanceof z.ZodArray) {
try {
const trimmedOutput = extractJSONArrayFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodObject) {
try {
const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err) {
// TODO
throw err
}
} else if (outputSchema instanceof z.ZodBoolean) {
output = output.toLowerCase().trim()
const booleanOutputs = {
true: true,
false: false,
yes: true,
no: false,
1: true,
0: false
}
const booleanOutput = booleanOutputs[output]
if (booleanOutput !== undefined) {
output = booleanOutput
} else {
// TODO
throw new Error(`invalid boolean output: ${output}`)
}
} else if (outputSchema instanceof z.ZodNumber) {
output = output.trim()
const numberOutput = outputSchema.isInt
? parseInt(output)
: parseFloat(output)
if (isNaN(numberOutput)) {
// TODO
throw new Error(`invalid number output: ${output}`)
} else {
output = numberOutput
}
}
// TODO: handle errors, retry logic, and self-healing
return outputSchema.parse(output)
} else {
return completion.message.content as any
return {
message: response.message,
response: response.response
}
}
}

Wyświetl plik

@ -61,7 +61,13 @@ export interface LLMOptions<
promptSuffix?: string
}
export type ChatMessageRole = 'user' | 'system' | 'assistant'
// export type ChatMessageRole = 'user' | 'system' | 'assistant'
export const ChatMessageRoleSchema = z.union([
z.literal('user'),
z.literal('system'),
z.literal('assistant')
])
export type ChatMessageRole = z.infer<typeof ChatMessageRoleSchema>
export interface ChatMessage {
role: ChatMessageRole
@ -76,6 +82,16 @@ export interface ChatModelOptions<
messages: ChatMessage[]
}
export interface BaseChatCompletionResponse<
TChatCompletionResponse extends Record<string, any> = Record<string, any>
> {
/** The completion message. */
message: ChatMessage
/** The raw response from the LLM provider. */
response: TChatCompletionResponse
}
export interface LLMExample {
input: string
output: string