kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: separate out functions, generate minified JSON, improve Boolean handling
rodzic
eeb2315879
commit
e4d929c11f
282
src/llms/chat.ts
282
src/llms/chat.ts
|
@ -21,6 +21,113 @@ import {
|
|||
getNumTokensForChatMessages
|
||||
} from './llm-utils'
|
||||
|
||||
const BOOLEAN_OUTPUTS = {
|
||||
true: true,
|
||||
false: false,
|
||||
t: true,
|
||||
f: false,
|
||||
yes: true,
|
||||
no: false,
|
||||
y: true,
|
||||
n: false,
|
||||
'1': true,
|
||||
'0': false
|
||||
}
|
||||
|
||||
function parseArrayOutput(output: string): Array<any> {
|
||||
try {
|
||||
const trimmedOutput = extractJSONArrayFromString(output)
|
||||
const parsedOutput = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
return parsedOutput
|
||||
} catch (err: any) {
|
||||
if (err instanceof JSONRepairError) {
|
||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
} else if (err instanceof SyntaxError) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid JSON array: ${err.message}`,
|
||||
{ cause: err }
|
||||
)
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function parseObjectOutput(output) {
|
||||
try {
|
||||
const trimmedOutput = extractJSONObjectFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
|
||||
if (Array.isArray(output)) {
|
||||
// TODO
|
||||
output = output[0]
|
||||
}
|
||||
|
||||
return output
|
||||
} catch (err: any) {
|
||||
if (err instanceof JSONRepairError) {
|
||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
} else if (err instanceof SyntaxError) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid JSON object: ${err.message}`,
|
||||
{ cause: err }
|
||||
)
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function parseBooleanOutput(output): boolean {
|
||||
output = output
|
||||
.toLowerCase()
|
||||
.trim()
|
||||
.replace(/[.!?]+$/, '')
|
||||
|
||||
const booleanOutput = BOOLEAN_OUTPUTS[output]
|
||||
|
||||
if (booleanOutput !== undefined) {
|
||||
return booleanOutput
|
||||
} else {
|
||||
throw new errors.OutputValidationError(`Invalid boolean output: ${output}`)
|
||||
}
|
||||
}
|
||||
|
||||
function parseNumberOutput(output, outputSchema: z.ZodNumber): number {
|
||||
output = output.trim()
|
||||
|
||||
const numberOutput = outputSchema.isInt
|
||||
? parseInt(output)
|
||||
: parseFloat(output)
|
||||
|
||||
if (isNaN(numberOutput)) {
|
||||
throw new errors.OutputValidationError(`Invalid number output: ${output}`)
|
||||
}
|
||||
|
||||
return numberOutput
|
||||
}
|
||||
|
||||
function parseOutput(output: any, outputSchema: ZodType<any>) {
|
||||
if (outputSchema instanceof z.ZodArray) {
|
||||
output = parseArrayOutput(output)
|
||||
} else if (outputSchema instanceof z.ZodObject) {
|
||||
output = parseObjectOutput(output)
|
||||
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||
output = parseBooleanOutput(output)
|
||||
} else if (outputSchema instanceof z.ZodNumber) {
|
||||
output = parseNumberOutput(output, outputSchema)
|
||||
}
|
||||
|
||||
// TODO: fix typescript issue here with recursive types
|
||||
const safeResult = (outputSchema.safeParse as any)(output)
|
||||
|
||||
if (!safeResult.success) {
|
||||
throw new errors.ZodOutputValidationError(safeResult.error)
|
||||
}
|
||||
|
||||
return safeResult.data
|
||||
}
|
||||
|
||||
export abstract class BaseChatCompletion<
|
||||
TInput extends types.TaskInput = void,
|
||||
TOutput extends types.TaskOutput = string,
|
||||
|
@ -141,47 +248,12 @@ export abstract class BaseChatCompletion<
|
|||
}
|
||||
}
|
||||
|
||||
if (this._outputSchema) {
|
||||
// TODO: replace zod-to-ts with zod-to-json-schema?
|
||||
const { node } = zodToTs(this._outputSchema)
|
||||
|
||||
if (node.kind === 152) {
|
||||
// handle raw strings differently
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: dedent`Output a raw string only, without any additional text.`
|
||||
})
|
||||
} else {
|
||||
const tsTypeString = printNode(node, {
|
||||
removeComments: false,
|
||||
// TODO: this doesn't seem to actually work, so we're doing it manually below
|
||||
omitTrailingSemicolon: true,
|
||||
noEmitHelpers: true
|
||||
})
|
||||
.replace(/^ {4}/gm, ' ')
|
||||
.replace(/;$/gm, '')
|
||||
|
||||
const label =
|
||||
this._outputSchema instanceof z.ZodArray
|
||||
? 'JSON array'
|
||||
: this._outputSchema instanceof z.ZodObject
|
||||
? 'JSON object'
|
||||
: this._outputSchema instanceof z.ZodNumber
|
||||
? 'number'
|
||||
: this._outputSchema instanceof z.ZodString
|
||||
? 'string'
|
||||
: this._outputSchema instanceof z.ZodBoolean
|
||||
? 'boolean'
|
||||
: 'JSON value'
|
||||
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: dedent`Do not output code. Output a single ${label} in the following TypeScript format:
|
||||
\`\`\`ts
|
||||
${tsTypeString}
|
||||
\`\`\``
|
||||
})
|
||||
}
|
||||
const outputMsg = this.outputMessage()
|
||||
if (outputMsg !== null) {
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: outputMsg
|
||||
})
|
||||
}
|
||||
|
||||
if (ctx?.retryMessage) {
|
||||
|
@ -196,6 +268,50 @@ export abstract class BaseChatCompletion<
|
|||
return messages
|
||||
}
|
||||
|
||||
public outputMessage(): string | null {
|
||||
const outputSchema = this._outputSchema
|
||||
|
||||
if (!outputSchema) {
|
||||
return null
|
||||
}
|
||||
|
||||
// TODO: replace zod-to-ts with zod-to-json-schema?
|
||||
const { node } = zodToTs(outputSchema)
|
||||
|
||||
if (node.kind === 152) {
|
||||
// Handle raw strings differently:
|
||||
return dedent`Output a raw string only, without any additional text.`
|
||||
}
|
||||
|
||||
const tsTypeString = printNode(node, {
|
||||
removeComments: false,
|
||||
// TODO: this doesn't seem to actually work, so we're doing it manually below
|
||||
omitTrailingSemicolon: true,
|
||||
noEmitHelpers: true
|
||||
})
|
||||
.replace(/^ {4}/gm, ' ')
|
||||
.replace(/;$/gm, '')
|
||||
let label: string
|
||||
if (outputSchema instanceof z.ZodArray) {
|
||||
label = 'JSON array (minified)'
|
||||
} else if (outputSchema instanceof z.ZodObject) {
|
||||
label = 'JSON object (minified)'
|
||||
} else if (outputSchema instanceof z.ZodNumber) {
|
||||
label = 'number'
|
||||
} else if (outputSchema instanceof z.ZodString) {
|
||||
label = 'string'
|
||||
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||
label = 'boolean'
|
||||
} else {
|
||||
label = 'JSON value'
|
||||
}
|
||||
|
||||
return dedent`Do not output code. Output a single ${label} in the following TypeScript format:
|
||||
\`\`\`ts
|
||||
${tsTypeString}
|
||||
\`\`\``
|
||||
}
|
||||
|
||||
protected override async _call(
|
||||
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
|
||||
): Promise<TOutput> {
|
||||
|
@ -356,89 +472,7 @@ export abstract class BaseChatCompletion<
|
|||
// console.log('<<<')
|
||||
|
||||
if (this._outputSchema) {
|
||||
const outputSchema = this._outputSchema
|
||||
|
||||
if (outputSchema instanceof z.ZodArray) {
|
||||
try {
|
||||
const trimmedOutput = extractJSONArrayFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
} catch (err: any) {
|
||||
if (err instanceof JSONRepairError) {
|
||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
} else if (err instanceof SyntaxError) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid JSON array: ${err.message}`,
|
||||
{ cause: err }
|
||||
)
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodObject) {
|
||||
try {
|
||||
const trimmedOutput = extractJSONObjectFromString(output)
|
||||
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
|
||||
|
||||
if (Array.isArray(output)) {
|
||||
// TODO
|
||||
output = output[0]
|
||||
}
|
||||
} catch (err: any) {
|
||||
if (err instanceof JSONRepairError) {
|
||||
throw new errors.OutputValidationError(err.message, { cause: err })
|
||||
} else if (err instanceof SyntaxError) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid JSON object: ${err.message}`,
|
||||
{ cause: err }
|
||||
)
|
||||
} else {
|
||||
throw err
|
||||
}
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodBoolean) {
|
||||
output = output.toLowerCase().trim()
|
||||
const booleanOutputs = {
|
||||
true: true,
|
||||
false: false,
|
||||
yes: true,
|
||||
no: false,
|
||||
1: true,
|
||||
0: false
|
||||
}
|
||||
|
||||
const booleanOutput = booleanOutputs[output]
|
||||
|
||||
if (booleanOutput !== undefined) {
|
||||
output = booleanOutput
|
||||
} else {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid boolean output: ${output}`
|
||||
)
|
||||
}
|
||||
} else if (outputSchema instanceof z.ZodNumber) {
|
||||
output = output.trim()
|
||||
|
||||
const numberOutput = outputSchema.isInt
|
||||
? parseInt(output)
|
||||
: parseFloat(output)
|
||||
|
||||
if (isNaN(numberOutput)) {
|
||||
throw new errors.OutputValidationError(
|
||||
`Invalid number output: ${output}`
|
||||
)
|
||||
} else {
|
||||
output = numberOutput
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: fix typescript issue here with recursive types
|
||||
const safeResult = (outputSchema.safeParse as any)(output)
|
||||
|
||||
if (!safeResult.success) {
|
||||
throw new errors.ZodOutputValidationError(safeResult.error)
|
||||
}
|
||||
|
||||
return safeResult.data
|
||||
return parseOutput(output, this._outputSchema)
|
||||
} else {
|
||||
return output
|
||||
}
|
||||
|
|
Ładowanie…
Reference in New Issue