diff --git a/apps/gateway/src/app.ts b/apps/gateway/src/app.ts index 6264d3c1..c98e9ebe 100644 --- a/apps/gateway/src/app.ts +++ b/apps/gateway/src/app.ts @@ -6,11 +6,9 @@ import { responseTime, sentry } from '@agentic/platform-hono' -import { Client as McpClient } from '@modelcontextprotocol/sdk/client/index.js' -import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { Hono } from 'hono' -import type { GatewayHonoEnv } from './lib/types' +import type { GatewayHonoEnv, McpToolCallResponse } from './lib/types' import { createAgenticClient } from './lib/agentic-client' import { createHttpResponseFromMcpToolCallResponse } from './lib/create-http-response-from-mcp-tool-call-response' import { fetchCache } from './lib/fetch-cache' @@ -31,10 +29,10 @@ app.use(sentry()) app.use( cors({ origin: '*', - allowHeaders: ['Content-Type', 'Authorization'], - allowMethods: ['POST', 'GET', 'OPTIONS'], - exposeHeaders: ['Content-Length'], - maxAge: 600, + allowHeaders: ['Content-Type', 'Authorization', 'mcp-session-id'], + allowMethods: ['GET', 'POST', 'DELETE', 'OPTIONS'], + exposeHeaders: ['Content-Length', 'mcp-session-id'], + maxAge: 86_400, credentials: true }) ) @@ -82,25 +80,22 @@ app.all(async (ctx) => { 500, 'Tool args are required for MCP origin requests' ) - - const transport = new StreamableHTTPClientTransport( - new URL(resolvedOriginRequest.deployment.originUrl) + assert( + resolvedOriginRequest.mcpClient, + 500, + 'MCP client is required for MCP origin requests' ) - const client = new McpClient({ - name: resolvedOriginRequest.deployment.originAdapter.serverInfo.name, - version: - resolvedOriginRequest.deployment.originAdapter.serverInfo.version - }) - - // TODO: re-use client connection across requests - await client.connect(transport) // TODO: add timeout support to the origin tool call? // TODO: add response caching for MCP tool calls - const toolCallResponse = await client.callTool({ - name: resolvedOriginRequest.tool.name, - arguments: resolvedOriginRequest.toolCallArgs - }) + const toolCallResponseString = + await resolvedOriginRequest.mcpClient.callTool({ + name: resolvedOriginRequest.tool.name, + args: resolvedOriginRequest.toolCallArgs + }) + const toolCallResponse = JSON.parse( + toolCallResponseString + ) as McpToolCallResponse originResponse = await createHttpResponseFromMcpToolCallResponse(ctx, { tool: resolvedOriginRequest.tool, diff --git a/apps/gateway/src/lib/durable-mcp-client.ts b/apps/gateway/src/lib/durable-mcp-client.ts new file mode 100644 index 00000000..75b6bacb --- /dev/null +++ b/apps/gateway/src/lib/durable-mcp-client.ts @@ -0,0 +1,83 @@ +import { assert } from '@agentic/platform-core' +import { Client as McpClient } from '@modelcontextprotocol/sdk/client/index.js' +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' +import { DurableObject } from 'cloudflare:workers' + +import type { RawEnv } from './env' + +export type DurableMcpClientInfo = { + url: string + name: string + version: string +} + +// TODO: not sure if there's a better way to handle re-using client connections +// across requests. + +export class DurableMcpClient extends DurableObject { + protected client?: McpClient + protected clientConnectionP?: Promise + + async init(mcpClientInfo: DurableMcpClientInfo) { + const durableMcpClientInfo = + await this.ctx.storage.get('mcp-client-info') + + if (!durableMcpClientInfo) { + await this.ctx.storage.put('mcp-client-info', mcpClientInfo) + } else { + assert( + mcpClientInfo.url === durableMcpClientInfo.url, + 500, + `DurableMcpClientInfo url mismatch: "${mcpClientInfo.url}" vs "${durableMcpClientInfo.url}"` + ) + } + + return this.ensureClientConnection(mcpClientInfo) + } + + async isInitialized(): Promise { + return !!(await this.ctx.storage.get('mcp-client-info')) + } + + async ensureClientConnection(durableMcpClientInfo?: DurableMcpClientInfo) { + if (this.clientConnectionP) return this.clientConnectionP + + durableMcpClientInfo ??= + await this.ctx.storage.get('mcp-client-info') + assert( + durableMcpClientInfo, + 500, + 'DurableMcpClient has not been initialized' + ) + const { name, version, url } = durableMcpClientInfo + + this.client = new McpClient({ + name, + version + }) + + const transport = new StreamableHTTPClientTransport(new URL(url)) + this.clientConnectionP = this.client.connect(transport) + await this.clientConnectionP + } + + async callTool({ + name, + args + }: { + name: string + args: Record + }): Promise { + await this.ensureClientConnection() + + const toolCallResponse = await this.client!.callTool({ + name, + arguments: args + }) + + // TODO: The `McpToolCallResponse` type is too complex for the CF + // serialization type inference to handle, so bypass it by serializing to + // a string and parse on the other end. + return JSON.stringify(toolCallResponse) + } +} diff --git a/apps/gateway/src/lib/durable-mcp-server.ts b/apps/gateway/src/lib/durable-mcp-server.ts new file mode 100644 index 00000000..7234a39d --- /dev/null +++ b/apps/gateway/src/lib/durable-mcp-server.ts @@ -0,0 +1,64 @@ +import type { AdminDeployment, PricingPlan } from '@agentic/platform-types' +import type { JSONRPCRequest } from '@modelcontextprotocol/sdk/types.js' +import { assert } from '@agentic/platform-core' +// import { parseDeploymentIdentifier } from '@agentic/platform-validators' +// import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import { DurableObject } from 'cloudflare:workers' + +import type { RawEnv } from './env' +import type { AdminConsumer } from './types' + +export class DurableMcpServer extends DurableObject { + // TODO: store this in storage? + protected _initData?: { + deployment: AdminDeployment + consumer?: AdminConsumer + pricingPlan?: PricingPlan + } + + async init({ + deployment, + consumer, + pricingPlan + }: { + deployment: AdminDeployment + consumer?: AdminConsumer + pricingPlan?: PricingPlan + }) { + // const parsedDeploymentIdentifier = parseDeploymentIdentifier( + // deployment.identifier + // ) + // assert( + // parsedDeploymentIdentifier, + // 500, + // `Invalid deployment identifier "${deployment.identifier}"` + // ) + // const { projectIdentifier } = parsedDeploymentIdentifier + + // const server = new McpServer({ + // name: projectIdentifier, + // version: deployment.version ?? '0.0.0' + // }) + // const transport = new StreamableHTTPServerTransport({}) + // server.addTransport(transport) + + this._initData = { + deployment, + consumer, + pricingPlan + } + } + + async isInitialized() { + return this._initData + } + + async sayHello(name: string): Promise { + assert(this._initData, 500, 'Server not initialized') + return `Hello, ${name}!` + } + + async onRequest(request: JSONRPCRequest) { + const { method, params } = request + } +} diff --git a/apps/gateway/src/durable-object.ts b/apps/gateway/src/lib/durable-rate-limiter.ts similarity index 88% rename from apps/gateway/src/durable-object.ts rename to apps/gateway/src/lib/durable-rate-limiter.ts index ec36af0b..326dbe2a 100644 --- a/apps/gateway/src/durable-object.ts +++ b/apps/gateway/src/lib/durable-rate-limiter.ts @@ -1,9 +1,9 @@ import { DurableObject } from 'cloudflare:workers' -import type { RawEnv } from './lib/env' +import type { RawEnv } from './env' /** A Durable Object's behavior is defined in an exported Javascript class */ -export class DurableObjectRateLimiter extends DurableObject { +export class DurableRateLimiter extends DurableObject { /** * The constructor is invoked once upon creation of the Durable Object, i.e. the first call to * `DurableObjectStub::get` for a given identifier (no-op constructors can be omitted) diff --git a/apps/gateway/src/lib/enforce-rate-limit.ts b/apps/gateway/src/lib/enforce-rate-limit.ts index e42b46d0..43aa940e 100644 --- a/apps/gateway/src/lib/enforce-rate-limit.ts +++ b/apps/gateway/src/lib/enforce-rate-limit.ts @@ -11,15 +11,11 @@ export async function enforceRateLimit( { id, interval, - maxPerInterval, - method, - pathname + maxPerInterval }: { id?: string interval: number maxPerInterval: number - method: string - pathname: string } ) { assert(id, 400, 'Unauthenticated requests must have a valid IP address') @@ -29,6 +25,4 @@ export async function enforceRateLimit( assert(id, 500, 'not implemented') assert(interval > 0, 500, 'not implemented') assert(maxPerInterval >= 0, 500, 'not implemented') - assert(method, 500, 'not implemented') - assert(pathname, 500, 'not implemented') } diff --git a/apps/gateway/src/lib/env.ts b/apps/gateway/src/lib/env.ts index e43489ca..dbb24b5b 100644 --- a/apps/gateway/src/lib/env.ts +++ b/apps/gateway/src/lib/env.ts @@ -6,19 +6,44 @@ import { } from '@agentic/platform-hono' import { z } from 'zod' +import type { DurableMcpClient } from './durable-mcp-client' +import type { DurableMcpServer } from './durable-mcp-server' +import type { DurableRateLimiter } from './durable-rate-limiter' + export const envSchema = baseEnvSchema .extend({ AGENTIC_API_BASE_URL: z.string().url(), AGENTIC_API_KEY: z.string().nonempty(), - DO_RATE_LIMITER: z.custom( - (ns) => ns && typeof ns === 'object' + DO_RATE_LIMITER: z.custom>( + (ns) => isDurableObjectNamespace(ns) + ), + + DO_MCP_SERVER: z.custom>((ns) => + isDurableObjectNamespace(ns) + ), + + DO_MCP_CLIENT: z.custom>((ns) => + isDurableObjectNamespace(ns) ) }) .strip() export type RawEnv = z.infer export type Env = Simplify> +export function isDurableObjectNamespace( + namespace: unknown +): namespace is DurableObjectNamespace { + return ( + typeof namespace === 'object' && + namespace !== null && + 'newUniqueId' in namespace && + typeof namespace.newUniqueId === 'function' && + 'idFromName' in namespace && + typeof namespace.idFromName === 'function' + ) +} + export function parseEnv(inputEnv: Record) { const baseEnv = parseBaseEnv({ SERVICE: 'gateway', diff --git a/apps/gateway/src/lib/resolve-mcp-edge-request.ts b/apps/gateway/src/lib/resolve-mcp-edge-request.ts new file mode 100644 index 00000000..48444580 --- /dev/null +++ b/apps/gateway/src/lib/resolve-mcp-edge-request.ts @@ -0,0 +1,73 @@ +import type { AdminDeployment, PricingPlan } from '@agentic/platform-types' +import { assert } from '@agentic/platform-core' +import { parseDeploymentIdentifier } from '@agentic/platform-validators' + +import type { AdminConsumer, GatewayHonoContext } from './types' +import { getAdminConsumer } from './get-admin-consumer' +import { getAdminDeployment } from './get-admin-deployment' + +export async function resolveMcpEdgeRequest(ctx: GatewayHonoContext): Promise<{ + deployment: AdminDeployment + consumer?: AdminConsumer + pricingPlan?: PricingPlan +}> { + const requestUrl = new URL(ctx.req.url) + const { pathname } = requestUrl + const requestedDeploymentIdentifier = pathname + .replace(/^\//, '') + .replace(/\/$/, '') + const parsedDeploymentIdentifier = parseDeploymentIdentifier( + requestedDeploymentIdentifier + ) + assert( + parsedDeploymentIdentifier, + 404, + `Invalid deployment identifier "${requestedDeploymentIdentifier}"` + ) + + const deployment = await getAdminDeployment( + ctx, + parsedDeploymentIdentifier.deploymentIdentifier + ) + + const apiKey = ctx.req.query('apiKey')?.trim() + let consumer: AdminConsumer | undefined + let pricingPlan: PricingPlan | undefined + + if (apiKey) { + consumer = await getAdminConsumer(ctx, apiKey) + assert(consumer, 401, `Invalid api key "${apiKey}"`) + assert( + consumer.isStripeSubscriptionActive, + 402, + `API key "${apiKey}" subscription is not active` + ) + assert( + consumer.projectId === deployment.projectId, + 403, + `API key "${apiKey}" is not authorized for project "${deployment.projectId}"` + ) + + // TODO: Ensure that consumer.plan is compatible with the target deployment? + // TODO: This could definitely cause issues when changing pricing plans. + + pricingPlan = deployment.pricingPlans.find( + (pricingPlan) => consumer!.plan === pricingPlan.slug + ) + + // assert( + // pricingPlan, + // 403, + // `API key "${apiKey}" unable to find matching pricing plan for project "${deployment.project}"` + // ) + } else { + // For unauthenticated requests, default to a free pricing plan if available. + pricingPlan = deployment.pricingPlans.find((plan) => plan.slug === 'free') + } + + return { + deployment, + consumer, + pricingPlan + } +} diff --git a/apps/gateway/src/lib/resolve-origin-request.ts b/apps/gateway/src/lib/resolve-origin-request.ts index fa8eeab9..ef45dbb5 100644 --- a/apps/gateway/src/lib/resolve-origin-request.ts +++ b/apps/gateway/src/lib/resolve-origin-request.ts @@ -2,6 +2,7 @@ import type { PricingPlan, RateLimit } from '@agentic/platform-types' import { assert } from '@agentic/platform-core' import { parseToolIdentifier } from '@agentic/platform-validators' +import type { DurableMcpClient } from './durable-mcp-client' import type { AdminConsumer, GatewayHonoContext, @@ -27,15 +28,12 @@ export async function resolveOriginRequest( ctx: GatewayHonoContext ): Promise { const logger = ctx.get('logger') - // cf-connecting-ip should always be present, but if not we can fallback to XFF. - const ip = - ctx.req.header('cf-connecting-ip') || - ctx.req.header('x-forwarded-for') || - undefined + const ip = ctx.get('ip') + const { method } = ctx.req const requestUrl = new URL(ctx.req.url) const { pathname } = requestUrl - const requestedToolIdentifier = pathname.replace(/^\//, '') + const requestedToolIdentifier = pathname.replace(/^\//, '').replace(/\/$/, '') const parsedToolIdentifier = parseToolIdentifier(requestedToolIdentifier) assert( parsedToolIdentifier, @@ -65,7 +63,7 @@ export async function resolveOriginRequest( let pricingPlan: PricingPlan | undefined let consumer: AdminConsumer | undefined - let reportUsage = true + let reportUsage = ctx.get('reportUsage') ?? true const token = (ctx.req.header('authorization') || '') .replace(/^Bearer /i, '') @@ -97,9 +95,18 @@ export async function resolveOriginRequest( // 403, // `Auth token "${token}" unable to find matching pricing plan for project "${deployment.project}"` // ) + + if (!ctx.get('sessionId')) { + ctx.set('sessionId', `${consumer.id}:${deployment.id}`) + } } else { // For unauthenticated requests, default to a free pricing plan if available. pricingPlan = deployment.pricingPlans.find((plan) => plan.slug === 'free') + + if (!ctx.get('sessionId')) { + assert(ip, 500, 'IP address is required for unauthenticated requests') + ctx.set('sessionId', `${ip}:${deployment.projectId}`) + } } let rateLimit: RateLimit | undefined | null @@ -165,13 +172,13 @@ export async function resolveOriginRequest( } } + ctx.set('reportUsage', reportUsage) + if (rateLimit) { await enforceRateLimit(ctx, { id: consumer?.id ?? ip, interval: rateLimit.interval, - maxPerInterval: rateLimit.maxPerInterval, - method, - pathname + maxPerInterval: rateLimit.maxPerInterval }) } @@ -191,6 +198,7 @@ export async function resolveOriginRequest( }) } + let mcpClient: DurableObjectStub | undefined if (originAdapter.type === 'openapi') { const operation = originAdapter.toolToOperationMap[tool.name] assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`) @@ -201,6 +209,18 @@ export async function resolveOriginRequest( operation, deployment }) + } else if (originAdapter.type === 'mcp') { + const sessionId = ctx.get('sessionId') + assert(sessionId, 500, 'Session ID is required for MCP origin requests') + + const id: DurableObjectId = ctx.env.DO_MCP_CLIENT.idFromName(sessionId) + mcpClient = ctx.env.DO_MCP_CLIENT.get(id) + + await mcpClient.init({ + url: deployment.originUrl, + name: originAdapter.serverInfo.name, + version: originAdapter.serverInfo.version + }) } if (originRequest) { @@ -209,14 +229,12 @@ export async function resolveOriginRequest( } return { - originRequest, - toolCallArgs, deployment, consumer, tool, - ip, - method, - pricingPlanSlug: pricingPlan?.slug, - reportUsage + pricingPlan, + toolCallArgs, + originRequest, + mcpClient } } diff --git a/apps/gateway/src/lib/types.ts b/apps/gateway/src/lib/types.ts index da6fca47..5b390fff 100644 --- a/apps/gateway/src/lib/types.ts +++ b/apps/gateway/src/lib/types.ts @@ -6,6 +6,7 @@ import type { import type { AdminDeployment, Consumer, + PricingPlan, Tool, User } from '@agentic/platform-types' @@ -13,6 +14,7 @@ import type { Client as McpClient } from '@modelcontextprotocol/sdk/client/index import type { Context } from 'hono' import type { Simplify } from 'type-fest' +import type { DurableMcpClient } from './durable-mcp-client' import type { Env } from './env' export type McpToolCallResponse = Simplify< @@ -29,6 +31,8 @@ export type GatewayHonoVariables = Simplify< DefaultHonoVariables & { client: AgenticApiClient cache: Cache + sessionId?: string + reportUsage?: boolean } > @@ -47,13 +51,11 @@ export type ToolCallArgs = Record export type ResolvedOriginRequest = { deployment: AdminDeployment tool: Tool - method: string - reportUsage: boolean consumer?: AdminConsumer - ip?: string - pricingPlanSlug?: string + pricingPlan?: PricingPlan - originRequest?: Request toolCallArgs?: ToolCallArgs + originRequest?: Request + mcpClient?: DurableObjectStub } diff --git a/apps/gateway/src/mcp.ts b/apps/gateway/src/mcp.ts index 59f6f203..f6bd1774 100644 --- a/apps/gateway/src/mcp.ts +++ b/apps/gateway/src/mcp.ts @@ -1,7 +1,18 @@ // import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' // import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' +import { assert, JsonRpcError } from '@agentic/platform-core' +import { + InitializeRequestSchema, + isJSONRPCError, + isJSONRPCNotification, + isJSONRPCRequest, + isJSONRPCResponse, + type JSONRPCMessage, + JSONRPCMessageSchema +} from '@modelcontextprotocol/sdk/types.js' -import type { GatewayHonoContext, ResolvedOriginRequest } from './lib/types' +import type { GatewayHonoContext } from './lib/types' +import { resolveMcpEdgeRequest } from './lib/resolve-mcp-edge-request' // TODO: https://github.com/modelcontextprotocol/servers/blob/8fb7bbdab73eddb42aba72e8eab81102efe1d544/src/everything/sse.ts // TODO: https://github.com/cloudflare/agents @@ -10,17 +21,222 @@ import type { GatewayHonoContext, ResolvedOriginRequest } from './lib/types' // string, // StreamableHTTPClientTransport // >() +// const server = new McpServer({ +// name: 'weather', +// version: '1.0.0', +// capabilities: { +// resources: {}, +// tools: {} +// } +// }) -export async function handleMCPRequest( - _ctx: GatewayHonoContext, - _resolvedOriginRequest: ResolvedOriginRequest -) { - // const server = new McpServer({ - // name: 'weather', - // version: '1.0.0', - // capabilities: { - // resources: {}, - // tools: {} - // } - // }) +const MAXIMUM_MESSAGE_SIZE_BYTES = 4 * 1024 * 1024 // 4MB + +export async function handleMcpRequest(ctx: GatewayHonoContext) { + const request = ctx.req.raw + ctx.set('isJsonRpcRequest', true) + + if (request.method !== 'POST') { + // We don't yet support GET or DELETE requests + throw new JsonRpcError({ + message: 'Method not allowed', + statusCode: 405, + jsonRpcErrorCode: -32_000, + jsonRpcId: null + }) + } + + // validate the Accept header + const acceptHeader = request.headers.get('accept') + + // The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types. + if ( + !acceptHeader?.includes('application/json') || + !acceptHeader.includes('text/event-stream') + ) { + throw new JsonRpcError({ + message: + 'Not Acceptable: Client must accept both "application/json" and "text/event-stream"', + statusCode: 406, + jsonRpcErrorCode: -32_000, + jsonRpcId: null + }) + } + + const ct = request.headers.get('content-type') + if (!ct?.includes('application/json')) { + throw new JsonRpcError({ + message: + 'Unsupported Media Type: Content-Type must be "application/json"', + statusCode: 415, + jsonRpcErrorCode: -32_000, + jsonRpcId: null + }) + } + + // Check content length against maximum allowed size + const contentLength = Number.parseInt( + request.headers.get('content-length') ?? '0', + 10 + ) + if (contentLength > MAXIMUM_MESSAGE_SIZE_BYTES) { + throw new JsonRpcError({ + message: `Request body too large. Maximum size is ${MAXIMUM_MESSAGE_SIZE_BYTES} bytes`, + statusCode: 413, + jsonRpcErrorCode: -32_000, + jsonRpcId: null + }) + } + + let sessionId = request.headers.get('mcp-session-id') + let rawMessage: unknown + + try { + rawMessage = await request.json() + } catch { + throw new JsonRpcError({ + message: 'Parse error: Invalid JSON', + statusCode: 400, + jsonRpcErrorCode: -32_700, + jsonRpcId: null + }) + } + + // Make sure the message is an array to simplify logic + const rawMessages = Array.isArray(rawMessage) ? rawMessage : [rawMessage] + + // Try to parse each message as JSON RPC. Fail if any message is invalid + const messages: JSONRPCMessage[] = rawMessages.map((msg) => { + const parsed = JSONRPCMessageSchema.safeParse(msg) + if (!parsed.success) { + throw new JsonRpcError({ + message: 'Parse error: Invalid JSON-RPC message', + statusCode: 400, + jsonRpcErrorCode: -32_700, + jsonRpcId: null + }) + } + return parsed.data + }) + + // Before we pass the messages to the agent, there's another error condition + // we need to enforce. Check if this is an initialization request + // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ + const isInitializationRequest = messages.some( + (msg) => InitializeRequestSchema.safeParse(msg).success + ) + + if (isInitializationRequest && sessionId) { + throw new JsonRpcError({ + message: + 'Invalid Request: Initialization requests must not include a sessionId', + statusCode: 400, + jsonRpcErrorCode: -32_600, + jsonRpcId: null + }) + } + + // The initialization request must be the only request in the batch + if (isInitializationRequest && messages.length > 1) { + throw new JsonRpcError({ + message: 'Invalid Request: Only one initialization request is allowed', + statusCode: 400, + jsonRpcErrorCode: -32_600, + jsonRpcId: null + }) + } + + // If an Mcp-Session-Id is returned by the server during initialization, + // clients using the Streamable HTTP transport MUST include it + // in the Mcp-Session-Id header on all of their subsequent HTTP requests. + if (!isInitializationRequest && !sessionId) { + throw new JsonRpcError({ + message: 'Bad Request: Mcp-Session-Id header is required', + statusCode: 400, + jsonRpcErrorCode: -32_000, + jsonRpcId: null + }) + } + + // If we don't have a sessionId, we are serving an initialization request + // and need to generate a new sessionId + sessionId = sessionId ?? ctx.env.DO_MCP_SERVER.newUniqueId().toString() + assert( + !ctx.get('sessionId'), + 500, + 'Session ID should be set by MCP handler for MCP edge requests' + ) + ctx.set('sessionId', sessionId) + + // Fetch the durable mcp server for this session + const id = ctx.env.DO_MCP_SERVER.idFromName(`streamable-http:${sessionId}`) + const durableMcpServer = ctx.env.DO_MCP_SERVER.get(id) + const isInitialized = await durableMcpServer.isInitialized() + + if (!isInitializationRequest && !isInitialized) { + // A session id that was never initialized was provided + throw new JsonRpcError({ + message: 'Session not found', + statusCode: 404, + jsonRpcErrorCode: -32_001, + jsonRpcId: null + }) + } + + if (isInitializationRequest) { + const { deployment, consumer, pricingPlan } = + await resolveMcpEdgeRequest(ctx) + + await durableMcpServer.init({ + deployment, + consumer, + pricingPlan + }) + } + + // We've validated and initialized the request! Now it's time to actually + // handle the JSON RPC messages in the request and respond with an SSE + // stream. + + // Create a Transform Stream for SSE + const { readable, writable } = new TransformStream() + const writer = writable.getWriter() + const encoder = new TextEncoder() + + // If there are no requests, we send the messages downstream and + // acknowledge the request with a 202 since we don't expect any responses + // back through this connection. + const hasOnlyNotificationsOrResponses = messages.every( + (msg) => isJSONRPCNotification(msg) || isJSONRPCResponse(msg) + ) + if (hasOnlyNotificationsOrResponses) { + // TODO + // for (const message of messages) { + // ws.send(JSON.stringify(message)) + // } + + return new Response(null, { + status: 202 + }) + } + + for (const message of messages) { + if (isJSONRPCRequest(message)) { + // Add each request id that we send off to a set so that we can keep + // track of which requests we still need a response for. + // requestIds.add(message.id) + } + // ws.send(JSON.stringify(message)) + } + + // Return the streamable http response. + return new Response(readable, { + headers: { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + 'mcp-session-id': sessionId + }, + status: 200 + }) } diff --git a/apps/gateway/src/mcp/utils.ts b/apps/gateway/src/mcp/utils.ts new file mode 100644 index 00000000..2a91b2fb --- /dev/null +++ b/apps/gateway/src/mcp/utils.ts @@ -0,0 +1,12 @@ +export function isDurableObjectNamespace( + namespace: unknown +): namespace is DurableObjectNamespace { + return ( + typeof namespace === 'object' && + namespace !== null && + 'newUniqueId' in namespace && + typeof namespace.newUniqueId === 'function' && + 'idFromName' in namespace && + typeof namespace.idFromName === 'function' + ) +} diff --git a/apps/gateway/src/worker.ts b/apps/gateway/src/worker.ts index a5ba513b..8fa1cfd7 100644 --- a/apps/gateway/src/worker.ts +++ b/apps/gateway/src/worker.ts @@ -2,8 +2,11 @@ import { app } from './app' import { type Env, parseEnv } from './lib/env' // Export Durable Objects for cloudflare -export { DurableObjectRateLimiter } from './durable-object' +export { DurableMcpClient } from './lib/durable-mcp-client' +export { DurableMcpServer } from './lib/durable-mcp-server' +export { DurableRateLimiter } from './lib/durable-rate-limiter' +// Main worker entrypoint export default { async fetch( request: Request, @@ -12,6 +15,7 @@ export default { ): Promise { let parsedEnv: Env + // Validate the environment try { parsedEnv = parseEnv(env) } catch (err: any) { @@ -23,6 +27,7 @@ export default { }) } + // Handle the request with `hono` return app.fetch(request, parsedEnv, ctx) } } satisfies ExportedHandler diff --git a/apps/gateway/wrangler.jsonc b/apps/gateway/wrangler.jsonc index dbd845da..f4b5e6ed 100644 --- a/apps/gateway/wrangler.jsonc +++ b/apps/gateway/wrangler.jsonc @@ -11,14 +11,26 @@ "migrations": [ { "tag": "v1", - "new_sqlite_classes": ["DurableObjectRateLimiter"] + "new_sqlite_classes": [ + "DurableRateLimiter", + "DurableMcpServer", + "DurableMcpClient" + ] } ], "durable_objects": { "bindings": [ { - "class_name": "DurableObjectRateLimiter", + "class_name": "DurableRateLimiter", "name": "DO_RATE_LIMITER" + }, + { + "class_name": "DurableMcpServer", + "name": "DO_MCP_SERVER" + }, + { + "class_name": "DurableMcpClient", + "name": "DO_MCP_CLIENT" } ] }, diff --git a/packages/core/src/errors.ts b/packages/core/src/errors.ts index 8e7a3da3..9ceaeee0 100644 --- a/packages/core/src/errors.ts +++ b/packages/core/src/errors.ts @@ -19,12 +19,12 @@ export class HttpError extends BaseError { readonly statusCode: ContentfulStatusCode constructor({ - statusCode = 500, message, + statusCode = 500, cause }: { - statusCode?: ContentfulStatusCode message: string + statusCode?: ContentfulStatusCode cause?: unknown }) { super({ message, cause }) @@ -33,6 +33,30 @@ export class HttpError extends BaseError { } } +export class JsonRpcError extends HttpError { + readonly jsonRpcErrorCode: number + readonly jsonRpcId: string | number | null + + constructor({ + message, + jsonRpcErrorCode, + jsonRpcId = null, + statusCode, + cause + }: { + message: string + jsonRpcErrorCode: number + jsonRpcId?: string | number | null + statusCode?: ContentfulStatusCode + cause?: unknown + }) { + super({ message, cause, statusCode }) + + this.jsonRpcErrorCode = jsonRpcErrorCode + this.jsonRpcId = jsonRpcId + } +} + export class ZodValidationError extends HttpError { constructor({ statusCode, diff --git a/packages/hono/src/error-handler.ts b/packages/hono/src/error-handler.ts index 8f0b48bd..33d0b27b 100644 --- a/packages/hono/src/error-handler.ts +++ b/packages/hono/src/error-handler.ts @@ -1,7 +1,7 @@ import type { Context } from 'hono' import type { HTTPResponseError } from 'hono/types' import type { ContentfulStatusCode } from 'hono/utils/http-status' -import { HttpError } from '@agentic/platform-core' +import { HttpError, JsonRpcError } from '@agentic/platform-core' import { captureException } from '@sentry/core' import { HTTPException } from 'hono/http-exception' import { HTTPError } from 'ky' @@ -13,6 +13,9 @@ export function errorHandler( const isProd = ctx.env?.isProd ?? true const logger = ctx.get('logger') ?? console const requestId = ctx.get('requestId') + let isJsonRpcRequest = !!ctx.get('isJsonRpcRequest') + let jsonRpcId: string | number | null = null + let jsonRpcErrorCode: number | undefined let message = 'Internal Server Error' let status: ContentfulStatusCode = 500 @@ -26,6 +29,12 @@ export function errorHandler( } else if (err instanceof HTTPError) { message = err.message status = err.response.status as ContentfulStatusCode + } else if (err instanceof JsonRpcError) { + message = err.message + status = err.statusCode + jsonRpcId = err.jsonRpcId + jsonRpcErrorCode = err.jsonRpcErrorCode + isJsonRpcRequest = true } else if (!isProd) { message = err.message ?? message } @@ -37,5 +46,44 @@ export function errorHandler( logger.warn(status, err) } - return ctx.json({ error: message, requestId }, status) + if (isJsonRpcRequest) { + if (jsonRpcErrorCode === undefined) { + jsonRpcErrorCode = httpStatusCodeToJsonRpcErrorCode(status) + } + + return ctx.json( + { + jsonrpc: '2.0', + error: { + message, + code: jsonRpcErrorCode + }, + id: jsonRpcId + }, + status + ) + } else { + return ctx.json({ error: message, requestId }, status) + } +} + +/** Error codes defined by the JSON-RPC specification. */ +export declare enum JsonRpcErrorCodes { + ConnectionClosed = -32_000, + RequestTimeout = -32_001, + ParseError = -32_700, + InvalidRequest = -32_600, + MethodNotFound = -32_601, + InvalidParams = -32_602, + InternalError = -32_603 +} + +export function httpStatusCodeToJsonRpcErrorCode( + statusCode: ContentfulStatusCode +): number { + if (statusCode >= 400 && statusCode < 500) { + return JsonRpcErrorCodes.InvalidRequest + } + + return JsonRpcErrorCodes.InternalError } diff --git a/packages/hono/src/middleware/init.ts b/packages/hono/src/middleware/init.ts index 9e7c4b08..be80683d 100644 --- a/packages/hono/src/middleware/init.ts +++ b/packages/hono/src/middleware/init.ts @@ -23,6 +23,12 @@ export const init = createMiddleware( const logger = new ConsoleLogger(ctx.env, { requestId }) ctx.set('logger', logger) + ctx.set('isJsonRpcRequest', false) + + const ip = + ctx.req.header('cf-connecting-ip') || ctx.req.header('x-forwarded-for') + ctx.set('ip', ip) + await next() } ) diff --git a/packages/hono/src/types.ts b/packages/hono/src/types.ts index aacfd187..80924d1a 100644 --- a/packages/hono/src/types.ts +++ b/packages/hono/src/types.ts @@ -15,6 +15,8 @@ export type DefaultHonoVariables = { sentry: Sentry requestId: string logger: Logger + isJsonRpcRequest?: boolean + ip?: string } export type DefaultHonoBindings = Simplify< diff --git a/readme.md b/readme.md index a671f433..7b36a5a4 100644 --- a/readme.md +++ b/readme.md @@ -56,7 +56,6 @@ - resources - prompts - other MCP features? -- allow config name to be `project-name` or `@namespace/project-name`? ## TODO Post-MVP @@ -88,6 +87,7 @@ - additional transactional emails - consider `projectName` and `projectSlug` or `projectIdentifier`? - handle or validate against dynamic MCP origin tools +- allow config name to be `project-name` or `@namespace/project-name`? ## License