From ce2f3afc4177d266320cddc5559c4650e4662a30 Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Mon, 9 Jun 2025 04:05:42 +0700 Subject: [PATCH] feat: mcp gateway work wip --- apps/gateway/src/app.ts | 13 +- apps/gateway/src/lib/durable-mcp-client.ts | 32 ++-- apps/gateway/src/lib/durable-mcp-server.ts | 160 +++++++++++++----- .../gateway/src/lib/resolve-origin-request.ts | 48 +++++- apps/gateway/src/lib/types.ts | 47 ++++- apps/gateway/src/lib/update-origin-request.ts | 9 +- readme.md | 12 +- 7 files changed, 233 insertions(+), 88 deletions(-) diff --git a/apps/gateway/src/app.ts b/apps/gateway/src/app.ts index c98e9ebe..036de28a 100644 --- a/apps/gateway/src/app.ts +++ b/apps/gateway/src/app.ts @@ -81,7 +81,7 @@ app.all(async (ctx) => { 'Tool args are required for MCP origin requests' ) assert( - resolvedOriginRequest.mcpClient, + resolvedOriginRequest.originMcpClient, 500, 'MCP client is required for MCP origin requests' ) @@ -89,9 +89,10 @@ app.all(async (ctx) => { // TODO: add timeout support to the origin tool call? // TODO: add response caching for MCP tool calls const toolCallResponseString = - await resolvedOriginRequest.mcpClient.callTool({ + await resolvedOriginRequest.originMcpClient.callTool({ name: resolvedOriginRequest.tool.name, - args: resolvedOriginRequest.toolCallArgs + args: resolvedOriginRequest.toolCallArgs, + metadata: resolvedOriginRequest.originMcpRequestMetadata! }) const toolCallResponse = JSON.parse( toolCallResponseString @@ -124,12 +125,6 @@ app.all(async (ctx) => { res.headers.delete('server-timing') res.headers.delete('reporting-endpoints') - // const id: DurableObjectId = env.DO_RATE_LIMITER.idFromName('foo') - // const stub = env.DO_RATE_LIMITER.get(id) - // const greeting = await stub.sayHello('world') - - // return new Response(greeting) - return res // TODO: move this `finally` block to a middleware handler diff --git a/apps/gateway/src/lib/durable-mcp-client.ts b/apps/gateway/src/lib/durable-mcp-client.ts index 6c5d3678..67ba666b 100644 --- a/apps/gateway/src/lib/durable-mcp-client.ts +++ b/apps/gateway/src/lib/durable-mcp-client.ts @@ -4,6 +4,7 @@ import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/ import { DurableObject } from 'cloudflare:workers' import type { RawEnv } from './env' +import type { AgenticMcpRequestMetadata } from './types' export type DurableMcpClientInfo = { url: string @@ -12,23 +13,25 @@ export type DurableMcpClientInfo = { } // TODO: not sure if there's a better way to handle re-using client connections -// across requests. Maybe we use one DurableObject per customer<>originUrl connection? +// across requests. Maybe we use one DurableObject per unique +// customer<>DurableMcpClientInfo connection? +// Currently using `sessionId` export class DurableMcpClient extends DurableObject { protected client?: McpClient protected clientConnectionP?: Promise async init(mcpClientInfo: DurableMcpClientInfo) { - const durableMcpClientInfo = + const existingMcpClientInfo = await this.ctx.storage.get('mcp-client-info') - if (!durableMcpClientInfo) { + if (!existingMcpClientInfo) { await this.ctx.storage.put('mcp-client-info', mcpClientInfo) } else { assert( - mcpClientInfo.url === durableMcpClientInfo.url, + mcpClientInfo.url === existingMcpClientInfo.url, 500, - `DurableMcpClientInfo url mismatch: "${mcpClientInfo.url}" vs "${durableMcpClientInfo.url}"` + `DurableMcpClientInfo url mismatch: "${mcpClientInfo.url}" vs "${existingMcpClientInfo.url}"` ) } @@ -39,17 +42,13 @@ export class DurableMcpClient extends DurableObject { return !!(await this.ctx.storage.get('mcp-client-info')) } - async ensureClientConnection(durableMcpClientInfo?: DurableMcpClientInfo) { + async ensureClientConnection(mcpClientInfo?: DurableMcpClientInfo) { if (this.clientConnectionP) return this.clientConnectionP - durableMcpClientInfo ??= + mcpClientInfo ??= await this.ctx.storage.get('mcp-client-info') - assert( - durableMcpClientInfo, - 500, - 'DurableMcpClient has not been initialized' - ) - const { name, version, url } = durableMcpClientInfo + assert(mcpClientInfo, 500, 'DurableMcpClient has not been initialized') + const { name, version, url } = mcpClientInfo this.client = new McpClient({ name, @@ -63,16 +62,19 @@ export class DurableMcpClient extends DurableObject { async callTool({ name, - args + args, + metadata }: { name: string args: Record + metadata: AgenticMcpRequestMetadata }): Promise { await this.ensureClientConnection() const toolCallResponse = await this.client!.callTool({ name, - arguments: args + arguments: args, + _meta: { agentic: metadata } }) // TODO: The `McpToolCallResponse` type is seemingly too complex for the CF diff --git a/apps/gateway/src/lib/durable-mcp-server.ts b/apps/gateway/src/lib/durable-mcp-server.ts index 7234a39d..71d07758 100644 --- a/apps/gateway/src/lib/durable-mcp-server.ts +++ b/apps/gateway/src/lib/durable-mcp-server.ts @@ -1,64 +1,136 @@ import type { AdminDeployment, PricingPlan } from '@agentic/platform-types' import type { JSONRPCRequest } from '@modelcontextprotocol/sdk/types.js' +// import type { JSONRPCRequest } from '@modelcontextprotocol/sdk/types.js' import { assert } from '@agentic/platform-core' -// import { parseDeploymentIdentifier } from '@agentic/platform-validators' -// import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import { parseDeploymentIdentifier } from '@agentic/platform-validators' +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js' import { DurableObject } from 'cloudflare:workers' import type { RawEnv } from './env' import type { AdminConsumer } from './types' +export type DurableMcpServerInfo = { + deployment: AdminDeployment + consumer?: AdminConsumer + pricingPlan?: PricingPlan +} + export class DurableMcpServer extends DurableObject { - // TODO: store this in storage? - protected _initData?: { - deployment: AdminDeployment - consumer?: AdminConsumer - pricingPlan?: PricingPlan - } + protected server?: McpServer + protected serverTransport?: StreamableHTTPServerTransport + protected serverConnectionP?: Promise - async init({ - deployment, - consumer, - pricingPlan - }: { - deployment: AdminDeployment - consumer?: AdminConsumer - pricingPlan?: PricingPlan - }) { - // const parsedDeploymentIdentifier = parseDeploymentIdentifier( - // deployment.identifier - // ) - // assert( - // parsedDeploymentIdentifier, - // 500, - // `Invalid deployment identifier "${deployment.identifier}"` - // ) - // const { projectIdentifier } = parsedDeploymentIdentifier + async init(mcpServerInfo: DurableMcpServerInfo) { + const existingMcpServerInfo = + await this.ctx.storage.get('mcp-server-info') - // const server = new McpServer({ - // name: projectIdentifier, - // version: deployment.version ?? '0.0.0' - // }) - // const transport = new StreamableHTTPServerTransport({}) - // server.addTransport(transport) - - this._initData = { - deployment, - consumer, - pricingPlan + if (!existingMcpServerInfo) { + await this.ctx.storage.put('mcp-server-info', mcpServerInfo) + } else { + assert( + mcpServerInfo.deployment.id === existingMcpServerInfo.deployment.id, + 500, + `DurableMcpServerInfo deployment id mismatch: "${mcpServerInfo.deployment.id}" vs "${existingMcpServerInfo.deployment.id}"` + ) } + + return this.ensureServerConnection(mcpServerInfo) } - async isInitialized() { - return this._initData + async isInitialized(): Promise { + return !!(await this.ctx.storage.get('mcp-server-info')) } - async sayHello(name: string): Promise { - assert(this._initData, 500, 'Server not initialized') - return `Hello, ${name}!` + async ensureServerConnection(mcpServerInfo?: DurableMcpServerInfo) { + if (this.serverConnectionP) return this.serverConnectionP + + mcpServerInfo ??= + await this.ctx.storage.get('mcp-server-info') + assert(mcpServerInfo, 500, 'DurableMcpServer has not been initialized') + const { deployment } = mcpServerInfo + + const parsedDeploymentIdentifier = parseDeploymentIdentifier( + deployment.identifier + ) + assert( + parsedDeploymentIdentifier, + 500, + `Invalid deployment identifier "${deployment.identifier}"` + ) + const { projectIdentifier } = parsedDeploymentIdentifier + + this.server = new McpServer({ + name: projectIdentifier, + version: deployment.version ?? '0.0.0' + }) + + for (const tool of deployment.tools) { + this.server.registerTool( + tool.name, + { + description: tool.description, + inputSchema: tool.inputSchema as any, // TODO: investigate types + outputSchema: tool.outputSchema as any, // TODO: investigate types + annotations: tool.annotations + }, + (_args: Record) => { + assert(false, 500, `Tool call not implemented: ${tool.name}`) + + // TODO??? + return { + content: [], + _meta: { + toolName: tool.name + } + } + } + ) + } + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => { + // TODO: improve this + return crypto.randomUUID() + }, + onsessioninitialized: (sessionId) => { + // TODO: improve this + // eslint-disable-next-line no-console + console.log(`Session initialized: ${sessionId}`) + } + }) + this.serverConnectionP = this.server.connect(transport) + + return this.serverConnectionP } - async onRequest(request: JSONRPCRequest) { - const { method, params } = request + // async fetch(request: Request) { + // await this.ensureServerConnection() + // const { readable, writable } = new TransformStream() + // const writer = writable.getWriter() + // const encoder = new TextEncoder() + + // const response = new Response(readable, { + // headers: { + // 'Content-Type': 'text/event-stream', + // 'Cache-Control': 'no-cache', + // Connection: 'keep-alive' + // // 'mcp-session-id': sessionId + // } + // }) + + // await this.serverTransport!.handleRequest(request, response) + // } + + async onRequest(message: JSONRPCRequest) { + await this.ensureServerConnection() + + // We need to map every incoming message to the connection that it came in on + // so that we can send relevant responses and notifications back on the same connection + // if (isJSONRPCRequest(message)) { + // this._requestIdToConnectionId.set(message.id.toString(), connection.id); + // } + + this.serverTransport!.onmessage?.(message) } } diff --git a/apps/gateway/src/lib/resolve-origin-request.ts b/apps/gateway/src/lib/resolve-origin-request.ts index d0a3ee2b..657fb530 100644 --- a/apps/gateway/src/lib/resolve-origin-request.ts +++ b/apps/gateway/src/lib/resolve-origin-request.ts @@ -1,10 +1,14 @@ import type { PricingPlan, RateLimit } from '@agentic/platform-types' import { assert } from '@agentic/platform-core' -import { parseToolIdentifier } from '@agentic/platform-validators' +import { + parseDeploymentIdentifier, + parseToolIdentifier +} from '@agentic/platform-validators' import type { DurableMcpClient } from './durable-mcp-client' import type { AdminConsumer, + AgenticMcpRequestMetadata, GatewayHonoContext, ResolvedOriginRequest, ToolCallArgs @@ -183,8 +187,10 @@ export async function resolveOriginRequest( } const { originAdapter } = deployment - let originRequest: Request | undefined let toolCallArgs: ToolCallArgs | undefined + let originRequest: Request | undefined + let originMcpClient: DurableObjectStub | undefined + let originMcpRequestMetadata: AgenticMcpRequestMetadata | undefined if (originAdapter.type === 'raw') { const originRequestUrl = `${deployment.originUrl}/${toolName}${requestUrl.search}` @@ -198,7 +204,6 @@ export async function resolveOriginRequest( }) } - let mcpClient: DurableObjectStub | undefined if (originAdapter.type === 'openapi') { const operation = originAdapter.toolToOperationMap[tool.name] assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`) @@ -214,13 +219,43 @@ export async function resolveOriginRequest( assert(sessionId, 500, 'Session ID is required for MCP origin requests') const id: DurableObjectId = ctx.env.DO_MCP_CLIENT.idFromName(sessionId) - mcpClient = ctx.env.DO_MCP_CLIENT.get(id) + originMcpClient = ctx.env.DO_MCP_CLIENT.get(id) - await mcpClient.init({ + await originMcpClient.init({ url: deployment.originUrl, name: originAdapter.serverInfo.name, version: originAdapter.serverInfo.version }) + + const parsedDeploymentIdentifier = parseDeploymentIdentifier( + deployment.identifier + ) + assert( + parsedDeploymentIdentifier, + 500, + `Internal error: deployment identifier "${deployment.identifier}" is invalid` + ) + const { projectIdentifier } = parsedDeploymentIdentifier + + originMcpRequestMetadata = { + agenticProxySecret: deployment._secret, + sessionId, + ip, + isCustomerSubscriptionActive: !!consumer?.isStripeSubscriptionActive, + customerId: consumer?.id, + customerSubscriptionPlan: consumer?.plan, + customerSubscriptionStatus: consumer?.stripeStatus, + userId: consumer?.user.id, + userEmail: consumer?.user.email, + userUsername: consumer?.user.username, + userName: consumer?.user.name, + userCreatedAt: consumer?.user.createdAt, + userUpdatedAt: consumer?.user.updatedAt, + deploymentId: deployment.id, + deploymentIdentifier: deployment.identifier, + projectId: deployment.projectId, + projectIdentifier + } as AgenticMcpRequestMetadata } if (originRequest) { @@ -237,6 +272,7 @@ export async function resolveOriginRequest( pricingPlan, toolCallArgs, originRequest, - mcpClient + originMcpClient, + originMcpRequestMetadata } } diff --git a/apps/gateway/src/lib/types.ts b/apps/gateway/src/lib/types.ts index 5b390fff..1e0fa449 100644 --- a/apps/gateway/src/lib/types.ts +++ b/apps/gateway/src/lib/types.ts @@ -57,5 +57,50 @@ export type ResolvedOriginRequest = { toolCallArgs?: ToolCallArgs originRequest?: Request - mcpClient?: DurableObjectStub + originMcpClient?: DurableObjectStub + originMcpRequestMetadata?: AgenticMcpRequestMetadata } + +export type AgenticMcpRequestMetadata = { + agenticProxySecret: string + sessionId: string + isCustomerSubscriptionActive: boolean + + customerId?: string + customerSubscriptionStatus?: string + customerSubscriptionPlan?: string + + userId?: string + userEmail?: string + userUsername?: string + userName?: string + userCreatedAt?: string + userUpdatedAt?: string + + deploymentId: string + deploymentIdentifier: string + projectId: string + projectIdentifier: string + + ip?: string +} & ( + | { + // If the customer has an active subscription, these fields are guaranteed + // to be present in the metadata. + isCustomerSubscriptionActive: true + + customerId: string + customerSubscriptionStatus: string + + userId: string + userEmail: string + userUsername: string + userCreatedAt: string + userUpdatedAt: string + } + | { + // If the customer does not have an active subscription, then the customer + // fields may or may not be present. + isCustomerSubscriptionActive: false + } +) diff --git a/apps/gateway/src/lib/update-origin-request.ts b/apps/gateway/src/lib/update-origin-request.ts index 247d5c59..13284d8c 100644 --- a/apps/gateway/src/lib/update-origin-request.ts +++ b/apps/gateway/src/lib/update-origin-request.ts @@ -53,17 +53,18 @@ export function updateOriginRequest( originRequest.headers.delete('x-forwarded-for') if (consumer) { - originRequest.headers.set('x-agentic-consumer', consumer.id) - originRequest.headers.set('x-agentic-user', consumer.user.id) + originRequest.headers.set('x-agentic-customer-id', consumer.id) originRequest.headers.set( - 'x-agentic-is-subscription-active', + 'x-agentic-is-customer-subscription-active', consumer.isStripeSubscriptionActive.toString() ) originRequest.headers.set( - 'x-agentic-subscription-status', + 'x-agentic-customer-subscription-status', consumer.stripeStatus ) + + originRequest.headers.set('x-agentic-user', consumer.user.id) originRequest.headers.set('x-agentic-user-email', consumer.user.email) originRequest.headers.set('x-agentic-user-username', consumer.user.username) originRequest.headers.set( diff --git a/readme.md b/readme.md index 7b36a5a4..0e484412 100644 --- a/readme.md +++ b/readme.md @@ -31,6 +31,7 @@ - auth - custom auth pages for `openauth` - API gateway + - **do I just ditch the public REST interface and focus on MCP?** - enforce rate limits - how to handle binary bodies and responses? - add support for `immutable` in `toolConfigs` @@ -39,16 +40,9 @@ - how do I use consumer auth tokens with this flow? - how does oauth work with this flow? - **Origin MCP servers** - - CF durable object stability across requests - - REST => MCP: getDurableObject(`consumer auth token or deployment + IP`) containing MCP client connection - - MCP => MCP: getDurableObject(`mcp-session-id`) - - **do I just ditch the public REST interface and focus on MCP?** - how to guarantee that the request is coming from agentic? - - like `x-agentic-proxy-secret` or signed requests but for MCP servers - - or do this once at the connection level? - - how to pass agentic gateway context to the origin server? - - instead of headers, maybe optional `agenticContext` param? - - how does this work with mcp auth? + - `_meta` for tool calls + - _still need a way of doing this for initial connection requests_ - mcp auth provider support - SSE support? (no; post-mvp if at all; only support [streamable http](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) like smithery does, or maybe support both?) - caching for MCP tool call responses