kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add Replicate stable diffusion tool
rodzic
cd4a2a9236
commit
ef7e350ddc
|
@ -0,0 +1,37 @@
|
|||
import 'dotenv/config'
|
||||
import { OpenAIClient } from 'openai-fetch'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { Agentic, ReplicateStableDiffusionTool } from '@/index'
|
||||
|
||||
async function main() {
|
||||
const openai = new OpenAIClient({
|
||||
apiKey: process.env.OPENAI_API_KEY!,
|
||||
fetchOptions: {
|
||||
timeout: false
|
||||
}
|
||||
})
|
||||
const agentic = new Agentic({ openai })
|
||||
|
||||
const topic = process.argv[2] || 'san francisco'
|
||||
|
||||
const res = await agentic
|
||||
.gpt4(
|
||||
`Generate {{numImages}} images of {{topic}}. Use prompts that are artistic and creative.`
|
||||
)
|
||||
.modelParams({ temperature: 1.0 })
|
||||
.tools([new ReplicateStableDiffusionTool()])
|
||||
.input(
|
||||
z.object({
|
||||
numImages: z.number().default(5),
|
||||
topic: z.string()
|
||||
})
|
||||
)
|
||||
.call({
|
||||
topic
|
||||
})
|
||||
|
||||
console.log(`\n\n\n${res}\n\n\n`)
|
||||
}
|
||||
|
||||
main()
|
|
@ -63,6 +63,7 @@
|
|||
"pino": "^8.14.1",
|
||||
"pino-pretty": "^10.0.0",
|
||||
"quick-lru": "^6.1.1",
|
||||
"replicate": "^0.12.3",
|
||||
"ts-dedent": "^2.2.0",
|
||||
"uuid": "^9.0.0",
|
||||
"zod": "^3.21.4",
|
||||
|
|
|
@ -77,6 +77,9 @@ dependencies:
|
|||
quick-lru:
|
||||
specifier: ^6.1.1
|
||||
version: 6.1.1
|
||||
replicate:
|
||||
specifier: ^0.12.3
|
||||
version: 0.12.3
|
||||
ts-dedent:
|
||||
specifier: ^2.2.0
|
||||
version: 2.2.0
|
||||
|
@ -3884,6 +3887,11 @@ packages:
|
|||
functions-have-names: 1.2.3
|
||||
dev: true
|
||||
|
||||
/replicate@0.12.3:
|
||||
resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==}
|
||||
engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'}
|
||||
dev: false
|
||||
|
||||
/require-directory@2.1.1:
|
||||
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
|
|
|
@ -3,6 +3,7 @@ export * from './diffbot'
|
|||
export * from './metaphor'
|
||||
export * from './midjourney'
|
||||
export * from './novu'
|
||||
export * from './replicate'
|
||||
export * from './search-and-crawl'
|
||||
export * from './serpapi'
|
||||
export * from './weather'
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
import pMap from 'p-map'
|
||||
import Replicate from 'replicate'
|
||||
import { z } from 'zod'
|
||||
|
||||
import * as types from '@/types'
|
||||
import { getEnv } from '@/env'
|
||||
import { BaseTask } from '@/task'
|
||||
|
||||
const REPLICATE_SD_MODEL =
|
||||
'stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf'
|
||||
|
||||
const ReplicateInputSchema = z.object({
|
||||
images: z.array(
|
||||
z.object({
|
||||
prompt: z
|
||||
.string()
|
||||
.describe(
|
||||
'Simple, short, comma-separated list of phrases which describes the image you want to generate'
|
||||
),
|
||||
negativePrompt: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe(
|
||||
'Simple, short, comma-separated list of phrases which describes qualities of the image you do NOT want to generate. Example: low quality, blurry, pixelated'
|
||||
)
|
||||
// seed: z.number().int().optional()
|
||||
})
|
||||
)
|
||||
})
|
||||
type ReplicateInput = z.infer<typeof ReplicateInputSchema>
|
||||
|
||||
const ReplicateOutputSchema = z.object({
|
||||
images: z.array(z.string())
|
||||
})
|
||||
type ReplicateOutput = z.infer<typeof ReplicateOutputSchema>
|
||||
|
||||
export class ReplicateStableDiffusionTool extends BaseTask<
|
||||
ReplicateInput,
|
||||
ReplicateOutput
|
||||
> {
|
||||
protected _replicateClient: Replicate
|
||||
|
||||
constructor(
|
||||
opts: {
|
||||
replicate?: Replicate
|
||||
} & types.BaseTaskOptions = {}
|
||||
) {
|
||||
super(opts)
|
||||
|
||||
this._replicateClient =
|
||||
opts.replicate ||
|
||||
new Replicate({
|
||||
auth: getEnv('REPLICATE_API_KEY')!
|
||||
})
|
||||
}
|
||||
|
||||
public override get inputSchema() {
|
||||
return ReplicateInputSchema
|
||||
}
|
||||
|
||||
public override get outputSchema() {
|
||||
return ReplicateOutputSchema
|
||||
}
|
||||
|
||||
public override get nameForModel(): string {
|
||||
return 'replicateStableDiffusion'
|
||||
}
|
||||
|
||||
public override get descForModel(): string {
|
||||
return 'Creates one or more images from a prompt using the Replicate stable diffusion API. Useful for generating images on the fly.'
|
||||
}
|
||||
|
||||
protected override async _call(
|
||||
ctx: types.TaskCallContext<ReplicateInput>
|
||||
): Promise<ReplicateOutput> {
|
||||
const images = (
|
||||
await pMap(
|
||||
ctx.input!.images,
|
||||
async (image) => {
|
||||
try {
|
||||
const input = {
|
||||
prompt: image.prompt
|
||||
}
|
||||
|
||||
if (image.negativePrompt) {
|
||||
input['negative_prompt'] = image.negativePrompt
|
||||
}
|
||||
|
||||
console.log('>>> replicate', image)
|
||||
const output: any = await this._replicateClient.run(
|
||||
REPLICATE_SD_MODEL,
|
||||
{
|
||||
input
|
||||
}
|
||||
)
|
||||
|
||||
console.log('<<< replicate', image, output)
|
||||
return output
|
||||
} catch (err) {
|
||||
this._logger.error(err, 'Replicate API error')
|
||||
return []
|
||||
}
|
||||
},
|
||||
{
|
||||
concurrency: 5
|
||||
}
|
||||
)
|
||||
).flat()
|
||||
|
||||
// console.log('replicate output', images)
|
||||
return this.outputSchema.parse({ images })
|
||||
}
|
||||
}
|
Ładowanie…
Reference in New Issue