feat: mcp edge server work

pull/715/head
Travis Fischer 2025-06-09 23:49:08 +07:00
rodzic 8ae5dec653
commit b4d6af670e
6 zmienionych plików z 344 dodań i 46 usunięć

Wyświetl plik

@ -6,6 +6,7 @@ import {
responseTime, responseTime,
sentry sentry
} from '@agentic/platform-hono' } from '@agentic/platform-hono'
import { parseToolIdentifier } from '@agentic/platform-validators'
import { Hono } from 'hono' import { Hono } from 'hono'
import type { GatewayHonoEnv, McpToolCallResponse } from './lib/types' import type { GatewayHonoEnv, McpToolCallResponse } from './lib/types'
@ -14,6 +15,7 @@ import { createHttpResponseFromMcpToolCallResponse } from './lib/create-http-res
import { fetchCache } from './lib/fetch-cache' import { fetchCache } from './lib/fetch-cache'
import { getRequestCacheKey } from './lib/get-request-cache-key' import { getRequestCacheKey } from './lib/get-request-cache-key'
import { resolveOriginRequest } from './lib/resolve-origin-request' import { resolveOriginRequest } from './lib/resolve-origin-request'
import { handleMcpRequest } from './mcp'
export const app = new Hono<GatewayHonoEnv>() export const app = new Hono<GatewayHonoEnv>()
@ -47,6 +49,15 @@ app.all(async (ctx) => {
ctx.set('cache', caches.default) ctx.set('cache', caches.default)
ctx.set('client', createAgenticClient(ctx)) ctx.set('client', createAgenticClient(ctx))
const requestUrl = new URL(ctx.req.url)
const { pathname } = requestUrl
const requestedToolIdentifier = pathname.replace(/^\//, '').replace(/\/$/, '')
const { toolName } = parseToolIdentifier(requestedToolIdentifier)
if (toolName === 'mcp') {
return handleMcpRequest(ctx)
}
const resolvedOriginRequest = await resolveOriginRequest(ctx) const resolvedOriginRequest = await resolveOriginRequest(ctx)
const originStartTime = Date.now() const originStartTime = Date.now()

Wyświetl plik

@ -0,0 +1,244 @@
import type { AdminDeployment, PricingPlan } from '@agentic/platform-types'
import { assert } from '@agentic/platform-core'
import { parseDeploymentIdentifier } from '@agentic/platform-validators'
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import {
CallToolRequestSchema,
ListToolsRequestSchema
} from '@modelcontextprotocol/sdk/types.js'
import contentType from 'fast-content-type-parse'
import type { DurableMcpClient } from './durable-mcp-client'
import type {
AdminConsumer,
AgenticMcpRequestMetadata,
GatewayHonoContext,
McpToolCallResponse
} from './types'
import { cfValidateJsonSchema } from './cf-validate-json-schema'
import { createRequestForOpenAPIOperation } from './create-request-for-openapi-operation'
import { fetchCache } from './fetch-cache'
import { getRequestCacheKey } from './get-request-cache-key'
import { updateOriginRequest } from './update-origin-request'
export type ConsumerMcpServerOptions = {
sessionId: string
deployment: AdminDeployment
consumer?: AdminConsumer
pricingPlan?: PricingPlan
}
export function createConsumerMcpServer(
ctx: GatewayHonoContext,
{ sessionId, deployment, consumer, pricingPlan }: ConsumerMcpServerOptions
) {
const { originAdapter } = deployment
const { projectIdentifier } = parseDeploymentIdentifier(deployment.identifier)
const server = new Server(
{ name: projectIdentifier, version: deployment.version ?? '0.0.0' },
{
capabilities: {
// TODO: add support for more capabilities
tools: {}
}
}
)
const tools = deployment.tools
.map((tool) => {
const toolConfig = deployment.toolConfigs.find(
(toolConfig) => toolConfig.name === tool.name
)
if (toolConfig) {
const pricingPlanToolConfig = pricingPlan
? toolConfig.pricingPlanConfig?.[pricingPlan.slug]
: undefined
if (pricingPlanToolConfig?.enabled === false) {
// Tool is disabled / hidden for the customer's current pricing plan
return undefined
}
if (!pricingPlanToolConfig?.enabled && !toolConfig.enabled) {
// Tool is disabled / hidden for all pricing plans
return undefined
}
}
return tool
})
.filter(Boolean)
server.setRequestHandler(ListToolsRequestSchema, async () => ({ tools }))
server.setRequestHandler(CallToolRequestSchema, async (request) => {
const { name, arguments: args, _meta } = request.params
const tool = tools.find((tool) => tool.name === name)
assert(tool, 404, `Unknown tool: ${name}`)
// TODO: Implement tool config logic
// const toolConfig = deployment.toolConfigs.find(
// (toolConfig) => toolConfig.name === tool.name
// )
if (originAdapter.type === 'raw') {
// TODO
assert(false, 500, 'Raw origin adapter not implemented')
} else {
// Validate incoming request params against the tool's input schema.
const toolCallArgs = cfValidateJsonSchema<Record<string, any>>({
schema: tool.inputSchema,
data: args,
errorMessage: `Invalid request parameters for tool "${tool.name}"`,
strictAdditionalProperties: true
})
if (originAdapter.type === 'openapi') {
const operation = originAdapter.toolToOperationMap[tool.name]
assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`)
assert(toolCallArgs, 500)
const originRequest = await createRequestForOpenAPIOperation({
toolCallArgs,
operation,
deployment
})
updateOriginRequest(originRequest, { consumer, deployment })
const cacheKey = await getRequestCacheKey(ctx, originRequest)
// TODO: transform origin 5XX errors to 502 errors...
// TODO: fetch origin request and transform response
const originResponse = await fetchCache(ctx, {
cacheKey,
fetchResponse: () => fetch(originRequest)
})
const { type: mimeType } = contentType.safeParse(
originResponse.headers.get('content-type') ||
'application/octet-stream'
)
if (tool.outputSchema) {
assert(
mimeType.includes('json'),
502,
`Tool "${tool.name}" requires a JSON response, but the origin returned content type "${mimeType}"`
)
const res: any = await originResponse.json()
const toolCallResponseContent = cfValidateJsonSchema({
schema: tool.outputSchema,
data: res as Record<string, unknown>,
coerce: false,
// TODO: double-check MCP schema on whether additional properties are allowed
strictAdditionalProperties: true,
errorMessage: `Invalid tool response for tool "${tool.name}"`,
errorStatusCode: 502
})
return {
structuredContent: toolCallResponseContent,
isError: originResponse.status >= 400
}
} else {
const result: McpToolCallResponse = {
isError: originResponse.status >= 400
}
if (mimeType.includes('json')) {
result.structuredContent = await originResponse.json()
} else if (mimeType.includes('text')) {
result.content = [
{
type: 'text',
text: await originResponse.text()
}
]
} else {
const resBody = await originResponse.arrayBuffer()
const resBodyBase64 = Buffer.from(resBody).toString('base64')
const type = mimeType.includes('image')
? 'image'
: mimeType.includes('audio')
? 'audio'
: 'resource'
// TODO: this needs work
result.content = [
{
type,
mimeType,
...(type === 'resource'
? {
blob: resBodyBase64
}
: {
data: resBodyBase64
})
}
]
}
return result
}
} else if (originAdapter.type === 'mcp') {
const id: DurableObjectId = ctx.env.DO_MCP_CLIENT.idFromName(sessionId)
const originMcpClient: DurableObjectStub<DurableMcpClient> =
ctx.env.DO_MCP_CLIENT.get(id)
await originMcpClient.init({
url: deployment.originUrl,
name: originAdapter.serverInfo.name,
version: originAdapter.serverInfo.version
})
const { projectIdentifier } = parseDeploymentIdentifier(
deployment.identifier,
{ errorStatusCode: 500 }
)
const originMcpRequestMetadata = {
agenticProxySecret: deployment._secret,
sessionId,
// ip,
isCustomerSubscriptionActive: !!consumer?.isStripeSubscriptionActive,
customerId: consumer?.id,
customerSubscriptionPlan: consumer?.plan,
customerSubscriptionStatus: consumer?.stripeStatus,
userId: consumer?.user.id,
userEmail: consumer?.user.email,
userUsername: consumer?.user.username,
userName: consumer?.user.name,
userCreatedAt: consumer?.user.createdAt,
userUpdatedAt: consumer?.user.updatedAt,
deploymentId: deployment.id,
deploymentIdentifier: deployment.identifier,
projectId: deployment.projectId,
projectIdentifier
} as AgenticMcpRequestMetadata
// TODO: add timeout support to the origin tool call?
// TODO: add response caching for MCP tool calls
const toolCallResponseString = await originMcpClient.callTool({
name: tool.name,
args: toolCallArgs,
metadata: originMcpRequestMetadata!
})
const toolCallResponse = JSON.parse(
toolCallResponseString
) as McpToolCallResponse
return toolCallResponse
} else {
assert(false, 500)
}
}
})
return server
}

Wyświetl plik

@ -4,21 +4,19 @@ import type {
} from '@agentic/platform-types' } from '@agentic/platform-types'
import { assert } from '@agentic/platform-core' import { assert } from '@agentic/platform-core'
import type { GatewayHonoContext, ToolCallArgs } from './types' import type { ToolCallArgs } from './types'
export async function createRequestForOpenAPIOperation( export async function createRequestForOpenAPIOperation({
ctx: GatewayHonoContext, toolCallArgs,
{ operation,
toolCallArgs, deployment,
operation, request
deployment }: {
}: { toolCallArgs: ToolCallArgs
toolCallArgs: ToolCallArgs operation: OpenAPIToolOperation
operation: OpenAPIToolOperation deployment: AdminDeployment
deployment: AdminDeployment request?: Request
} }): Promise<Request> {
): Promise<Request> {
const request = ctx.req.raw
assert(toolCallArgs, 500, 'Tool args are required') assert(toolCallArgs, 500, 'Tool args are required')
assert( assert(
deployment.originAdapter.type === 'openapi', deployment.originAdapter.type === 'openapi',
@ -43,13 +41,16 @@ export async function createRequestForOpenAPIOperation(
) )
const headers: Record<string, string> = {} const headers: Record<string, string> = {}
for (const [key, value] of request.headers.entries()) { if (request) {
headers[key] = value // TODO: do we want to expose these? especially authorization?
for (const [key, value] of request.headers.entries()) {
headers[key] = value
}
} }
if (headerParams.length > 0) { if (headerParams.length > 0) {
for (const [key] of headerParams) { for (const [key] of headerParams) {
headers[key] = (request.headers.get(key) as string) ?? toolCallArgs[key] headers[key] = (request?.headers.get(key) as string) ?? toolCallArgs[key]
} }
} }

Wyświetl plik

@ -149,9 +149,8 @@ export async function resolveOriginRequest(
if (pricingPlan && pricingPlanToolConfig) { if (pricingPlan && pricingPlanToolConfig) {
assert( assert(
pricingPlanToolConfig.enabled && pricingPlanToolConfig.enabled ||
pricingPlanToolConfig.enabled === undefined && (pricingPlanToolConfig.enabled === undefined && toolConfig.enabled),
toolConfig.enabled,
403, 403,
`Tool "${tool.name}" is not enabled for pricing plan "${pricingPlan.slug}"` `Tool "${tool.name}" is not enabled for pricing plan "${pricingPlan.slug}"`
) )
@ -202,7 +201,8 @@ export async function resolveOriginRequest(
assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`) assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`)
assert(toolCallArgs, 500) assert(toolCallArgs, 500)
originRequest = await createRequestForOpenAPIOperation(ctx, { originRequest = await createRequestForOpenAPIOperation({
request: ctx.req.raw,
toolCallArgs, toolCallArgs,
operation, operation,
deployment deployment

Wyświetl plik

@ -1,11 +1,10 @@
// import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' // import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'
// import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' // import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
import { assert, JsonRpcError } from '@agentic/platform-core' import { assert, JsonRpcError } from '@agentic/platform-core'
import { parseDeploymentIdentifier } from '@agentic/platform-validators' import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'
import { import {
InitializeRequestSchema, InitializeRequestSchema,
// isJSONRPCError, isJSONRPCError,
isJSONRPCNotification, isJSONRPCNotification,
isJSONRPCRequest, isJSONRPCRequest,
isJSONRPCResponse, isJSONRPCResponse,
@ -14,6 +13,7 @@ import {
} from '@modelcontextprotocol/sdk/types.js' } from '@modelcontextprotocol/sdk/types.js'
import type { GatewayHonoContext } from './lib/types' import type { GatewayHonoContext } from './lib/types'
import { createConsumerMcpServer } from './lib/consumer-mcp-server'
import { resolveMcpEdgeRequest } from './lib/resolve-mcp-edge-request' import { resolveMcpEdgeRequest } from './lib/resolve-mcp-edge-request'
// import { DurableMcpServer } from './lib/durable-mcp-server' // import { DurableMcpServer } from './lib/durable-mcp-server'
@ -177,24 +177,36 @@ export async function handleMcpRequest(ctx: GatewayHonoContext) {
// const durableMcpServer = ctx.env.DO_MCP_SERVER.get(id) // const durableMcpServer = ctx.env.DO_MCP_SERVER.get(id)
// const isInitialized = await durableMcpServer.isInitialized() // const isInitialized = await durableMcpServer.isInitialized()
if (!isInitializationRequest && !isInitialized) { // if (!isInitializationRequest && !isInitialized) {
// A session id that was never initialized was provided // // A session id that was never initialized was provided
throw new JsonRpcError({ // throw new JsonRpcError({
message: 'Session not found', // message: 'Session not found',
statusCode: 404, // statusCode: 404,
jsonRpcErrorCode: -32_001, // jsonRpcErrorCode: -32_001,
jsonRpcId: null // jsonRpcId: null
}) // })
} // }
const { deployment, consumer, pricingPlan } = await resolveMcpEdgeRequest(ctx) const { deployment, consumer, pricingPlan } = await resolveMcpEdgeRequest(ctx)
const { projectIdentifier } = parseDeploymentIdentifier(deployment.identifier) const server = createConsumerMcpServer(ctx, {
sessionId,
const server = new McpServer({ deployment,
name: projectIdentifier, consumer,
version: deployment.version ?? '0.0.0' pricingPlan
}) })
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => {
return ctx.env.DO_MCP_SERVER.newUniqueId().toString()
},
onsessioninitialized: (sessionId) => {
// TODO: improve this
// eslint-disable-next-line no-console
console.log(`Session initialized: ${sessionId}`)
}
})
await server.connect(transport)
// if (isInitializationRequest) { // if (isInitializationRequest) {
// await durableMcpServer.init({ // await durableMcpServer.init({
// deployment, // deployment,
@ -212,6 +224,36 @@ export async function handleMcpRequest(ctx: GatewayHonoContext) {
const writer = writable.getWriter() const writer = writable.getWriter()
const encoder = new TextEncoder() const encoder = new TextEncoder()
// Keep track of the request ids that we have sent to the server
// so that we can close the connection once we have received
// all the responses
const requestIds = new Set<string | number>()
// eslint-disable-next-line unicorn/prefer-add-event-listener
transport.onmessage = async (message) => {
// validate that the message is a valid JSONRPC message
const result = JSONRPCMessageSchema.safeParse(message)
if (!result.success) {
// TODO: add a warning here
return
}
// If the message is a response or an error, remove the id from the set of
// request ids
if (isJSONRPCResponse(result.data) || isJSONRPCError(result.data)) {
requestIds.delete(result.data.id)
}
// Send the message as an SSE event
const messageText = `event: message\ndata: ${JSON.stringify(result.data)}\n\n`
await writer.write(encoder.encode(messageText))
// If we have received all the responses, close the connection
if (!requestIds.size) {
ctx.executionCtx.waitUntil(transport.close())
}
}
// If there are no requests, we send the messages downstream and // If there are no requests, we send the messages downstream and
// acknowledge the request with a 202 since we don't expect any responses // acknowledge the request with a 202 since we don't expect any responses
// back through this connection. // back through this connection.
@ -219,10 +261,7 @@ export async function handleMcpRequest(ctx: GatewayHonoContext) {
(msg) => isJSONRPCNotification(msg) || isJSONRPCResponse(msg) (msg) => isJSONRPCNotification(msg) || isJSONRPCResponse(msg)
) )
if (hasOnlyNotificationsOrResponses) { if (hasOnlyNotificationsOrResponses) {
// TODO await Promise.all(messages.map((message) => transport.send(message)))
// for (const message of messages) {
// ws.send(JSON.stringify(message))
// }
return new Response(null, { return new Response(null, {
status: 202 status: 202
@ -233,9 +272,10 @@ export async function handleMcpRequest(ctx: GatewayHonoContext) {
if (isJSONRPCRequest(message)) { if (isJSONRPCRequest(message)) {
// Add each request id that we send off to a set so that we can keep // Add each request id that we send off to a set so that we can keep
// track of which requests we still need a response for. // track of which requests we still need a response for.
// requestIds.add(message.id) requestIds.add(message.id)
} }
// ws.send(JSON.stringify(message))
await transport.send(message)
} }
// Return the streamable http response. // Return the streamable http response.

Wyświetl plik

@ -48,9 +48,11 @@ export const pricingPlanToolConfigSchema = z
/** /**
* Whether this tool should be enabled for customers on a given pricing plan. * Whether this tool should be enabled for customers on a given pricing plan.
* *
* @default true * If `undefined`, will use the tool's default enabled state.
*
* @default undefined
*/ */
enabled: z.boolean().optional().default(true), enabled: z.boolean().optional(),
/** /**
* Overrides whether to report default `requests` usage for metered billing * Overrides whether to report default `requests` usage for metered billing