kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: add DurableMcpClient to persist origin mcp client connections across API gateway requests
rodzic
b58958c0b0
commit
2f22119876
|
@ -6,11 +6,9 @@ import {
|
|||
responseTime,
|
||||
sentry
|
||||
} from '@agentic/platform-hono'
|
||||
import { Client as McpClient } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||
import { Hono } from 'hono'
|
||||
|
||||
import type { GatewayHonoEnv } from './lib/types'
|
||||
import type { GatewayHonoEnv, McpToolCallResponse } from './lib/types'
|
||||
import { createAgenticClient } from './lib/agentic-client'
|
||||
import { createHttpResponseFromMcpToolCallResponse } from './lib/create-http-response-from-mcp-tool-call-response'
|
||||
import { fetchCache } from './lib/fetch-cache'
|
||||
|
@ -31,10 +29,10 @@ app.use(sentry())
|
|||
app.use(
|
||||
cors({
|
||||
origin: '*',
|
||||
allowHeaders: ['Content-Type', 'Authorization'],
|
||||
allowMethods: ['POST', 'GET', 'OPTIONS'],
|
||||
exposeHeaders: ['Content-Length'],
|
||||
maxAge: 600,
|
||||
allowHeaders: ['Content-Type', 'Authorization', 'mcp-session-id'],
|
||||
allowMethods: ['GET', 'POST', 'DELETE', 'OPTIONS'],
|
||||
exposeHeaders: ['Content-Length', 'mcp-session-id'],
|
||||
maxAge: 86_400,
|
||||
credentials: true
|
||||
})
|
||||
)
|
||||
|
@ -82,25 +80,22 @@ app.all(async (ctx) => {
|
|||
500,
|
||||
'Tool args are required for MCP origin requests'
|
||||
)
|
||||
|
||||
const transport = new StreamableHTTPClientTransport(
|
||||
new URL(resolvedOriginRequest.deployment.originUrl)
|
||||
assert(
|
||||
resolvedOriginRequest.mcpClient,
|
||||
500,
|
||||
'MCP client is required for MCP origin requests'
|
||||
)
|
||||
const client = new McpClient({
|
||||
name: resolvedOriginRequest.deployment.originAdapter.serverInfo.name,
|
||||
version:
|
||||
resolvedOriginRequest.deployment.originAdapter.serverInfo.version
|
||||
})
|
||||
|
||||
// TODO: re-use client connection across requests
|
||||
await client.connect(transport)
|
||||
|
||||
// TODO: add timeout support to the origin tool call?
|
||||
// TODO: add response caching for MCP tool calls
|
||||
const toolCallResponse = await client.callTool({
|
||||
name: resolvedOriginRequest.tool.name,
|
||||
arguments: resolvedOriginRequest.toolCallArgs
|
||||
})
|
||||
const toolCallResponseString =
|
||||
await resolvedOriginRequest.mcpClient.callTool({
|
||||
name: resolvedOriginRequest.tool.name,
|
||||
args: resolvedOriginRequest.toolCallArgs
|
||||
})
|
||||
const toolCallResponse = JSON.parse(
|
||||
toolCallResponseString
|
||||
) as McpToolCallResponse
|
||||
|
||||
originResponse = await createHttpResponseFromMcpToolCallResponse(ctx, {
|
||||
tool: resolvedOriginRequest.tool,
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
import { assert } from '@agentic/platform-core'
|
||||
import { Client as McpClient } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||
import { DurableObject } from 'cloudflare:workers'
|
||||
|
||||
import type { RawEnv } from './env'
|
||||
|
||||
export type DurableMcpClientInfo = {
|
||||
url: string
|
||||
name: string
|
||||
version: string
|
||||
}
|
||||
|
||||
// TODO: not sure if there's a better way to handle re-using client connections
|
||||
// across requests.
|
||||
|
||||
export class DurableMcpClient extends DurableObject<RawEnv> {
|
||||
protected client?: McpClient
|
||||
protected clientConnectionP?: Promise<void>
|
||||
|
||||
async init(mcpClientInfo: DurableMcpClientInfo) {
|
||||
const durableMcpClientInfo =
|
||||
await this.ctx.storage.get<DurableMcpClientInfo>('mcp-client-info')
|
||||
|
||||
if (!durableMcpClientInfo) {
|
||||
await this.ctx.storage.put('mcp-client-info', mcpClientInfo)
|
||||
} else {
|
||||
assert(
|
||||
mcpClientInfo.url === durableMcpClientInfo.url,
|
||||
500,
|
||||
`DurableMcpClientInfo url mismatch: "${mcpClientInfo.url}" vs "${durableMcpClientInfo.url}"`
|
||||
)
|
||||
}
|
||||
|
||||
return this.ensureClientConnection(mcpClientInfo)
|
||||
}
|
||||
|
||||
async isInitialized(): Promise<boolean> {
|
||||
return !!(await this.ctx.storage.get('mcp-client-info'))
|
||||
}
|
||||
|
||||
async ensureClientConnection(durableMcpClientInfo?: DurableMcpClientInfo) {
|
||||
if (this.clientConnectionP) return this.clientConnectionP
|
||||
|
||||
durableMcpClientInfo ??=
|
||||
await this.ctx.storage.get<DurableMcpClientInfo>('mcp-client-info')
|
||||
assert(
|
||||
durableMcpClientInfo,
|
||||
500,
|
||||
'DurableMcpClient has not been initialized'
|
||||
)
|
||||
const { name, version, url } = durableMcpClientInfo
|
||||
|
||||
this.client = new McpClient({
|
||||
name,
|
||||
version
|
||||
})
|
||||
|
||||
const transport = new StreamableHTTPClientTransport(new URL(url))
|
||||
this.clientConnectionP = this.client.connect(transport)
|
||||
await this.clientConnectionP
|
||||
}
|
||||
|
||||
async callTool({
|
||||
name,
|
||||
args
|
||||
}: {
|
||||
name: string
|
||||
args: Record<string, unknown>
|
||||
}): Promise<string> {
|
||||
await this.ensureClientConnection()
|
||||
|
||||
const toolCallResponse = await this.client!.callTool({
|
||||
name,
|
||||
arguments: args
|
||||
})
|
||||
|
||||
// TODO: The `McpToolCallResponse` type is too complex for the CF
|
||||
// serialization type inference to handle, so bypass it by serializing to
|
||||
// a string and parse on the other end.
|
||||
return JSON.stringify(toolCallResponse)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
import type { AdminDeployment, PricingPlan } from '@agentic/platform-types'
|
||||
import type { JSONRPCRequest } from '@modelcontextprotocol/sdk/types.js'
|
||||
import { assert } from '@agentic/platform-core'
|
||||
// import { parseDeploymentIdentifier } from '@agentic/platform-validators'
|
||||
// import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'
|
||||
import { DurableObject } from 'cloudflare:workers'
|
||||
|
||||
import type { RawEnv } from './env'
|
||||
import type { AdminConsumer } from './types'
|
||||
|
||||
export class DurableMcpServer extends DurableObject<RawEnv> {
|
||||
// TODO: store this in storage?
|
||||
protected _initData?: {
|
||||
deployment: AdminDeployment
|
||||
consumer?: AdminConsumer
|
||||
pricingPlan?: PricingPlan
|
||||
}
|
||||
|
||||
async init({
|
||||
deployment,
|
||||
consumer,
|
||||
pricingPlan
|
||||
}: {
|
||||
deployment: AdminDeployment
|
||||
consumer?: AdminConsumer
|
||||
pricingPlan?: PricingPlan
|
||||
}) {
|
||||
// const parsedDeploymentIdentifier = parseDeploymentIdentifier(
|
||||
// deployment.identifier
|
||||
// )
|
||||
// assert(
|
||||
// parsedDeploymentIdentifier,
|
||||
// 500,
|
||||
// `Invalid deployment identifier "${deployment.identifier}"`
|
||||
// )
|
||||
// const { projectIdentifier } = parsedDeploymentIdentifier
|
||||
|
||||
// const server = new McpServer({
|
||||
// name: projectIdentifier,
|
||||
// version: deployment.version ?? '0.0.0'
|
||||
// })
|
||||
// const transport = new StreamableHTTPServerTransport({})
|
||||
// server.addTransport(transport)
|
||||
|
||||
this._initData = {
|
||||
deployment,
|
||||
consumer,
|
||||
pricingPlan
|
||||
}
|
||||
}
|
||||
|
||||
async isInitialized() {
|
||||
return this._initData
|
||||
}
|
||||
|
||||
async sayHello(name: string): Promise<string> {
|
||||
assert(this._initData, 500, 'Server not initialized')
|
||||
return `Hello, ${name}!`
|
||||
}
|
||||
|
||||
async onRequest(request: JSONRPCRequest) {
|
||||
const { method, params } = request
|
||||
}
|
||||
}
|
|
@ -1,9 +1,9 @@
|
|||
import { DurableObject } from 'cloudflare:workers'
|
||||
|
||||
import type { RawEnv } from './lib/env'
|
||||
import type { RawEnv } from './env'
|
||||
|
||||
/** A Durable Object's behavior is defined in an exported Javascript class */
|
||||
export class DurableObjectRateLimiter extends DurableObject<RawEnv> {
|
||||
export class DurableRateLimiter extends DurableObject<RawEnv> {
|
||||
/**
|
||||
* The constructor is invoked once upon creation of the Durable Object, i.e. the first call to
|
||||
* `DurableObjectStub::get` for a given identifier (no-op constructors can be omitted)
|
|
@ -11,15 +11,11 @@ export async function enforceRateLimit(
|
|||
{
|
||||
id,
|
||||
interval,
|
||||
maxPerInterval,
|
||||
method,
|
||||
pathname
|
||||
maxPerInterval
|
||||
}: {
|
||||
id?: string
|
||||
interval: number
|
||||
maxPerInterval: number
|
||||
method: string
|
||||
pathname: string
|
||||
}
|
||||
) {
|
||||
assert(id, 400, 'Unauthenticated requests must have a valid IP address')
|
||||
|
@ -29,6 +25,4 @@ export async function enforceRateLimit(
|
|||
assert(id, 500, 'not implemented')
|
||||
assert(interval > 0, 500, 'not implemented')
|
||||
assert(maxPerInterval >= 0, 500, 'not implemented')
|
||||
assert(method, 500, 'not implemented')
|
||||
assert(pathname, 500, 'not implemented')
|
||||
}
|
||||
|
|
|
@ -6,19 +6,44 @@ import {
|
|||
} from '@agentic/platform-hono'
|
||||
import { z } from 'zod'
|
||||
|
||||
import type { DurableMcpClient } from './durable-mcp-client'
|
||||
import type { DurableMcpServer } from './durable-mcp-server'
|
||||
import type { DurableRateLimiter } from './durable-rate-limiter'
|
||||
|
||||
export const envSchema = baseEnvSchema
|
||||
.extend({
|
||||
AGENTIC_API_BASE_URL: z.string().url(),
|
||||
AGENTIC_API_KEY: z.string().nonempty(),
|
||||
|
||||
DO_RATE_LIMITER: z.custom<DurableObjectNamespace>(
|
||||
(ns) => ns && typeof ns === 'object'
|
||||
DO_RATE_LIMITER: z.custom<DurableObjectNamespace<DurableRateLimiter>>(
|
||||
(ns) => isDurableObjectNamespace(ns)
|
||||
),
|
||||
|
||||
DO_MCP_SERVER: z.custom<DurableObjectNamespace<DurableMcpServer>>((ns) =>
|
||||
isDurableObjectNamespace(ns)
|
||||
),
|
||||
|
||||
DO_MCP_CLIENT: z.custom<DurableObjectNamespace<DurableMcpClient>>((ns) =>
|
||||
isDurableObjectNamespace(ns)
|
||||
)
|
||||
})
|
||||
.strip()
|
||||
export type RawEnv = z.infer<typeof envSchema>
|
||||
export type Env = Simplify<ReturnType<typeof parseEnv>>
|
||||
|
||||
export function isDurableObjectNamespace(
|
||||
namespace: unknown
|
||||
): namespace is DurableObjectNamespace {
|
||||
return (
|
||||
typeof namespace === 'object' &&
|
||||
namespace !== null &&
|
||||
'newUniqueId' in namespace &&
|
||||
typeof namespace.newUniqueId === 'function' &&
|
||||
'idFromName' in namespace &&
|
||||
typeof namespace.idFromName === 'function'
|
||||
)
|
||||
}
|
||||
|
||||
export function parseEnv(inputEnv: Record<string, unknown>) {
|
||||
const baseEnv = parseBaseEnv({
|
||||
SERVICE: 'gateway',
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
import type { AdminDeployment, PricingPlan } from '@agentic/platform-types'
|
||||
import { assert } from '@agentic/platform-core'
|
||||
import { parseDeploymentIdentifier } from '@agentic/platform-validators'
|
||||
|
||||
import type { AdminConsumer, GatewayHonoContext } from './types'
|
||||
import { getAdminConsumer } from './get-admin-consumer'
|
||||
import { getAdminDeployment } from './get-admin-deployment'
|
||||
|
||||
export async function resolveMcpEdgeRequest(ctx: GatewayHonoContext): Promise<{
|
||||
deployment: AdminDeployment
|
||||
consumer?: AdminConsumer
|
||||
pricingPlan?: PricingPlan
|
||||
}> {
|
||||
const requestUrl = new URL(ctx.req.url)
|
||||
const { pathname } = requestUrl
|
||||
const requestedDeploymentIdentifier = pathname
|
||||
.replace(/^\//, '')
|
||||
.replace(/\/$/, '')
|
||||
const parsedDeploymentIdentifier = parseDeploymentIdentifier(
|
||||
requestedDeploymentIdentifier
|
||||
)
|
||||
assert(
|
||||
parsedDeploymentIdentifier,
|
||||
404,
|
||||
`Invalid deployment identifier "${requestedDeploymentIdentifier}"`
|
||||
)
|
||||
|
||||
const deployment = await getAdminDeployment(
|
||||
ctx,
|
||||
parsedDeploymentIdentifier.deploymentIdentifier
|
||||
)
|
||||
|
||||
const apiKey = ctx.req.query('apiKey')?.trim()
|
||||
let consumer: AdminConsumer | undefined
|
||||
let pricingPlan: PricingPlan | undefined
|
||||
|
||||
if (apiKey) {
|
||||
consumer = await getAdminConsumer(ctx, apiKey)
|
||||
assert(consumer, 401, `Invalid api key "${apiKey}"`)
|
||||
assert(
|
||||
consumer.isStripeSubscriptionActive,
|
||||
402,
|
||||
`API key "${apiKey}" subscription is not active`
|
||||
)
|
||||
assert(
|
||||
consumer.projectId === deployment.projectId,
|
||||
403,
|
||||
`API key "${apiKey}" is not authorized for project "${deployment.projectId}"`
|
||||
)
|
||||
|
||||
// TODO: Ensure that consumer.plan is compatible with the target deployment?
|
||||
// TODO: This could definitely cause issues when changing pricing plans.
|
||||
|
||||
pricingPlan = deployment.pricingPlans.find(
|
||||
(pricingPlan) => consumer!.plan === pricingPlan.slug
|
||||
)
|
||||
|
||||
// assert(
|
||||
// pricingPlan,
|
||||
// 403,
|
||||
// `API key "${apiKey}" unable to find matching pricing plan for project "${deployment.project}"`
|
||||
// )
|
||||
} else {
|
||||
// For unauthenticated requests, default to a free pricing plan if available.
|
||||
pricingPlan = deployment.pricingPlans.find((plan) => plan.slug === 'free')
|
||||
}
|
||||
|
||||
return {
|
||||
deployment,
|
||||
consumer,
|
||||
pricingPlan
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ import type { PricingPlan, RateLimit } from '@agentic/platform-types'
|
|||
import { assert } from '@agentic/platform-core'
|
||||
import { parseToolIdentifier } from '@agentic/platform-validators'
|
||||
|
||||
import type { DurableMcpClient } from './durable-mcp-client'
|
||||
import type {
|
||||
AdminConsumer,
|
||||
GatewayHonoContext,
|
||||
|
@ -27,15 +28,12 @@ export async function resolveOriginRequest(
|
|||
ctx: GatewayHonoContext
|
||||
): Promise<ResolvedOriginRequest> {
|
||||
const logger = ctx.get('logger')
|
||||
// cf-connecting-ip should always be present, but if not we can fallback to XFF.
|
||||
const ip =
|
||||
ctx.req.header('cf-connecting-ip') ||
|
||||
ctx.req.header('x-forwarded-for') ||
|
||||
undefined
|
||||
const ip = ctx.get('ip')
|
||||
|
||||
const { method } = ctx.req
|
||||
const requestUrl = new URL(ctx.req.url)
|
||||
const { pathname } = requestUrl
|
||||
const requestedToolIdentifier = pathname.replace(/^\//, '')
|
||||
const requestedToolIdentifier = pathname.replace(/^\//, '').replace(/\/$/, '')
|
||||
const parsedToolIdentifier = parseToolIdentifier(requestedToolIdentifier)
|
||||
assert(
|
||||
parsedToolIdentifier,
|
||||
|
@ -65,7 +63,7 @@ export async function resolveOriginRequest(
|
|||
|
||||
let pricingPlan: PricingPlan | undefined
|
||||
let consumer: AdminConsumer | undefined
|
||||
let reportUsage = true
|
||||
let reportUsage = ctx.get('reportUsage') ?? true
|
||||
|
||||
const token = (ctx.req.header('authorization') || '')
|
||||
.replace(/^Bearer /i, '')
|
||||
|
@ -97,9 +95,18 @@ export async function resolveOriginRequest(
|
|||
// 403,
|
||||
// `Auth token "${token}" unable to find matching pricing plan for project "${deployment.project}"`
|
||||
// )
|
||||
|
||||
if (!ctx.get('sessionId')) {
|
||||
ctx.set('sessionId', `${consumer.id}:${deployment.id}`)
|
||||
}
|
||||
} else {
|
||||
// For unauthenticated requests, default to a free pricing plan if available.
|
||||
pricingPlan = deployment.pricingPlans.find((plan) => plan.slug === 'free')
|
||||
|
||||
if (!ctx.get('sessionId')) {
|
||||
assert(ip, 500, 'IP address is required for unauthenticated requests')
|
||||
ctx.set('sessionId', `${ip}:${deployment.projectId}`)
|
||||
}
|
||||
}
|
||||
|
||||
let rateLimit: RateLimit | undefined | null
|
||||
|
@ -165,13 +172,13 @@ export async function resolveOriginRequest(
|
|||
}
|
||||
}
|
||||
|
||||
ctx.set('reportUsage', reportUsage)
|
||||
|
||||
if (rateLimit) {
|
||||
await enforceRateLimit(ctx, {
|
||||
id: consumer?.id ?? ip,
|
||||
interval: rateLimit.interval,
|
||||
maxPerInterval: rateLimit.maxPerInterval,
|
||||
method,
|
||||
pathname
|
||||
maxPerInterval: rateLimit.maxPerInterval
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -191,6 +198,7 @@ export async function resolveOriginRequest(
|
|||
})
|
||||
}
|
||||
|
||||
let mcpClient: DurableObjectStub<DurableMcpClient> | undefined
|
||||
if (originAdapter.type === 'openapi') {
|
||||
const operation = originAdapter.toolToOperationMap[tool.name]
|
||||
assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`)
|
||||
|
@ -201,6 +209,18 @@ export async function resolveOriginRequest(
|
|||
operation,
|
||||
deployment
|
||||
})
|
||||
} else if (originAdapter.type === 'mcp') {
|
||||
const sessionId = ctx.get('sessionId')
|
||||
assert(sessionId, 500, 'Session ID is required for MCP origin requests')
|
||||
|
||||
const id: DurableObjectId = ctx.env.DO_MCP_CLIENT.idFromName(sessionId)
|
||||
mcpClient = ctx.env.DO_MCP_CLIENT.get(id)
|
||||
|
||||
await mcpClient.init({
|
||||
url: deployment.originUrl,
|
||||
name: originAdapter.serverInfo.name,
|
||||
version: originAdapter.serverInfo.version
|
||||
})
|
||||
}
|
||||
|
||||
if (originRequest) {
|
||||
|
@ -209,14 +229,12 @@ export async function resolveOriginRequest(
|
|||
}
|
||||
|
||||
return {
|
||||
originRequest,
|
||||
toolCallArgs,
|
||||
deployment,
|
||||
consumer,
|
||||
tool,
|
||||
ip,
|
||||
method,
|
||||
pricingPlanSlug: pricingPlan?.slug,
|
||||
reportUsage
|
||||
pricingPlan,
|
||||
toolCallArgs,
|
||||
originRequest,
|
||||
mcpClient
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import type {
|
|||
import type {
|
||||
AdminDeployment,
|
||||
Consumer,
|
||||
PricingPlan,
|
||||
Tool,
|
||||
User
|
||||
} from '@agentic/platform-types'
|
||||
|
@ -13,6 +14,7 @@ import type { Client as McpClient } from '@modelcontextprotocol/sdk/client/index
|
|||
import type { Context } from 'hono'
|
||||
import type { Simplify } from 'type-fest'
|
||||
|
||||
import type { DurableMcpClient } from './durable-mcp-client'
|
||||
import type { Env } from './env'
|
||||
|
||||
export type McpToolCallResponse = Simplify<
|
||||
|
@ -29,6 +31,8 @@ export type GatewayHonoVariables = Simplify<
|
|||
DefaultHonoVariables & {
|
||||
client: AgenticApiClient
|
||||
cache: Cache
|
||||
sessionId?: string
|
||||
reportUsage?: boolean
|
||||
}
|
||||
>
|
||||
|
||||
|
@ -47,13 +51,11 @@ export type ToolCallArgs = Record<string, any>
|
|||
export type ResolvedOriginRequest = {
|
||||
deployment: AdminDeployment
|
||||
tool: Tool
|
||||
method: string
|
||||
reportUsage: boolean
|
||||
|
||||
consumer?: AdminConsumer
|
||||
ip?: string
|
||||
pricingPlanSlug?: string
|
||||
pricingPlan?: PricingPlan
|
||||
|
||||
originRequest?: Request
|
||||
toolCallArgs?: ToolCallArgs
|
||||
originRequest?: Request
|
||||
mcpClient?: DurableObjectStub<DurableMcpClient>
|
||||
}
|
||||
|
|
|
@ -1,7 +1,18 @@
|
|||
// import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'
|
||||
// import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||
import { assert, JsonRpcError } from '@agentic/platform-core'
|
||||
import {
|
||||
InitializeRequestSchema,
|
||||
isJSONRPCError,
|
||||
isJSONRPCNotification,
|
||||
isJSONRPCRequest,
|
||||
isJSONRPCResponse,
|
||||
type JSONRPCMessage,
|
||||
JSONRPCMessageSchema
|
||||
} from '@modelcontextprotocol/sdk/types.js'
|
||||
|
||||
import type { GatewayHonoContext, ResolvedOriginRequest } from './lib/types'
|
||||
import type { GatewayHonoContext } from './lib/types'
|
||||
import { resolveMcpEdgeRequest } from './lib/resolve-mcp-edge-request'
|
||||
|
||||
// TODO: https://github.com/modelcontextprotocol/servers/blob/8fb7bbdab73eddb42aba72e8eab81102efe1d544/src/everything/sse.ts
|
||||
// TODO: https://github.com/cloudflare/agents
|
||||
|
@ -10,17 +21,222 @@ import type { GatewayHonoContext, ResolvedOriginRequest } from './lib/types'
|
|||
// string,
|
||||
// StreamableHTTPClientTransport
|
||||
// >()
|
||||
// const server = new McpServer({
|
||||
// name: 'weather',
|
||||
// version: '1.0.0',
|
||||
// capabilities: {
|
||||
// resources: {},
|
||||
// tools: {}
|
||||
// }
|
||||
// })
|
||||
|
||||
export async function handleMCPRequest(
|
||||
_ctx: GatewayHonoContext,
|
||||
_resolvedOriginRequest: ResolvedOriginRequest
|
||||
) {
|
||||
// const server = new McpServer({
|
||||
// name: 'weather',
|
||||
// version: '1.0.0',
|
||||
// capabilities: {
|
||||
// resources: {},
|
||||
// tools: {}
|
||||
// }
|
||||
// })
|
||||
const MAXIMUM_MESSAGE_SIZE_BYTES = 4 * 1024 * 1024 // 4MB
|
||||
|
||||
export async function handleMcpRequest(ctx: GatewayHonoContext) {
|
||||
const request = ctx.req.raw
|
||||
ctx.set('isJsonRpcRequest', true)
|
||||
|
||||
if (request.method !== 'POST') {
|
||||
// We don't yet support GET or DELETE requests
|
||||
throw new JsonRpcError({
|
||||
message: 'Method not allowed',
|
||||
statusCode: 405,
|
||||
jsonRpcErrorCode: -32_000,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
|
||||
// validate the Accept header
|
||||
const acceptHeader = request.headers.get('accept')
|
||||
|
||||
// The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types.
|
||||
if (
|
||||
!acceptHeader?.includes('application/json') ||
|
||||
!acceptHeader.includes('text/event-stream')
|
||||
) {
|
||||
throw new JsonRpcError({
|
||||
message:
|
||||
'Not Acceptable: Client must accept both "application/json" and "text/event-stream"',
|
||||
statusCode: 406,
|
||||
jsonRpcErrorCode: -32_000,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
|
||||
const ct = request.headers.get('content-type')
|
||||
if (!ct?.includes('application/json')) {
|
||||
throw new JsonRpcError({
|
||||
message:
|
||||
'Unsupported Media Type: Content-Type must be "application/json"',
|
||||
statusCode: 415,
|
||||
jsonRpcErrorCode: -32_000,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
|
||||
// Check content length against maximum allowed size
|
||||
const contentLength = Number.parseInt(
|
||||
request.headers.get('content-length') ?? '0',
|
||||
10
|
||||
)
|
||||
if (contentLength > MAXIMUM_MESSAGE_SIZE_BYTES) {
|
||||
throw new JsonRpcError({
|
||||
message: `Request body too large. Maximum size is ${MAXIMUM_MESSAGE_SIZE_BYTES} bytes`,
|
||||
statusCode: 413,
|
||||
jsonRpcErrorCode: -32_000,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
|
||||
let sessionId = request.headers.get('mcp-session-id')
|
||||
let rawMessage: unknown
|
||||
|
||||
try {
|
||||
rawMessage = await request.json()
|
||||
} catch {
|
||||
throw new JsonRpcError({
|
||||
message: 'Parse error: Invalid JSON',
|
||||
statusCode: 400,
|
||||
jsonRpcErrorCode: -32_700,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
|
||||
// Make sure the message is an array to simplify logic
|
||||
const rawMessages = Array.isArray(rawMessage) ? rawMessage : [rawMessage]
|
||||
|
||||
// Try to parse each message as JSON RPC. Fail if any message is invalid
|
||||
const messages: JSONRPCMessage[] = rawMessages.map((msg) => {
|
||||
const parsed = JSONRPCMessageSchema.safeParse(msg)
|
||||
if (!parsed.success) {
|
||||
throw new JsonRpcError({
|
||||
message: 'Parse error: Invalid JSON-RPC message',
|
||||
statusCode: 400,
|
||||
jsonRpcErrorCode: -32_700,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
return parsed.data
|
||||
})
|
||||
|
||||
// Before we pass the messages to the agent, there's another error condition
|
||||
// we need to enforce. Check if this is an initialization request
|
||||
// https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/
|
||||
const isInitializationRequest = messages.some(
|
||||
(msg) => InitializeRequestSchema.safeParse(msg).success
|
||||
)
|
||||
|
||||
if (isInitializationRequest && sessionId) {
|
||||
throw new JsonRpcError({
|
||||
message:
|
||||
'Invalid Request: Initialization requests must not include a sessionId',
|
||||
statusCode: 400,
|
||||
jsonRpcErrorCode: -32_600,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
|
||||
// The initialization request must be the only request in the batch
|
||||
if (isInitializationRequest && messages.length > 1) {
|
||||
throw new JsonRpcError({
|
||||
message: 'Invalid Request: Only one initialization request is allowed',
|
||||
statusCode: 400,
|
||||
jsonRpcErrorCode: -32_600,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
|
||||
// If an Mcp-Session-Id is returned by the server during initialization,
|
||||
// clients using the Streamable HTTP transport MUST include it
|
||||
// in the Mcp-Session-Id header on all of their subsequent HTTP requests.
|
||||
if (!isInitializationRequest && !sessionId) {
|
||||
throw new JsonRpcError({
|
||||
message: 'Bad Request: Mcp-Session-Id header is required',
|
||||
statusCode: 400,
|
||||
jsonRpcErrorCode: -32_000,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
|
||||
// If we don't have a sessionId, we are serving an initialization request
|
||||
// and need to generate a new sessionId
|
||||
sessionId = sessionId ?? ctx.env.DO_MCP_SERVER.newUniqueId().toString()
|
||||
assert(
|
||||
!ctx.get('sessionId'),
|
||||
500,
|
||||
'Session ID should be set by MCP handler for MCP edge requests'
|
||||
)
|
||||
ctx.set('sessionId', sessionId)
|
||||
|
||||
// Fetch the durable mcp server for this session
|
||||
const id = ctx.env.DO_MCP_SERVER.idFromName(`streamable-http:${sessionId}`)
|
||||
const durableMcpServer = ctx.env.DO_MCP_SERVER.get(id)
|
||||
const isInitialized = await durableMcpServer.isInitialized()
|
||||
|
||||
if (!isInitializationRequest && !isInitialized) {
|
||||
// A session id that was never initialized was provided
|
||||
throw new JsonRpcError({
|
||||
message: 'Session not found',
|
||||
statusCode: 404,
|
||||
jsonRpcErrorCode: -32_001,
|
||||
jsonRpcId: null
|
||||
})
|
||||
}
|
||||
|
||||
if (isInitializationRequest) {
|
||||
const { deployment, consumer, pricingPlan } =
|
||||
await resolveMcpEdgeRequest(ctx)
|
||||
|
||||
await durableMcpServer.init({
|
||||
deployment,
|
||||
consumer,
|
||||
pricingPlan
|
||||
})
|
||||
}
|
||||
|
||||
// We've validated and initialized the request! Now it's time to actually
|
||||
// handle the JSON RPC messages in the request and respond with an SSE
|
||||
// stream.
|
||||
|
||||
// Create a Transform Stream for SSE
|
||||
const { readable, writable } = new TransformStream()
|
||||
const writer = writable.getWriter()
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
// If there are no requests, we send the messages downstream and
|
||||
// acknowledge the request with a 202 since we don't expect any responses
|
||||
// back through this connection.
|
||||
const hasOnlyNotificationsOrResponses = messages.every(
|
||||
(msg) => isJSONRPCNotification(msg) || isJSONRPCResponse(msg)
|
||||
)
|
||||
if (hasOnlyNotificationsOrResponses) {
|
||||
// TODO
|
||||
// for (const message of messages) {
|
||||
// ws.send(JSON.stringify(message))
|
||||
// }
|
||||
|
||||
return new Response(null, {
|
||||
status: 202
|
||||
})
|
||||
}
|
||||
|
||||
for (const message of messages) {
|
||||
if (isJSONRPCRequest(message)) {
|
||||
// Add each request id that we send off to a set so that we can keep
|
||||
// track of which requests we still need a response for.
|
||||
// requestIds.add(message.id)
|
||||
}
|
||||
// ws.send(JSON.stringify(message))
|
||||
}
|
||||
|
||||
// Return the streamable http response.
|
||||
return new Response(readable, {
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
Connection: 'keep-alive',
|
||||
'mcp-session-id': sessionId
|
||||
},
|
||||
status: 200
|
||||
})
|
||||
}
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
export function isDurableObjectNamespace(
|
||||
namespace: unknown
|
||||
): namespace is DurableObjectNamespace {
|
||||
return (
|
||||
typeof namespace === 'object' &&
|
||||
namespace !== null &&
|
||||
'newUniqueId' in namespace &&
|
||||
typeof namespace.newUniqueId === 'function' &&
|
||||
'idFromName' in namespace &&
|
||||
typeof namespace.idFromName === 'function'
|
||||
)
|
||||
}
|
|
@ -2,8 +2,11 @@ import { app } from './app'
|
|||
import { type Env, parseEnv } from './lib/env'
|
||||
|
||||
// Export Durable Objects for cloudflare
|
||||
export { DurableObjectRateLimiter } from './durable-object'
|
||||
export { DurableMcpClient } from './lib/durable-mcp-client'
|
||||
export { DurableMcpServer } from './lib/durable-mcp-server'
|
||||
export { DurableRateLimiter } from './lib/durable-rate-limiter'
|
||||
|
||||
// Main worker entrypoint
|
||||
export default {
|
||||
async fetch(
|
||||
request: Request,
|
||||
|
@ -12,6 +15,7 @@ export default {
|
|||
): Promise<Response> {
|
||||
let parsedEnv: Env
|
||||
|
||||
// Validate the environment
|
||||
try {
|
||||
parsedEnv = parseEnv(env)
|
||||
} catch (err: any) {
|
||||
|
@ -23,6 +27,7 @@ export default {
|
|||
})
|
||||
}
|
||||
|
||||
// Handle the request with `hono`
|
||||
return app.fetch(request, parsedEnv, ctx)
|
||||
}
|
||||
} satisfies ExportedHandler<Env>
|
||||
|
|
|
@ -11,14 +11,26 @@
|
|||
"migrations": [
|
||||
{
|
||||
"tag": "v1",
|
||||
"new_sqlite_classes": ["DurableObjectRateLimiter"]
|
||||
"new_sqlite_classes": [
|
||||
"DurableRateLimiter",
|
||||
"DurableMcpServer",
|
||||
"DurableMcpClient"
|
||||
]
|
||||
}
|
||||
],
|
||||
"durable_objects": {
|
||||
"bindings": [
|
||||
{
|
||||
"class_name": "DurableObjectRateLimiter",
|
||||
"class_name": "DurableRateLimiter",
|
||||
"name": "DO_RATE_LIMITER"
|
||||
},
|
||||
{
|
||||
"class_name": "DurableMcpServer",
|
||||
"name": "DO_MCP_SERVER"
|
||||
},
|
||||
{
|
||||
"class_name": "DurableMcpClient",
|
||||
"name": "DO_MCP_CLIENT"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
|
|
@ -19,12 +19,12 @@ export class HttpError extends BaseError {
|
|||
readonly statusCode: ContentfulStatusCode
|
||||
|
||||
constructor({
|
||||
statusCode = 500,
|
||||
message,
|
||||
statusCode = 500,
|
||||
cause
|
||||
}: {
|
||||
statusCode?: ContentfulStatusCode
|
||||
message: string
|
||||
statusCode?: ContentfulStatusCode
|
||||
cause?: unknown
|
||||
}) {
|
||||
super({ message, cause })
|
||||
|
@ -33,6 +33,30 @@ export class HttpError extends BaseError {
|
|||
}
|
||||
}
|
||||
|
||||
export class JsonRpcError extends HttpError {
|
||||
readonly jsonRpcErrorCode: number
|
||||
readonly jsonRpcId: string | number | null
|
||||
|
||||
constructor({
|
||||
message,
|
||||
jsonRpcErrorCode,
|
||||
jsonRpcId = null,
|
||||
statusCode,
|
||||
cause
|
||||
}: {
|
||||
message: string
|
||||
jsonRpcErrorCode: number
|
||||
jsonRpcId?: string | number | null
|
||||
statusCode?: ContentfulStatusCode
|
||||
cause?: unknown
|
||||
}) {
|
||||
super({ message, cause, statusCode })
|
||||
|
||||
this.jsonRpcErrorCode = jsonRpcErrorCode
|
||||
this.jsonRpcId = jsonRpcId
|
||||
}
|
||||
}
|
||||
|
||||
export class ZodValidationError extends HttpError {
|
||||
constructor({
|
||||
statusCode,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import type { Context } from 'hono'
|
||||
import type { HTTPResponseError } from 'hono/types'
|
||||
import type { ContentfulStatusCode } from 'hono/utils/http-status'
|
||||
import { HttpError } from '@agentic/platform-core'
|
||||
import { HttpError, JsonRpcError } from '@agentic/platform-core'
|
||||
import { captureException } from '@sentry/core'
|
||||
import { HTTPException } from 'hono/http-exception'
|
||||
import { HTTPError } from 'ky'
|
||||
|
@ -13,6 +13,9 @@ export function errorHandler(
|
|||
const isProd = ctx.env?.isProd ?? true
|
||||
const logger = ctx.get('logger') ?? console
|
||||
const requestId = ctx.get('requestId')
|
||||
let isJsonRpcRequest = !!ctx.get('isJsonRpcRequest')
|
||||
let jsonRpcId: string | number | null = null
|
||||
let jsonRpcErrorCode: number | undefined
|
||||
|
||||
let message = 'Internal Server Error'
|
||||
let status: ContentfulStatusCode = 500
|
||||
|
@ -26,6 +29,12 @@ export function errorHandler(
|
|||
} else if (err instanceof HTTPError) {
|
||||
message = err.message
|
||||
status = err.response.status as ContentfulStatusCode
|
||||
} else if (err instanceof JsonRpcError) {
|
||||
message = err.message
|
||||
status = err.statusCode
|
||||
jsonRpcId = err.jsonRpcId
|
||||
jsonRpcErrorCode = err.jsonRpcErrorCode
|
||||
isJsonRpcRequest = true
|
||||
} else if (!isProd) {
|
||||
message = err.message ?? message
|
||||
}
|
||||
|
@ -37,5 +46,44 @@ export function errorHandler(
|
|||
logger.warn(status, err)
|
||||
}
|
||||
|
||||
return ctx.json({ error: message, requestId }, status)
|
||||
if (isJsonRpcRequest) {
|
||||
if (jsonRpcErrorCode === undefined) {
|
||||
jsonRpcErrorCode = httpStatusCodeToJsonRpcErrorCode(status)
|
||||
}
|
||||
|
||||
return ctx.json(
|
||||
{
|
||||
jsonrpc: '2.0',
|
||||
error: {
|
||||
message,
|
||||
code: jsonRpcErrorCode
|
||||
},
|
||||
id: jsonRpcId
|
||||
},
|
||||
status
|
||||
)
|
||||
} else {
|
||||
return ctx.json({ error: message, requestId }, status)
|
||||
}
|
||||
}
|
||||
|
||||
/** Error codes defined by the JSON-RPC specification. */
|
||||
export declare enum JsonRpcErrorCodes {
|
||||
ConnectionClosed = -32_000,
|
||||
RequestTimeout = -32_001,
|
||||
ParseError = -32_700,
|
||||
InvalidRequest = -32_600,
|
||||
MethodNotFound = -32_601,
|
||||
InvalidParams = -32_602,
|
||||
InternalError = -32_603
|
||||
}
|
||||
|
||||
export function httpStatusCodeToJsonRpcErrorCode(
|
||||
statusCode: ContentfulStatusCode
|
||||
): number {
|
||||
if (statusCode >= 400 && statusCode < 500) {
|
||||
return JsonRpcErrorCodes.InvalidRequest
|
||||
}
|
||||
|
||||
return JsonRpcErrorCodes.InternalError
|
||||
}
|
||||
|
|
|
@ -23,6 +23,12 @@ export const init = createMiddleware<DefaultHonoEnv>(
|
|||
const logger = new ConsoleLogger(ctx.env, { requestId })
|
||||
ctx.set('logger', logger)
|
||||
|
||||
ctx.set('isJsonRpcRequest', false)
|
||||
|
||||
const ip =
|
||||
ctx.req.header('cf-connecting-ip') || ctx.req.header('x-forwarded-for')
|
||||
ctx.set('ip', ip)
|
||||
|
||||
await next()
|
||||
}
|
||||
)
|
||||
|
|
|
@ -15,6 +15,8 @@ export type DefaultHonoVariables = {
|
|||
sentry: Sentry
|
||||
requestId: string
|
||||
logger: Logger
|
||||
isJsonRpcRequest?: boolean
|
||||
ip?: string
|
||||
}
|
||||
|
||||
export type DefaultHonoBindings = Simplify<
|
||||
|
|
|
@ -56,7 +56,6 @@
|
|||
- resources
|
||||
- prompts
|
||||
- other MCP features?
|
||||
- allow config name to be `project-name` or `@namespace/project-name`?
|
||||
|
||||
## TODO Post-MVP
|
||||
|
||||
|
@ -88,6 +87,7 @@
|
|||
- additional transactional emails
|
||||
- consider `projectName` and `projectSlug` or `projectIdentifier`?
|
||||
- handle or validate against dynamic MCP origin tools
|
||||
- allow config name to be `project-name` or `@namespace/project-name`?
|
||||
|
||||
## License
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue