feat: add Replicate stable diffusion tool

Travis Fischer 2023-06-23 17:37:49 -07:00
rodzic cd4a2a9236
commit ef7e350ddc
5 zmienionych plików z 160 dodań i 0 usunięć

37
examples/replicate.ts vendored 100644
Wyświetl plik

@ -0,0 +1,37 @@
import 'dotenv/config'
import { OpenAIClient } from 'openai-fetch'
import { z } from 'zod'
import { Agentic, ReplicateStableDiffusionTool } from '@/index'
async function main() {
const openai = new OpenAIClient({
apiKey: process.env.OPENAI_API_KEY!,
fetchOptions: {
timeout: false
}
})
const agentic = new Agentic({ openai })
const topic = process.argv[2] || 'san francisco'
const res = await agentic
.gpt4(
`Generate {{numImages}} images of {{topic}}. Use prompts that are artistic and creative.`
)
.modelParams({ temperature: 1.0 })
.tools([new ReplicateStableDiffusionTool()])
.input(
z.object({
numImages: z.number().default(5),
topic: z.string()
})
)
.call({
topic
})
console.log(`\n\n\n${res}\n\n\n`)
}
main()

Wyświetl plik

@ -63,6 +63,7 @@
"pino": "^8.14.1",
"pino-pretty": "^10.0.0",
"quick-lru": "^6.1.1",
"replicate": "^0.12.3",
"ts-dedent": "^2.2.0",
"uuid": "^9.0.0",
"zod": "^3.21.4",

Wyświetl plik

@ -77,6 +77,9 @@ dependencies:
quick-lru:
specifier: ^6.1.1
version: 6.1.1
replicate:
specifier: ^0.12.3
version: 0.12.3
ts-dedent:
specifier: ^2.2.0
version: 2.2.0
@ -3884,6 +3887,11 @@ packages:
functions-have-names: 1.2.3
dev: true
/replicate@0.12.3:
resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==}
engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'}
dev: false
/require-directory@2.1.1:
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
engines: {node: '>=0.10.0'}

Wyświetl plik

@ -3,6 +3,7 @@ export * from './diffbot'
export * from './metaphor'
export * from './midjourney'
export * from './novu'
export * from './replicate'
export * from './search-and-crawl'
export * from './serpapi'
export * from './weather'

Wyświetl plik

@ -0,0 +1,113 @@
import pMap from 'p-map'
import Replicate from 'replicate'
import { z } from 'zod'
import * as types from '@/types'
import { getEnv } from '@/env'
import { BaseTask } from '@/task'
const REPLICATE_SD_MODEL =
'stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf'
const ReplicateInputSchema = z.object({
images: z.array(
z.object({
prompt: z
.string()
.describe(
'Simple, short, comma-separated list of phrases which describes the image you want to generate'
),
negativePrompt: z
.string()
.optional()
.describe(
'Simple, short, comma-separated list of phrases which describes qualities of the image you do NOT want to generate. Example: low quality, blurry, pixelated'
)
// seed: z.number().int().optional()
})
)
})
type ReplicateInput = z.infer<typeof ReplicateInputSchema>
const ReplicateOutputSchema = z.object({
images: z.array(z.string())
})
type ReplicateOutput = z.infer<typeof ReplicateOutputSchema>
export class ReplicateStableDiffusionTool extends BaseTask<
ReplicateInput,
ReplicateOutput
> {
protected _replicateClient: Replicate
constructor(
opts: {
replicate?: Replicate
} & types.BaseTaskOptions = {}
) {
super(opts)
this._replicateClient =
opts.replicate ||
new Replicate({
auth: getEnv('REPLICATE_API_KEY')!
})
}
public override get inputSchema() {
return ReplicateInputSchema
}
public override get outputSchema() {
return ReplicateOutputSchema
}
public override get nameForModel(): string {
return 'replicateStableDiffusion'
}
public override get descForModel(): string {
return 'Creates one or more images from a prompt using the Replicate stable diffusion API. Useful for generating images on the fly.'
}
protected override async _call(
ctx: types.TaskCallContext<ReplicateInput>
): Promise<ReplicateOutput> {
const images = (
await pMap(
ctx.input!.images,
async (image) => {
try {
const input = {
prompt: image.prompt
}
if (image.negativePrompt) {
input['negative_prompt'] = image.negativePrompt
}
console.log('>>> replicate', image)
const output: any = await this._replicateClient.run(
REPLICATE_SD_MODEL,
{
input
}
)
console.log('<<< replicate', image, output)
return output
} catch (err) {
this._logger.error(err, 'Replicate API error')
return []
}
},
{
concurrency: 5
}
)
).flat()
// console.log('replicate output', images)
return this.outputSchema.parse({ images })
}
}