kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add midjourney tool
rodzic
fc033e669c
commit
cd4a2a9236
|
@ -0,0 +1,46 @@
|
||||||
|
import { Midjourney } from '@agentic/midjourney-fetch'
|
||||||
|
import 'dotenv/config'
|
||||||
|
import { OpenAIClient } from 'openai-fetch'
|
||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
|
import { Agentic, MidjourneyImagineTool } from '@/index'
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const openai = new OpenAIClient({ apiKey: process.env.OPENAI_API_KEY! })
|
||||||
|
const agentic = new Agentic({ openai })
|
||||||
|
console.log({
|
||||||
|
channelId: process.env.MIDJOURNEY_CHANNEL_ID!,
|
||||||
|
serverId: process.env.MIDJOURNEY_SERVER_ID!,
|
||||||
|
token: process.env.MIDJOURNEY_TOKEN!
|
||||||
|
// applicationId: process.env.MIDJOURNEY_APPLICATION_ID!,
|
||||||
|
// version: process.env.MIDJOURNEY_VERSION!,
|
||||||
|
// id: process.env.MIDJOURNEY_ID!
|
||||||
|
})
|
||||||
|
const midjourney = new Midjourney({
|
||||||
|
channelId: process.env.MIDJOURNEY_CHANNEL_ID!,
|
||||||
|
serverId: process.env.MIDJOURNEY_SERVER_ID!,
|
||||||
|
token: process.env.MIDJOURNEY_TOKEN!
|
||||||
|
// applicationId: process.env.MIDJOURNEY_APPLICATION_ID!,
|
||||||
|
// version: process.env.MIDJOURNEY_VERSION!,
|
||||||
|
// id: process.env.MIDJOURNEY_ID!
|
||||||
|
})
|
||||||
|
|
||||||
|
const topic = process.argv[2] || 'san francisco'
|
||||||
|
|
||||||
|
const res = await agentic
|
||||||
|
.gpt3(`Generate 2 creative images of {{topic}}`)
|
||||||
|
.modelParams({ temperature: 1.0 })
|
||||||
|
.tools([new MidjourneyImagineTool({ midjourney })])
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
topic: z.string()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.call({
|
||||||
|
topic
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log(`\n\n\n${res}\n\n\n`)
|
||||||
|
}
|
||||||
|
|
||||||
|
main()
|
|
@ -39,6 +39,7 @@
|
||||||
"test-cov": "c8 ava"
|
"test-cov": "c8 ava"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@agentic/midjourney-fetch": "^1.0.1",
|
||||||
"@anthropic-ai/sdk": "^0.4.4",
|
"@anthropic-ai/sdk": "^0.4.4",
|
||||||
"@inquirer/checkbox": "^1.3.2",
|
"@inquirer/checkbox": "^1.3.2",
|
||||||
"@inquirer/editor": "^1.2.1",
|
"@inquirer/editor": "^1.2.1",
|
||||||
|
|
|
@ -5,6 +5,9 @@ settings:
|
||||||
excludeLinksFromLockfile: false
|
excludeLinksFromLockfile: false
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
|
'@agentic/midjourney-fetch':
|
||||||
|
specifier: ^1.0.1
|
||||||
|
version: link:../../temp/midjourney-fetch
|
||||||
'@anthropic-ai/sdk':
|
'@anthropic-ai/sdk':
|
||||||
specifier: ^0.4.4
|
specifier: ^0.4.4
|
||||||
version: 0.4.4
|
version: 0.4.4
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
export * from './calculator'
|
export * from './calculator'
|
||||||
export * from './diffbot'
|
export * from './diffbot'
|
||||||
export * from './metaphor'
|
export * from './metaphor'
|
||||||
|
export * from './midjourney'
|
||||||
export * from './novu'
|
export * from './novu'
|
||||||
export * from './search-and-crawl'
|
export * from './search-and-crawl'
|
||||||
export * from './serpapi'
|
export * from './serpapi'
|
||||||
|
|
|
@ -6,14 +6,7 @@ import { BaseTask } from '@/task'
|
||||||
|
|
||||||
export const MetaphorInputSchema = z.object({
|
export const MetaphorInputSchema = z.object({
|
||||||
query: z.string(),
|
query: z.string(),
|
||||||
numResults: z.number().optional(),
|
numResults: z.number().optional()
|
||||||
useQueryExpansion: z.boolean().optional(),
|
|
||||||
includeDomains: z.array(z.string()).optional(),
|
|
||||||
excludeDomains: z.array(z.string()).optional(),
|
|
||||||
startCrawlDate: z.string().optional(),
|
|
||||||
endCrawlDate: z.string().optional(),
|
|
||||||
startPublishedDate: z.string().optional(),
|
|
||||||
endPublishedDate: z.string().optional()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
export const MetaphorOutputSchema = z.object({
|
export const MetaphorOutputSchema = z.object({
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
import { type Midjourney } from '@agentic/midjourney-fetch'
|
||||||
|
import pMap from 'p-map'
|
||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
|
import * as types from '@/types'
|
||||||
|
import { BaseTask } from '@/task'
|
||||||
|
|
||||||
|
export const MidjourneyInputSchema = z.object({
|
||||||
|
prompt: z
|
||||||
|
.string()
|
||||||
|
.describe(
|
||||||
|
'Simple, short, comma-separated list of phrases which describe the image you want to generate'
|
||||||
|
),
|
||||||
|
numImages: z.number().default(1).optional()
|
||||||
|
})
|
||||||
|
export type MidjourneyInput = z.infer<typeof MidjourneyInputSchema>
|
||||||
|
|
||||||
|
export const MidjourneyImageSchema = z.object({
|
||||||
|
id: z.string(),
|
||||||
|
url: z.string(),
|
||||||
|
components: z.array(
|
||||||
|
z
|
||||||
|
.object({
|
||||||
|
type: z.number(),
|
||||||
|
components: z.array(
|
||||||
|
z.object({
|
||||||
|
type: z.number(),
|
||||||
|
custom_id: z.string(),
|
||||||
|
style: z.number(),
|
||||||
|
label: z.string()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.deepPartial()
|
||||||
|
)
|
||||||
|
})
|
||||||
|
export const MidjourneyOutputSchema = z.object({
|
||||||
|
images: z.array(MidjourneyImageSchema)
|
||||||
|
})
|
||||||
|
export type MidjourneyOutput = z.infer<typeof MidjourneyOutputSchema>
|
||||||
|
|
||||||
|
export class MidjourneyImagineTool extends BaseTask<
|
||||||
|
MidjourneyInput,
|
||||||
|
MidjourneyOutput
|
||||||
|
> {
|
||||||
|
protected _midjourneyClient: Midjourney
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
opts: {
|
||||||
|
midjourney: Midjourney
|
||||||
|
} & types.BaseTaskOptions
|
||||||
|
) {
|
||||||
|
super(opts)
|
||||||
|
|
||||||
|
this._midjourneyClient = opts.midjourney
|
||||||
|
}
|
||||||
|
|
||||||
|
public override get inputSchema() {
|
||||||
|
return MidjourneyInputSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
public override get outputSchema() {
|
||||||
|
return MidjourneyOutputSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
public override get nameForModel(): string {
|
||||||
|
return 'midjourneyImagine'
|
||||||
|
}
|
||||||
|
|
||||||
|
public override get descForModel(): string {
|
||||||
|
return 'Creates one or more images from a prompt using the Midjourney API. Useful for generating images on the fly.'
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override async _call(
|
||||||
|
ctx: types.TaskCallContext<MidjourneyInput>
|
||||||
|
): Promise<MidjourneyOutput> {
|
||||||
|
const numImages = ctx.input!.numImages || 1
|
||||||
|
|
||||||
|
const images = (
|
||||||
|
await pMap(
|
||||||
|
new Array(numImages).fill(0),
|
||||||
|
async () => {
|
||||||
|
try {
|
||||||
|
const message = await this._midjourneyClient.imagine(
|
||||||
|
ctx.input!.prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
if (!message) {
|
||||||
|
throw new Error('Midjourney API failed to return a message')
|
||||||
|
}
|
||||||
|
|
||||||
|
const attachment = message.attachments?.[0]
|
||||||
|
if (!attachment) {
|
||||||
|
throw new Error('Midjourney API returned invalid message')
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: message.id,
|
||||||
|
url: attachment.url,
|
||||||
|
components: message.components
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
this._logger.error(err, 'Midjourney API error')
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
concurrency: 1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
).filter(Boolean)
|
||||||
|
|
||||||
|
return this.outputSchema.parse({ images })
|
||||||
|
}
|
||||||
|
}
|
Ładowanie…
Reference in New Issue