kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: refactoring openai completions
rodzic
4174d7cdbf
commit
ab9de9a725
162
src/llm.ts
162
src/llm.ts
|
@ -1,7 +1,15 @@
|
|||
import { jsonrepair } from 'jsonrepair'
|
||||
import Mustache from 'mustache'
|
||||
import { dedent } from 'ts-dedent'
|
||||
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
||||
import { printNode, zodToTs } from 'zod-to-ts'
|
||||
|
||||
import * as types from './types'
|
||||
import { BaseTaskCallBuilder } from './task'
|
||||
import {
|
||||
extractJSONArrayFromString,
|
||||
extractJSONObjectFromString
|
||||
} from './utils'
|
||||
|
||||
export abstract class BaseLLMCallBuilder<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
|
@ -59,10 +67,11 @@ export abstract class BaseLLMCallBuilder<
|
|||
// }): Promise<TOutput>
|
||||
}
|
||||
|
||||
export abstract class ChatModelBuilder<
|
||||
export abstract class BaseChatModelBuilder<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>,
|
||||
TModelParams extends Record<string, any> = Record<string, any>
|
||||
TModelParams extends Record<string, any> = Record<string, any>,
|
||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||
> extends BaseLLMCallBuilder<TInput, TOutput, TModelParams> {
|
||||
_messages: types.ChatMessage[]
|
||||
|
||||
|
@ -71,4 +80,153 @@ export abstract class ChatModelBuilder<
|
|||
|
||||
this._messages = options.messages
|
||||
}
|
||||
|
||||
protected abstract _createChatCompletion(
|
||||
messages: types.ChatMessage[]
|
||||
): Promise<types.BaseChatCompletionResponse<TChatCompletionResponse>>
|
||||
|
||||
override async call(
|
||||
input?: types.ParsedData<TInput>
|
||||
): Promise<types.ParsedData<TOutput>> {
|
||||
if (this._inputSchema) {
|
||||
const inputSchema =
|
||||
this._inputSchema instanceof z.ZodType
|
||||
? this._inputSchema
|
||||
: z.object(this._inputSchema)
|
||||
|
||||
// TODO: handle errors gracefully
|
||||
input = inputSchema.parse(input)
|
||||
}
|
||||
|
||||
// TODO: validate input message variables against input schema
|
||||
|
||||
const messages = this._messages
|
||||
.map((message) => {
|
||||
return {
|
||||
...message,
|
||||
content: message.content
|
||||
? Mustache.render(dedent(message.content), input).trim()
|
||||
: ''
|
||||
}
|
||||
})
|
||||
.filter((message) => message.content)
|
||||
|
||||
if (this._examples?.length) {
|
||||
// TODO: smarter example selection
|
||||
for (const example of this._examples) {
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: `Example input: ${example.input}\n\nExample output: ${example.output}`
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (this._outputSchema) {
|
||||
const outputSchema =
|
||||
this._outputSchema instanceof z.ZodType
|
||||
? this._outputSchema
|
||||
: z.object(this._outputSchema)
|
||||
|
||||
const { node } = zodToTs(outputSchema)
|
||||
|
||||
if (node.kind === 152) {
|
||||
// handle raw strings differently
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: dedent`Output a raw string only, without any additional text.`
|
||||
})
|
||||
} else {
|
||||
const tsTypeString = printNode(node, {
|
||||
removeComments: false,
|
||||
// TODO: this doesn't seem to actually work, so we're doing it manually below
|
||||
omitTrailingSemicolon: true,
|
||||
noEmitHelpers: true
|
||||
})
|
||||
.replace(/^ /gm, ' ')
|
||||
.replace(/;$/gm, '')
|
||||
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: dedent`Output JSON only in the following TypeScript format:
|
||||
\`\`\`ts
|
||||
${tsTypeString}
|
||||
\`\`\``
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: filter/compress messages based on token counts
|
||||
|
||||
console.log('>>>')
|
||||
console.log(messages)
|
||||
|
||||
const completion = await this._createChatCompletion(messages)
|
||||
let output: any = completion.message.content
|
||||
|
||||
console.log('===')
|
||||
console.log(output)
|
||||
console.log('<<<')
|
||||
|
||||
if (this._outputSchema) {
|
||||
const outputSchema =
|
||||
this._outputSchema instanceof z.ZodType
|
||||
? this._outputSchema
|
||||
: z.object(this._outputSchema)
|
||||
|
||||
if (outputSchema instanceof z.ZodArray) {
|
||||
try {
|
||||
const trimmedOutput = extractJSONArrayFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
} catch (err) {
|
||||
// TODO
|
||||
throw err
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodObject) {
|
||||
try {
|
||||
const trimmedOutput = extractJSONObjectFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
} catch (err) {
|
||||
// TODO
|
||||
throw err
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||
output = output.toLowerCase().trim()
|
||||
const booleanOutputs = {
|
||||
true: true,
|
||||
false: false,
|
||||
yes: true,
|
||||
no: false,
|
||||
1: true,
|
||||
0: false
|
||||
}
|
||||
|
||||
const booleanOutput = booleanOutputs[output]
|
||||
if (booleanOutput !== undefined) {
|
||||
output = booleanOutput
|
||||
} else {
|
||||
// TODO
|
||||
throw new Error(`invalid boolean output: ${output}`)
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodNumber) {
|
||||
output = output.trim()
|
||||
|
||||
const numberOutput = outputSchema.isInt
|
||||
? parseInt(output)
|
||||
: parseFloat(output)
|
||||
|
||||
if (isNaN(numberOutput)) {
|
||||
// TODO
|
||||
throw new Error(`invalid number output: ${output}`)
|
||||
} else {
|
||||
output = numberOutput
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle errors, retry logic, and self-healing
|
||||
|
||||
return outputSchema.parse(output)
|
||||
} else {
|
||||
return output
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
168
src/openai.ts
168
src/openai.ts
|
@ -1,25 +1,18 @@
|
|||
import { jsonrepair } from 'jsonrepair'
|
||||
import Mustache from 'mustache'
|
||||
import { dedent } from 'ts-dedent'
|
||||
import type { SetRequired } from 'type-fest'
|
||||
import { ZodRawShape, ZodTypeAny, z } from 'zod'
|
||||
import { printNode, zodToTs } from 'zod-to-ts'
|
||||
|
||||
import * as types from './types'
|
||||
import { defaultOpenAIModel } from './constants'
|
||||
import { ChatModelBuilder } from './llm'
|
||||
import {
|
||||
extractJSONArrayFromString,
|
||||
extractJSONObjectFromString
|
||||
} from './utils'
|
||||
import { BaseChatModelBuilder } from './llm'
|
||||
|
||||
export class OpenAIChatModelBuilder<
|
||||
TInput extends ZodRawShape | ZodTypeAny = ZodTypeAny,
|
||||
TOutput extends ZodRawShape | ZodTypeAny = z.ZodType<string>
|
||||
> extends ChatModelBuilder<
|
||||
> extends BaseChatModelBuilder<
|
||||
TInput,
|
||||
TOutput,
|
||||
SetRequired<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>
|
||||
SetRequired<Omit<types.openai.ChatCompletionParams, 'messages'>, 'model'>,
|
||||
types.openai.ChatCompletionResponse
|
||||
> {
|
||||
_client: types.openai.OpenAIClient
|
||||
|
||||
|
@ -40,151 +33,20 @@ export class OpenAIChatModelBuilder<
|
|||
this._client = client
|
||||
}
|
||||
|
||||
override async call(
|
||||
input?: types.ParsedData<TInput>
|
||||
): Promise<types.ParsedData<TOutput>> {
|
||||
if (this._inputSchema) {
|
||||
const inputSchema =
|
||||
this._inputSchema instanceof z.ZodType
|
||||
? this._inputSchema
|
||||
: z.object(this._inputSchema)
|
||||
|
||||
// TODO: handle errors gracefully
|
||||
input = inputSchema.parse(input)
|
||||
}
|
||||
|
||||
// TODO: validate input message variables against input schema
|
||||
|
||||
const messages = this._messages
|
||||
.map((message) => {
|
||||
return {
|
||||
...message,
|
||||
content: message.content
|
||||
? Mustache.render(dedent(message.content), input).trim()
|
||||
: ''
|
||||
}
|
||||
})
|
||||
.filter((message) => message.content)
|
||||
|
||||
if (this._examples?.length) {
|
||||
// TODO: smarter example selection
|
||||
for (const example of this._examples) {
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: `Example input: ${example.input}\n\nExample output: ${example.output}`
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (this._outputSchema) {
|
||||
const outputSchema =
|
||||
this._outputSchema instanceof z.ZodType
|
||||
? this._outputSchema
|
||||
: z.object(this._outputSchema)
|
||||
|
||||
const { node } = zodToTs(outputSchema)
|
||||
|
||||
if (node.kind === 152) {
|
||||
// handle raw strings differently
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: dedent`Output a raw string only, without any additional text.`
|
||||
})
|
||||
} else {
|
||||
const tsTypeString = printNode(node, {
|
||||
removeComments: false,
|
||||
// TODO: this doesn't seem to actually work, so we're doing it manually below
|
||||
omitTrailingSemicolon: true,
|
||||
noEmitHelpers: true
|
||||
})
|
||||
.replace(/^ /gm, ' ')
|
||||
.replace(/;$/gm, '')
|
||||
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: dedent`Output JSON only in the following TypeScript format:
|
||||
\`\`\`ts
|
||||
${tsTypeString}
|
||||
\`\`\``
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: filter/compress messages based on token counts
|
||||
|
||||
console.log('>>>')
|
||||
console.log(messages)
|
||||
const completion = await this._client.createChatCompletion({
|
||||
model: defaultOpenAIModel, // TODO: this shouldn't be necessary but TS is complaining
|
||||
...this._outputSchema,
|
||||
protected override async _createChatCompletion(
|
||||
messages: types.ChatMessage[]
|
||||
): Promise<
|
||||
types.BaseChatCompletionResponse<types.openai.ChatCompletionResponse>
|
||||
> {
|
||||
const response = await this._client.createChatCompletion({
|
||||
model: this._model,
|
||||
...this._modelParams,
|
||||
messages
|
||||
})
|
||||
|
||||
if (this._outputSchema) {
|
||||
const outputSchema =
|
||||
this._outputSchema instanceof z.ZodType
|
||||
? this._outputSchema
|
||||
: z.object(this._outputSchema)
|
||||
|
||||
let output: any = completion.message.content
|
||||
console.log('===')
|
||||
console.log(output)
|
||||
console.log('<<<')
|
||||
|
||||
if (outputSchema instanceof z.ZodArray) {
|
||||
try {
|
||||
const trimmedOutput = extractJSONArrayFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
} catch (err) {
|
||||
// TODO
|
||||
throw err
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodObject) {
|
||||
try {
|
||||
const trimmedOutput = extractJSONObjectFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
} catch (err) {
|
||||
// TODO
|
||||
throw err
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||
output = output.toLowerCase().trim()
|
||||
const booleanOutputs = {
|
||||
true: true,
|
||||
false: false,
|
||||
yes: true,
|
||||
no: false,
|
||||
1: true,
|
||||
0: false
|
||||
}
|
||||
|
||||
const booleanOutput = booleanOutputs[output]
|
||||
if (booleanOutput !== undefined) {
|
||||
output = booleanOutput
|
||||
} else {
|
||||
// TODO
|
||||
throw new Error(`invalid boolean output: ${output}`)
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodNumber) {
|
||||
output = output.trim()
|
||||
|
||||
const numberOutput = outputSchema.isInt
|
||||
? parseInt(output)
|
||||
: parseFloat(output)
|
||||
|
||||
if (isNaN(numberOutput)) {
|
||||
// TODO
|
||||
throw new Error(`invalid number output: ${output}`)
|
||||
} else {
|
||||
output = numberOutput
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle errors, retry logic, and self-healing
|
||||
|
||||
return outputSchema.parse(output)
|
||||
} else {
|
||||
return completion.message.content as any
|
||||
return {
|
||||
message: response.message,
|
||||
response: response.response
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
18
src/types.ts
18
src/types.ts
|
@ -61,7 +61,13 @@ export interface LLMOptions<
|
|||
promptSuffix?: string
|
||||
}
|
||||
|
||||
export type ChatMessageRole = 'user' | 'system' | 'assistant'
|
||||
// export type ChatMessageRole = 'user' | 'system' | 'assistant'
|
||||
export const ChatMessageRoleSchema = z.union([
|
||||
z.literal('user'),
|
||||
z.literal('system'),
|
||||
z.literal('assistant')
|
||||
])
|
||||
export type ChatMessageRole = z.infer<typeof ChatMessageRoleSchema>
|
||||
|
||||
export interface ChatMessage {
|
||||
role: ChatMessageRole
|
||||
|
@ -76,6 +82,16 @@ export interface ChatModelOptions<
|
|||
messages: ChatMessage[]
|
||||
}
|
||||
|
||||
export interface BaseChatCompletionResponse<
|
||||
TChatCompletionResponse extends Record<string, any> = Record<string, any>
|
||||
> {
|
||||
/** The completion message. */
|
||||
message: ChatMessage
|
||||
|
||||
/** The raw response from the LLM provider. */
|
||||
response: TChatCompletionResponse
|
||||
}
|
||||
|
||||
export interface LLMExample {
|
||||
input: string
|
||||
output: string
|
||||
|
|
Ładowanie…
Reference in New Issue