feat: separate out functions, generate minified JSON, improve Boolean handling

old-agentic-v1^2
Philipp Burckhardt 2023-06-19 15:03:43 -04:00
rodzic eeb2315879
commit e4d929c11f
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A2C3BCA4F31D1DDD
1 zmienionych plików z 158 dodań i 124 usunięć

Wyświetl plik

@ -21,6 +21,113 @@ import {
getNumTokensForChatMessages
} from './llm-utils'
const BOOLEAN_OUTPUTS = {
true: true,
false: false,
t: true,
f: false,
yes: true,
no: false,
y: true,
n: false,
'1': true,
'0': false
}
function parseArrayOutput(output: string): Array<any> {
try {
const trimmedOutput = extractJSONArrayFromString(output)
const parsedOutput = JSON.parse(jsonrepair(trimmedOutput ?? output))
return parsedOutput
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON array: ${err.message}`,
{ cause: err }
)
} else {
throw err
}
}
}
function parseObjectOutput(output) {
try {
const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
if (Array.isArray(output)) {
// TODO
output = output[0]
}
return output
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON object: ${err.message}`,
{ cause: err }
)
} else {
throw err
}
}
}
function parseBooleanOutput(output): boolean {
output = output
.toLowerCase()
.trim()
.replace(/[.!?]+$/, '')
const booleanOutput = BOOLEAN_OUTPUTS[output]
if (booleanOutput !== undefined) {
return booleanOutput
} else {
throw new errors.OutputValidationError(`Invalid boolean output: ${output}`)
}
}
function parseNumberOutput(output, outputSchema: z.ZodNumber): number {
output = output.trim()
const numberOutput = outputSchema.isInt
? parseInt(output)
: parseFloat(output)
if (isNaN(numberOutput)) {
throw new errors.OutputValidationError(`Invalid number output: ${output}`)
}
return numberOutput
}
function parseOutput(output: any, outputSchema: ZodType<any>) {
if (outputSchema instanceof z.ZodArray) {
output = parseArrayOutput(output)
} else if (outputSchema instanceof z.ZodObject) {
output = parseObjectOutput(output)
} else if (outputSchema instanceof z.ZodBoolean) {
output = parseBooleanOutput(output)
} else if (outputSchema instanceof z.ZodNumber) {
output = parseNumberOutput(output, outputSchema)
}
// TODO: fix typescript issue here with recursive types
const safeResult = (outputSchema.safeParse as any)(output)
if (!safeResult.success) {
throw new errors.ZodOutputValidationError(safeResult.error)
}
return safeResult.data
}
export abstract class BaseChatCompletion<
TInput extends types.TaskInput = void,
TOutput extends types.TaskOutput = string,
@ -141,47 +248,12 @@ export abstract class BaseChatCompletion<
}
}
if (this._outputSchema) {
// TODO: replace zod-to-ts with zod-to-json-schema?
const { node } = zodToTs(this._outputSchema)
if (node.kind === 152) {
// handle raw strings differently
messages.push({
role: 'system',
content: dedent`Output a raw string only, without any additional text.`
})
} else {
const tsTypeString = printNode(node, {
removeComments: false,
// TODO: this doesn't seem to actually work, so we're doing it manually below
omitTrailingSemicolon: true,
noEmitHelpers: true
})
.replace(/^ {4}/gm, ' ')
.replace(/;$/gm, '')
const label =
this._outputSchema instanceof z.ZodArray
? 'JSON array'
: this._outputSchema instanceof z.ZodObject
? 'JSON object'
: this._outputSchema instanceof z.ZodNumber
? 'number'
: this._outputSchema instanceof z.ZodString
? 'string'
: this._outputSchema instanceof z.ZodBoolean
? 'boolean'
: 'JSON value'
messages.push({
role: 'system',
content: dedent`Do not output code. Output a single ${label} in the following TypeScript format:
\`\`\`ts
${tsTypeString}
\`\`\``
})
}
const outputMsg = this.outputMessage()
if (outputMsg !== null) {
messages.push({
role: 'system',
content: outputMsg
})
}
if (ctx?.retryMessage) {
@ -196,6 +268,50 @@ export abstract class BaseChatCompletion<
return messages
}
public outputMessage(): string | null {
const outputSchema = this._outputSchema
if (!outputSchema) {
return null
}
// TODO: replace zod-to-ts with zod-to-json-schema?
const { node } = zodToTs(outputSchema)
if (node.kind === 152) {
// Handle raw strings differently:
return dedent`Output a raw string only, without any additional text.`
}
const tsTypeString = printNode(node, {
removeComments: false,
// TODO: this doesn't seem to actually work, so we're doing it manually below
omitTrailingSemicolon: true,
noEmitHelpers: true
})
.replace(/^ {4}/gm, ' ')
.replace(/;$/gm, '')
let label: string
if (outputSchema instanceof z.ZodArray) {
label = 'JSON array (minified)'
} else if (outputSchema instanceof z.ZodObject) {
label = 'JSON object (minified)'
} else if (outputSchema instanceof z.ZodNumber) {
label = 'number'
} else if (outputSchema instanceof z.ZodString) {
label = 'string'
} else if (outputSchema instanceof z.ZodBoolean) {
label = 'boolean'
} else {
label = 'JSON value'
}
return dedent`Do not output code. Output a single ${label} in the following TypeScript format:
\`\`\`ts
${tsTypeString}
\`\`\``
}
protected override async _call(
ctx: types.TaskCallContext<TInput, types.LLMTaskResponseMetadata>
): Promise<TOutput> {
@ -356,89 +472,7 @@ export abstract class BaseChatCompletion<
// console.log('<<<')
if (this._outputSchema) {
const outputSchema = this._outputSchema
if (outputSchema instanceof z.ZodArray) {
try {
const trimmedOutput = extractJSONArrayFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON array: ${err.message}`,
{ cause: err }
)
} else {
throw err
}
}
} else if (outputSchema instanceof z.ZodObject) {
try {
const trimmedOutput = extractJSONObjectFromString(output)
output = JSON.parse(jsonrepair(trimmedOutput ?? output))
if (Array.isArray(output)) {
// TODO
output = output[0]
}
} catch (err: any) {
if (err instanceof JSONRepairError) {
throw new errors.OutputValidationError(err.message, { cause: err })
} else if (err instanceof SyntaxError) {
throw new errors.OutputValidationError(
`Invalid JSON object: ${err.message}`,
{ cause: err }
)
} else {
throw err
}
}
} else if (outputSchema instanceof z.ZodBoolean) {
output = output.toLowerCase().trim()
const booleanOutputs = {
true: true,
false: false,
yes: true,
no: false,
1: true,
0: false
}
const booleanOutput = booleanOutputs[output]
if (booleanOutput !== undefined) {
output = booleanOutput
} else {
throw new errors.OutputValidationError(
`Invalid boolean output: ${output}`
)
}
} else if (outputSchema instanceof z.ZodNumber) {
output = output.trim()
const numberOutput = outputSchema.isInt
? parseInt(output)
: parseFloat(output)
if (isNaN(numberOutput)) {
throw new errors.OutputValidationError(
`Invalid number output: ${output}`
)
} else {
output = numberOutput
}
}
// TODO: fix typescript issue here with recursive types
const safeResult = (outputSchema.safeParse as any)(output)
if (!safeResult.success) {
throw new errors.ZodOutputValidationError(safeResult.error)
}
return safeResult.data
return parseOutput(output, this._outputSchema)
} else {
return output
}