diff --git a/apps/gateway/src/app.ts b/apps/gateway/src/app.ts index 036de28a..768911a3 100644 --- a/apps/gateway/src/app.ts +++ b/apps/gateway/src/app.ts @@ -6,6 +6,7 @@ import { responseTime, sentry } from '@agentic/platform-hono' +import { parseToolIdentifier } from '@agentic/platform-validators' import { Hono } from 'hono' import type { GatewayHonoEnv, McpToolCallResponse } from './lib/types' @@ -14,6 +15,7 @@ import { createHttpResponseFromMcpToolCallResponse } from './lib/create-http-res import { fetchCache } from './lib/fetch-cache' import { getRequestCacheKey } from './lib/get-request-cache-key' import { resolveOriginRequest } from './lib/resolve-origin-request' +import { handleMcpRequest } from './mcp' export const app = new Hono() @@ -47,6 +49,15 @@ app.all(async (ctx) => { ctx.set('cache', caches.default) ctx.set('client', createAgenticClient(ctx)) + const requestUrl = new URL(ctx.req.url) + const { pathname } = requestUrl + const requestedToolIdentifier = pathname.replace(/^\//, '').replace(/\/$/, '') + const { toolName } = parseToolIdentifier(requestedToolIdentifier) + + if (toolName === 'mcp') { + return handleMcpRequest(ctx) + } + const resolvedOriginRequest = await resolveOriginRequest(ctx) const originStartTime = Date.now() diff --git a/apps/gateway/src/lib/consumer-mcp-server.ts b/apps/gateway/src/lib/consumer-mcp-server.ts new file mode 100644 index 00000000..1e3cf19b --- /dev/null +++ b/apps/gateway/src/lib/consumer-mcp-server.ts @@ -0,0 +1,244 @@ +import type { AdminDeployment, PricingPlan } from '@agentic/platform-types' +import { assert } from '@agentic/platform-core' +import { parseDeploymentIdentifier } from '@agentic/platform-validators' +import { Server } from '@modelcontextprotocol/sdk/server/index.js' +import { + CallToolRequestSchema, + ListToolsRequestSchema +} from '@modelcontextprotocol/sdk/types.js' +import contentType from 'fast-content-type-parse' + +import type { DurableMcpClient } from './durable-mcp-client' +import type { + AdminConsumer, + AgenticMcpRequestMetadata, + GatewayHonoContext, + McpToolCallResponse +} from './types' +import { cfValidateJsonSchema } from './cf-validate-json-schema' +import { createRequestForOpenAPIOperation } from './create-request-for-openapi-operation' +import { fetchCache } from './fetch-cache' +import { getRequestCacheKey } from './get-request-cache-key' +import { updateOriginRequest } from './update-origin-request' + +export type ConsumerMcpServerOptions = { + sessionId: string + deployment: AdminDeployment + consumer?: AdminConsumer + pricingPlan?: PricingPlan +} + +export function createConsumerMcpServer( + ctx: GatewayHonoContext, + { sessionId, deployment, consumer, pricingPlan }: ConsumerMcpServerOptions +) { + const { originAdapter } = deployment + const { projectIdentifier } = parseDeploymentIdentifier(deployment.identifier) + + const server = new Server( + { name: projectIdentifier, version: deployment.version ?? '0.0.0' }, + { + capabilities: { + // TODO: add support for more capabilities + tools: {} + } + } + ) + + const tools = deployment.tools + .map((tool) => { + const toolConfig = deployment.toolConfigs.find( + (toolConfig) => toolConfig.name === tool.name + ) + + if (toolConfig) { + const pricingPlanToolConfig = pricingPlan + ? toolConfig.pricingPlanConfig?.[pricingPlan.slug] + : undefined + + if (pricingPlanToolConfig?.enabled === false) { + // Tool is disabled / hidden for the customer's current pricing plan + return undefined + } + + if (!pricingPlanToolConfig?.enabled && !toolConfig.enabled) { + // Tool is disabled / hidden for all pricing plans + return undefined + } + } + + return tool + }) + .filter(Boolean) + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ tools })) + + server.setRequestHandler(CallToolRequestSchema, async (request) => { + const { name, arguments: args, _meta } = request.params + + const tool = tools.find((tool) => tool.name === name) + assert(tool, 404, `Unknown tool: ${name}`) + + // TODO: Implement tool config logic + // const toolConfig = deployment.toolConfigs.find( + // (toolConfig) => toolConfig.name === tool.name + // ) + + if (originAdapter.type === 'raw') { + // TODO + assert(false, 500, 'Raw origin adapter not implemented') + } else { + // Validate incoming request params against the tool's input schema. + const toolCallArgs = cfValidateJsonSchema>({ + schema: tool.inputSchema, + data: args, + errorMessage: `Invalid request parameters for tool "${tool.name}"`, + strictAdditionalProperties: true + }) + + if (originAdapter.type === 'openapi') { + const operation = originAdapter.toolToOperationMap[tool.name] + assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`) + assert(toolCallArgs, 500) + + const originRequest = await createRequestForOpenAPIOperation({ + toolCallArgs, + operation, + deployment + }) + + updateOriginRequest(originRequest, { consumer, deployment }) + + const cacheKey = await getRequestCacheKey(ctx, originRequest) + + // TODO: transform origin 5XX errors to 502 errors... + // TODO: fetch origin request and transform response + const originResponse = await fetchCache(ctx, { + cacheKey, + fetchResponse: () => fetch(originRequest) + }) + + const { type: mimeType } = contentType.safeParse( + originResponse.headers.get('content-type') || + 'application/octet-stream' + ) + + if (tool.outputSchema) { + assert( + mimeType.includes('json'), + 502, + `Tool "${tool.name}" requires a JSON response, but the origin returned content type "${mimeType}"` + ) + const res: any = await originResponse.json() + + const toolCallResponseContent = cfValidateJsonSchema({ + schema: tool.outputSchema, + data: res as Record, + coerce: false, + // TODO: double-check MCP schema on whether additional properties are allowed + strictAdditionalProperties: true, + errorMessage: `Invalid tool response for tool "${tool.name}"`, + errorStatusCode: 502 + }) + + return { + structuredContent: toolCallResponseContent, + isError: originResponse.status >= 400 + } + } else { + const result: McpToolCallResponse = { + isError: originResponse.status >= 400 + } + + if (mimeType.includes('json')) { + result.structuredContent = await originResponse.json() + } else if (mimeType.includes('text')) { + result.content = [ + { + type: 'text', + text: await originResponse.text() + } + ] + } else { + const resBody = await originResponse.arrayBuffer() + const resBodyBase64 = Buffer.from(resBody).toString('base64') + const type = mimeType.includes('image') + ? 'image' + : mimeType.includes('audio') + ? 'audio' + : 'resource' + + // TODO: this needs work + result.content = [ + { + type, + mimeType, + ...(type === 'resource' + ? { + blob: resBodyBase64 + } + : { + data: resBodyBase64 + }) + } + ] + } + + return result + } + } else if (originAdapter.type === 'mcp') { + const id: DurableObjectId = ctx.env.DO_MCP_CLIENT.idFromName(sessionId) + const originMcpClient: DurableObjectStub = + ctx.env.DO_MCP_CLIENT.get(id) + + await originMcpClient.init({ + url: deployment.originUrl, + name: originAdapter.serverInfo.name, + version: originAdapter.serverInfo.version + }) + + const { projectIdentifier } = parseDeploymentIdentifier( + deployment.identifier, + { errorStatusCode: 500 } + ) + + const originMcpRequestMetadata = { + agenticProxySecret: deployment._secret, + sessionId, + // ip, + isCustomerSubscriptionActive: !!consumer?.isStripeSubscriptionActive, + customerId: consumer?.id, + customerSubscriptionPlan: consumer?.plan, + customerSubscriptionStatus: consumer?.stripeStatus, + userId: consumer?.user.id, + userEmail: consumer?.user.email, + userUsername: consumer?.user.username, + userName: consumer?.user.name, + userCreatedAt: consumer?.user.createdAt, + userUpdatedAt: consumer?.user.updatedAt, + deploymentId: deployment.id, + deploymentIdentifier: deployment.identifier, + projectId: deployment.projectId, + projectIdentifier + } as AgenticMcpRequestMetadata + + // TODO: add timeout support to the origin tool call? + // TODO: add response caching for MCP tool calls + const toolCallResponseString = await originMcpClient.callTool({ + name: tool.name, + args: toolCallArgs, + metadata: originMcpRequestMetadata! + }) + const toolCallResponse = JSON.parse( + toolCallResponseString + ) as McpToolCallResponse + + return toolCallResponse + } else { + assert(false, 500) + } + } + }) + + return server +} diff --git a/apps/gateway/src/lib/create-request-for-openapi-operation.ts b/apps/gateway/src/lib/create-request-for-openapi-operation.ts index 6312f61f..dab12775 100644 --- a/apps/gateway/src/lib/create-request-for-openapi-operation.ts +++ b/apps/gateway/src/lib/create-request-for-openapi-operation.ts @@ -4,21 +4,19 @@ import type { } from '@agentic/platform-types' import { assert } from '@agentic/platform-core' -import type { GatewayHonoContext, ToolCallArgs } from './types' +import type { ToolCallArgs } from './types' -export async function createRequestForOpenAPIOperation( - ctx: GatewayHonoContext, - { - toolCallArgs, - operation, - deployment - }: { - toolCallArgs: ToolCallArgs - operation: OpenAPIToolOperation - deployment: AdminDeployment - } -): Promise { - const request = ctx.req.raw +export async function createRequestForOpenAPIOperation({ + toolCallArgs, + operation, + deployment, + request +}: { + toolCallArgs: ToolCallArgs + operation: OpenAPIToolOperation + deployment: AdminDeployment + request?: Request +}): Promise { assert(toolCallArgs, 500, 'Tool args are required') assert( deployment.originAdapter.type === 'openapi', @@ -43,13 +41,16 @@ export async function createRequestForOpenAPIOperation( ) const headers: Record = {} - for (const [key, value] of request.headers.entries()) { - headers[key] = value + if (request) { + // TODO: do we want to expose these? especially authorization? + for (const [key, value] of request.headers.entries()) { + headers[key] = value + } } if (headerParams.length > 0) { for (const [key] of headerParams) { - headers[key] = (request.headers.get(key) as string) ?? toolCallArgs[key] + headers[key] = (request?.headers.get(key) as string) ?? toolCallArgs[key] } } diff --git a/apps/gateway/src/lib/resolve-origin-request.ts b/apps/gateway/src/lib/resolve-origin-request.ts index 9d1daba3..2bf57fcc 100644 --- a/apps/gateway/src/lib/resolve-origin-request.ts +++ b/apps/gateway/src/lib/resolve-origin-request.ts @@ -149,9 +149,8 @@ export async function resolveOriginRequest( if (pricingPlan && pricingPlanToolConfig) { assert( - pricingPlanToolConfig.enabled && - pricingPlanToolConfig.enabled === undefined && - toolConfig.enabled, + pricingPlanToolConfig.enabled || + (pricingPlanToolConfig.enabled === undefined && toolConfig.enabled), 403, `Tool "${tool.name}" is not enabled for pricing plan "${pricingPlan.slug}"` ) @@ -202,7 +201,8 @@ export async function resolveOriginRequest( assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`) assert(toolCallArgs, 500) - originRequest = await createRequestForOpenAPIOperation(ctx, { + originRequest = await createRequestForOpenAPIOperation({ + request: ctx.req.raw, toolCallArgs, operation, deployment diff --git a/apps/gateway/src/mcp.ts b/apps/gateway/src/mcp.ts index e938e643..ee77362a 100644 --- a/apps/gateway/src/mcp.ts +++ b/apps/gateway/src/mcp.ts @@ -1,11 +1,10 @@ // import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' // import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { assert, JsonRpcError } from '@agentic/platform-core' -import { parseDeploymentIdentifier } from '@agentic/platform-validators' -import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js' import { InitializeRequestSchema, - // isJSONRPCError, + isJSONRPCError, isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResponse, @@ -14,6 +13,7 @@ import { } from '@modelcontextprotocol/sdk/types.js' import type { GatewayHonoContext } from './lib/types' +import { createConsumerMcpServer } from './lib/consumer-mcp-server' import { resolveMcpEdgeRequest } from './lib/resolve-mcp-edge-request' // import { DurableMcpServer } from './lib/durable-mcp-server' @@ -177,24 +177,36 @@ export async function handleMcpRequest(ctx: GatewayHonoContext) { // const durableMcpServer = ctx.env.DO_MCP_SERVER.get(id) // const isInitialized = await durableMcpServer.isInitialized() - if (!isInitializationRequest && !isInitialized) { - // A session id that was never initialized was provided - throw new JsonRpcError({ - message: 'Session not found', - statusCode: 404, - jsonRpcErrorCode: -32_001, - jsonRpcId: null - }) - } + // if (!isInitializationRequest && !isInitialized) { + // // A session id that was never initialized was provided + // throw new JsonRpcError({ + // message: 'Session not found', + // statusCode: 404, + // jsonRpcErrorCode: -32_001, + // jsonRpcId: null + // }) + // } const { deployment, consumer, pricingPlan } = await resolveMcpEdgeRequest(ctx) - const { projectIdentifier } = parseDeploymentIdentifier(deployment.identifier) - - const server = new McpServer({ - name: projectIdentifier, - version: deployment.version ?? '0.0.0' + const server = createConsumerMcpServer(ctx, { + sessionId, + deployment, + consumer, + pricingPlan }) + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => { + return ctx.env.DO_MCP_SERVER.newUniqueId().toString() + }, + onsessioninitialized: (sessionId) => { + // TODO: improve this + // eslint-disable-next-line no-console + console.log(`Session initialized: ${sessionId}`) + } + }) + await server.connect(transport) + // if (isInitializationRequest) { // await durableMcpServer.init({ // deployment, @@ -212,6 +224,36 @@ export async function handleMcpRequest(ctx: GatewayHonoContext) { const writer = writable.getWriter() const encoder = new TextEncoder() + // Keep track of the request ids that we have sent to the server + // so that we can close the connection once we have received + // all the responses + const requestIds = new Set() + + // eslint-disable-next-line unicorn/prefer-add-event-listener + transport.onmessage = async (message) => { + // validate that the message is a valid JSONRPC message + const result = JSONRPCMessageSchema.safeParse(message) + if (!result.success) { + // TODO: add a warning here + return + } + + // If the message is a response or an error, remove the id from the set of + // request ids + if (isJSONRPCResponse(result.data) || isJSONRPCError(result.data)) { + requestIds.delete(result.data.id) + } + + // Send the message as an SSE event + const messageText = `event: message\ndata: ${JSON.stringify(result.data)}\n\n` + await writer.write(encoder.encode(messageText)) + + // If we have received all the responses, close the connection + if (!requestIds.size) { + ctx.executionCtx.waitUntil(transport.close()) + } + } + // If there are no requests, we send the messages downstream and // acknowledge the request with a 202 since we don't expect any responses // back through this connection. @@ -219,10 +261,7 @@ export async function handleMcpRequest(ctx: GatewayHonoContext) { (msg) => isJSONRPCNotification(msg) || isJSONRPCResponse(msg) ) if (hasOnlyNotificationsOrResponses) { - // TODO - // for (const message of messages) { - // ws.send(JSON.stringify(message)) - // } + await Promise.all(messages.map((message) => transport.send(message))) return new Response(null, { status: 202 @@ -233,9 +272,10 @@ export async function handleMcpRequest(ctx: GatewayHonoContext) { if (isJSONRPCRequest(message)) { // Add each request id that we send off to a set so that we can keep // track of which requests we still need a response for. - // requestIds.add(message.id) + requestIds.add(message.id) } - // ws.send(JSON.stringify(message)) + + await transport.send(message) } // Return the streamable http response. diff --git a/packages/types/src/tools.ts b/packages/types/src/tools.ts index b9a5472c..d4bbef60 100644 --- a/packages/types/src/tools.ts +++ b/packages/types/src/tools.ts @@ -48,9 +48,11 @@ export const pricingPlanToolConfigSchema = z /** * Whether this tool should be enabled for customers on a given pricing plan. * - * @default true + * If `undefined`, will use the tool's default enabled state. + * + * @default undefined */ - enabled: z.boolean().optional().default(true), + enabled: z.boolean().optional(), /** * Overrides whether to report default `requests` usage for metered billing