From b051e07d0b783920bc25f81f473b9005b4b43bdf Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Fri, 23 Jun 2023 16:20:06 -0700 Subject: [PATCH] feat: add midjourney tool --- legacy/examples/midjourney.ts | 46 +++++++++++++ legacy/package.json | 1 + legacy/pnpm-lock.yaml | 3 + legacy/src/tools/index.ts | 1 + legacy/src/tools/metaphor.ts | 9 +-- legacy/src/tools/midjourney.ts | 115 +++++++++++++++++++++++++++++++++ 6 files changed, 167 insertions(+), 8 deletions(-) create mode 100644 legacy/examples/midjourney.ts create mode 100644 legacy/src/tools/midjourney.ts diff --git a/legacy/examples/midjourney.ts b/legacy/examples/midjourney.ts new file mode 100644 index 00000000..966d2fdd --- /dev/null +++ b/legacy/examples/midjourney.ts @@ -0,0 +1,46 @@ +import { Midjourney } from '@agentic/midjourney-fetch' +import 'dotenv/config' +import { OpenAIClient } from 'openai-fetch' +import { z } from 'zod' + +import { Agentic, MidjourneyImagineTool } from '@/index' + +async function main() { + const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! }) + const agentic = new Agentic({ openai }) + console.log({ + channelId: process.env.MIDJOURNEY_CHANNEL_ID!, + serverId: process.env.MIDJOURNEY_SERVER_ID!, + token: process.env.MIDJOURNEY_TOKEN! + // applicationId: process.env.MIDJOURNEY_APPLICATION_ID!, + // version: process.env.MIDJOURNEY_VERSION!, + // id: process.env.MIDJOURNEY_ID! + }) + const midjourney = new Midjourney({ + channelId: process.env.MIDJOURNEY_CHANNEL_ID!, + serverId: process.env.MIDJOURNEY_SERVER_ID!, + token: process.env.MIDJOURNEY_TOKEN! + // applicationId: process.env.MIDJOURNEY_APPLICATION_ID!, + // version: process.env.MIDJOURNEY_VERSION!, + // id: process.env.MIDJOURNEY_ID! + }) + + const topic = process.argv[2] || 'san francisco' + + const res = await agentic + .gpt3(`Generate 2 creative images of {{topic}}`) + .modelParams({ temperature: 1.0 }) + .tools([new MidjourneyImagineTool({ midjourney })]) + .input( + z.object({ + topic: z.string() + }) + ) + .call({ + topic + }) + + console.log(`\n\n\n${res}\n\n\n`) +} + +main() diff --git a/legacy/package.json b/legacy/package.json index 88cbb81c..01e9d53c 100644 --- a/legacy/package.json +++ b/legacy/package.json @@ -39,6 +39,7 @@ "test-cov": "c8 ava" }, "dependencies": { + "@agentic/midjourney-fetch": "^1.0.1", "@anthropic-ai/sdk": "^0.4.4", "@inquirer/checkbox": "^1.3.2", "@inquirer/editor": "^1.2.1", diff --git a/legacy/pnpm-lock.yaml b/legacy/pnpm-lock.yaml index 650c04d6..208d9779 100644 --- a/legacy/pnpm-lock.yaml +++ b/legacy/pnpm-lock.yaml @@ -5,6 +5,9 @@ settings: excludeLinksFromLockfile: false dependencies: + '@agentic/midjourney-fetch': + specifier: ^1.0.1 + version: link:../../temp/midjourney-fetch '@anthropic-ai/sdk': specifier: ^0.4.4 version: 0.4.4 diff --git a/legacy/src/tools/index.ts b/legacy/src/tools/index.ts index 75318a36..004efcbd 100644 --- a/legacy/src/tools/index.ts +++ b/legacy/src/tools/index.ts @@ -1,6 +1,7 @@ export * from './calculator' export * from './diffbot' export * from './metaphor' +export * from './midjourney' export * from './novu' export * from './search-and-crawl' export * from './serpapi' diff --git a/legacy/src/tools/metaphor.ts b/legacy/src/tools/metaphor.ts index a963bf1a..8c85c3aa 100644 --- a/legacy/src/tools/metaphor.ts +++ b/legacy/src/tools/metaphor.ts @@ -6,14 +6,7 @@ import { BaseTask } from '@/task' export const MetaphorInputSchema = z.object({ query: z.string(), - numResults: z.number().optional(), - useQueryExpansion: z.boolean().optional(), - includeDomains: z.array(z.string()).optional(), - excludeDomains: z.array(z.string()).optional(), - startCrawlDate: z.string().optional(), - endCrawlDate: z.string().optional(), - startPublishedDate: z.string().optional(), - endPublishedDate: z.string().optional() + numResults: z.number().optional() }) export const MetaphorOutputSchema = z.object({ diff --git a/legacy/src/tools/midjourney.ts b/legacy/src/tools/midjourney.ts new file mode 100644 index 00000000..e40c0b5f --- /dev/null +++ b/legacy/src/tools/midjourney.ts @@ -0,0 +1,115 @@ +import { type Midjourney } from '@agentic/midjourney-fetch' +import pMap from 'p-map' +import { z } from 'zod' + +import * as types from '@/types' +import { BaseTask } from '@/task' + +export const MidjourneyInputSchema = z.object({ + prompt: z + .string() + .describe( + 'Simple, short, comma-separated list of phrases which describe the image you want to generate' + ), + numImages: z.number().default(1).optional() +}) +export type MidjourneyInput = z.infer + +export const MidjourneyImageSchema = z.object({ + id: z.string(), + url: z.string(), + components: z.array( + z + .object({ + type: z.number(), + components: z.array( + z.object({ + type: z.number(), + custom_id: z.string(), + style: z.number(), + label: z.string() + }) + ) + }) + .deepPartial() + ) +}) +export const MidjourneyOutputSchema = z.object({ + images: z.array(MidjourneyImageSchema) +}) +export type MidjourneyOutput = z.infer + +export class MidjourneyImagineTool extends BaseTask< + MidjourneyInput, + MidjourneyOutput +> { + protected _midjourneyClient: Midjourney + + constructor( + opts: { + midjourney: Midjourney + } & types.BaseTaskOptions + ) { + super(opts) + + this._midjourneyClient = opts.midjourney + } + + public override get inputSchema() { + return MidjourneyInputSchema + } + + public override get outputSchema() { + return MidjourneyOutputSchema + } + + public override get nameForModel(): string { + return 'midjourneyImagine' + } + + public override get descForModel(): string { + return 'Creates one or more images from a prompt using the Midjourney API. Useful for generating images on the fly.' + } + + protected override async _call( + ctx: types.TaskCallContext + ): Promise { + const numImages = ctx.input!.numImages || 1 + + const images = ( + await pMap( + new Array(numImages).fill(0), + async () => { + try { + const message = await this._midjourneyClient.imagine( + ctx.input!.prompt + ) + + if (!message) { + throw new Error('Midjourney API failed to return a message') + } + + const attachment = message.attachments?.[0] + if (!attachment) { + throw new Error('Midjourney API returned invalid message') + } + + return { + id: message.id, + url: attachment.url, + components: message.components + } + } catch (err) { + this._logger.error(err, 'Midjourney API error') + return null + } + }, + { + concurrency: 1 + } + ) + ).filter(Boolean) + + return this.outputSchema.parse({ images }) + } +}