From 2c967478fba95e35049c9dd7865ed711c1987c96 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Fri, 23 Jun 2023 17:37:49 -0700 Subject: [PATCH] feat: add Replicate stable diffusion tool --- legacy/examples/replicate.ts | 37 +++++++++++ legacy/package.json | 1 + legacy/pnpm-lock.yaml | 8 +++ legacy/src/tools/index.ts | 1 + legacy/src/tools/replicate.ts | 113 ++++++++++++++++++++++++++++++++++ 5 files changed, 160 insertions(+) create mode 100644 legacy/examples/replicate.ts create mode 100644 legacy/src/tools/replicate.ts diff --git a/legacy/examples/replicate.ts b/legacy/examples/replicate.ts new file mode 100644 index 00000000..15f4d1b9 --- /dev/null +++ b/legacy/examples/replicate.ts @@ -0,0 +1,37 @@ +import 'dotenv/config' +import { OpenAIClient } from 'openai-fetch' +import { z } from 'zod' + +import { Agentic, ReplicateStableDiffusionTool } from '@/index' + +async function main() { + const openai = new OpenAIClient({ + apiKey: process.env.OPENAI_API_KEY!, + fetchOptions: { + timeout: false + } + }) + const agentic = new Agentic({ openai }) + + const topic = process.argv[2] || 'san francisco' + + const res = await agentic + .gpt4( + `Generate {{numImages}} images of {{topic}}. Use prompts that are artistic and creative.` + ) + .modelParams({ temperature: 1.0 }) + .tools([new ReplicateStableDiffusionTool()]) + .input( + z.object({ + numImages: z.number().default(5), + topic: z.string() + }) + ) + .call({ + topic + }) + + console.log(`\n\n\n${res}\n\n\n`) +} + +main() diff --git a/legacy/package.json b/legacy/package.json index 01e9d53c..d2c901a7 100644 --- a/legacy/package.json +++ b/legacy/package.json @@ -63,6 +63,7 @@ "pino": "^8.14.1", "pino-pretty": "^10.0.0", "quick-lru": "^6.1.1", + "replicate": "^0.12.3", "ts-dedent": "^2.2.0", "uuid": "^9.0.0", "zod": "^3.21.4", diff --git a/legacy/pnpm-lock.yaml b/legacy/pnpm-lock.yaml index 208d9779..605a5ce0 100644 --- a/legacy/pnpm-lock.yaml +++ b/legacy/pnpm-lock.yaml @@ -77,6 +77,9 @@ dependencies: quick-lru: specifier: ^6.1.1 version: 6.1.1 + replicate: + specifier: ^0.12.3 + version: 0.12.3 ts-dedent: specifier: ^2.2.0 version: 2.2.0 @@ -3884,6 +3887,11 @@ packages: functions-have-names: 1.2.3 dev: true + /replicate@0.12.3: + resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==} + engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'} + dev: false + /require-directory@2.1.1: resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==} engines: {node: '>=0.10.0'} diff --git a/legacy/src/tools/index.ts b/legacy/src/tools/index.ts index 004efcbd..968fd210 100644 --- a/legacy/src/tools/index.ts +++ b/legacy/src/tools/index.ts @@ -3,6 +3,7 @@ export * from './diffbot' export * from './metaphor' export * from './midjourney' export * from './novu' +export * from './replicate' export * from './search-and-crawl' export * from './serpapi' export * from './weather' diff --git a/legacy/src/tools/replicate.ts b/legacy/src/tools/replicate.ts new file mode 100644 index 00000000..de8966cc --- /dev/null +++ b/legacy/src/tools/replicate.ts @@ -0,0 +1,113 @@ +import pMap from 'p-map' +import Replicate from 'replicate' +import { z } from 'zod' + +import * as types from '@/types' +import { getEnv } from '@/env' +import { BaseTask } from '@/task' + +const REPLICATE_SD_MODEL = + 'stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf' + +const ReplicateInputSchema = z.object({ + images: z.array( + z.object({ + prompt: z + .string() + .describe( + 'Simple, short, comma-separated list of phrases which describes the image you want to generate' + ), + negativePrompt: z + .string() + .optional() + .describe( + 'Simple, short, comma-separated list of phrases which describes qualities of the image you do NOT want to generate. Example: low quality, blurry, pixelated' + ) + // seed: z.number().int().optional() + }) + ) +}) +type ReplicateInput = z.infer + +const ReplicateOutputSchema = z.object({ + images: z.array(z.string()) +}) +type ReplicateOutput = z.infer + +export class ReplicateStableDiffusionTool extends BaseTask< + ReplicateInput, + ReplicateOutput +> { + protected _replicateClient: Replicate + + constructor( + opts: { + replicate?: Replicate + } & types.BaseTaskOptions = {} + ) { + super(opts) + + this._replicateClient = + opts.replicate || + new Replicate({ + auth: getEnv('REPLICATE_API_KEY')! + }) + } + + public override get inputSchema() { + return ReplicateInputSchema + } + + public override get outputSchema() { + return ReplicateOutputSchema + } + + public override get nameForModel(): string { + return 'replicateStableDiffusion' + } + + public override get descForModel(): string { + return 'Creates one or more images from a prompt using the Replicate stable diffusion API. Useful for generating images on the fly.' + } + + protected override async _call( + ctx: types.TaskCallContext + ): Promise { + const images = ( + await pMap( + ctx.input!.images, + async (image) => { + try { + const input = { + prompt: image.prompt + } + + if (image.negativePrompt) { + input['negative_prompt'] = image.negativePrompt + } + + console.log('>>> replicate', image) + const output: any = await this._replicateClient.run( + REPLICATE_SD_MODEL, + { + input + } + ) + + console.log('<<< replicate', image, output) + return output + } catch (err) { + this._logger.error(err, 'Replicate API error') + return [] + } + }, + { + concurrency: 5 + } + ) + ).flat() + + // console.log('replicate output', images) + return this.outputSchema.parse({ images }) + } +}