diff --git a/apps/gateway/src/app.ts b/apps/gateway/src/app.ts index f33ce716..743270b1 100644 --- a/apps/gateway/src/app.ts +++ b/apps/gateway/src/app.ts @@ -9,13 +9,12 @@ import { import { parseToolIdentifier } from '@agentic/platform-validators' import { Hono } from 'hono' -import type { GatewayHonoEnv, McpToolCallResponse } from './lib/types' +import type { GatewayHonoEnv } from './lib/types' import { createAgenticClient } from './lib/agentic-client' import { createHttpResponseFromMcpToolCallResponse } from './lib/create-http-response-from-mcp-tool-call-response' -import { fetchCache } from './lib/fetch-cache' -import { getRequestCacheKey } from './lib/get-request-cache-key' +import { resolveHttpEdgeRequest } from './lib/resolve-http-edge-request' import { resolveMcpEdgeRequest } from './lib/resolve-mcp-edge-request' -import { resolveOriginRequest } from './lib/resolve-origin-request' +import { resolveOriginToolCall } from './lib/resolve-origin-tool-call' import { DurableMcpServer } from './worker' export const app = new Hono() @@ -67,69 +66,37 @@ app.all(async (ctx) => { }).fetch(ctx.req.raw, ctx.env, executionCtx) } - const resolvedOriginRequest = await resolveOriginRequest(ctx) + const resolvedEdgeRequest = await resolveHttpEdgeRequest(ctx) const originStartTime = Date.now() + + const resolvedOriginToolCallResult = await resolveOriginToolCall({ + tool: resolvedEdgeRequest.tool, + args: resolvedEdgeRequest.toolCallArgs, + deployment: resolvedEdgeRequest.deployment, + consumer: resolvedEdgeRequest.consumer, + pricingPlan: resolvedEdgeRequest.pricingPlan, + sessionId: ctx.get('sessionId')!, + ip: ctx.get('ip'), + env: ctx.env, + waitUntil: ctx.executionCtx.waitUntil + }) + let originResponse: Response | undefined - - switch (resolvedOriginRequest.deployment.originAdapter.type) { - case 'openapi': - case 'raw': { - assert( - resolvedOriginRequest.originRequest, - 500, - 'Origin request is required' - ) - - const cacheKey = await getRequestCacheKey( - ctx, - resolvedOriginRequest.originRequest - ) - - // TODO: transform origin 5XX errors to 502 errors... - originResponse = await fetchCache(ctx, { - cacheKey, - fetchResponse: () => fetch(resolvedOriginRequest.originRequest!) - }) - break - } - - case 'mcp': { - assert( - resolvedOriginRequest.toolCallArgs, - 500, - 'Tool args are required for MCP origin requests' - ) - assert( - resolvedOriginRequest.originMcpClient, - 500, - 'MCP client is required for MCP origin requests' - ) - - // TODO: add timeout support to the origin tool call? - // TODO: add response caching for MCP tool calls - const toolCallResponseString = - await resolvedOriginRequest.originMcpClient.callTool({ - name: resolvedOriginRequest.tool.name, - args: resolvedOriginRequest.toolCallArgs, - metadata: resolvedOriginRequest.originMcpRequestMetadata! - }) - const toolCallResponse = JSON.parse( - toolCallResponseString - ) as McpToolCallResponse - - originResponse = await createHttpResponseFromMcpToolCallResponse(ctx, { - tool: resolvedOriginRequest.tool, - deployment: resolvedOriginRequest.deployment, - toolCallResponse - }) - } + if (resolvedOriginToolCallResult.originResponse) { + originResponse = resolvedOriginToolCallResult.originResponse + } else { + originResponse = await createHttpResponseFromMcpToolCallResponse(ctx, { + tool: resolvedEdgeRequest.tool, + deployment: resolvedEdgeRequest.deployment, + toolCallResponse: resolvedOriginToolCallResult.toolCallResponse + }) } assert(originResponse, 500, 'Origin response is required') const res = new Response(originResponse.body, originResponse) - // Record the time it took for both the origin and gateway to respond + // Record the time it took for the origin to respond. const now = Date.now() const originTimespan = now - originStartTime res.headers.set('x-origin-response-time', `${originTimespan}ms`) diff --git a/apps/gateway/src/lib/cf-validate-json-schema.ts b/apps/gateway/src/lib/cf-validate-json-schema.ts index 62f99d64..a8030957 100644 --- a/apps/gateway/src/lib/cf-validate-json-schema.ts +++ b/apps/gateway/src/lib/cf-validate-json-schema.ts @@ -18,14 +18,14 @@ export function cfValidateJsonSchema({ data, coerce = false, strictAdditionalProperties = false, - errorMessage, + errorPrefix, errorStatusCode = 400 }: { schema: any data: unknown coerce?: boolean strictAdditionalProperties?: boolean - errorMessage?: string + errorPrefix?: string errorStatusCode?: number }): T { assert(schema, 400, '`schema` is required') @@ -37,7 +37,7 @@ export function cfValidateJsonSchema({ if (isSchemaObject && !isDataObject) { throw new HttpError({ statusCode: 400, - message: `${errorMessage ? errorMessage + ': ' : ''}Data must be an object according to its schema.` + message: `${errorPrefix ? errorPrefix + ': ' : ''}Data must be an object according to its schema.` }) } @@ -50,7 +50,7 @@ export function cfValidateJsonSchema({ if (missingRequiredFields.length > 0) { throw new HttpError({ statusCode: errorStatusCode, - message: `${errorMessage ? errorMessage + ': ' : ''}Missing required ${plur('parameter', missingRequiredFields.length)}: ${missingRequiredFields.map((field) => `"${field}"`).join(', ')}` + message: `${errorPrefix ? errorPrefix + ': ' : ''}Missing required ${plur('parameter', missingRequiredFields.length)}: ${missingRequiredFields.map((field) => `"${field}"`).join(', ')}` }) } } @@ -70,7 +70,7 @@ export function cfValidateJsonSchema({ if (extraProperties.length > 0) { throw new HttpError({ statusCode: errorStatusCode, - message: `${errorMessage ? errorMessage + ': ' : ''}Unexpected additional ${plur('parameter', extraProperties.length)}: ${extraProperties.map((property) => `"${property}"`).join(', ')}` + message: `${errorPrefix ? errorPrefix + ': ' : ''}Unexpected additional ${plur('parameter', extraProperties.length)}: ${extraProperties.map((property) => `"${property}"`).join(', ')}` }) } } @@ -93,7 +93,7 @@ export function cfValidateJsonSchema({ } const finalErrorMessage = `${ - errorMessage ? errorMessage + ': ' : '' + errorPrefix ? errorPrefix + ': ' : '' }${result.errors .map(({ keyword, error }) => `keyword "${keyword}" error ${error}`) .join(' ')}` diff --git a/apps/gateway/src/lib/create-http-response-from-mcp-tool-call-response.ts b/apps/gateway/src/lib/create-http-response-from-mcp-tool-call-response.ts index eaa5d7fa..c4a08254 100644 --- a/apps/gateway/src/lib/create-http-response-from-mcp-tool-call-response.ts +++ b/apps/gateway/src/lib/create-http-response-from-mcp-tool-call-response.ts @@ -24,6 +24,7 @@ export async function createHttpResponseFromMcpToolCallResponse( assert( !toolCallResponse.isError, 502, + // TODO: add content or structuredContent to the error message `MCP tool "${tool.name}" returned an error.` ) @@ -41,7 +42,7 @@ export async function createHttpResponseFromMcpToolCallResponse( coerce: false, // TODO: double-check MCP schema on whether additional properties are allowed strictAdditionalProperties: true, - errorMessage: `Invalid tool response for tool "${tool.name}"`, + errorPrefix: `Invalid tool response for tool "${tool.name}"`, errorStatusCode: 502 }) diff --git a/apps/gateway/src/lib/durable-mcp-server.ts b/apps/gateway/src/lib/durable-mcp-server.ts index c8178973..5aea4645 100644 --- a/apps/gateway/src/lib/durable-mcp-server.ts +++ b/apps/gateway/src/lib/durable-mcp-server.ts @@ -10,17 +10,9 @@ import { import { McpAgent } from 'agents/mcp' import type { RawEnv } from './env' -import type { - AdminConsumer, - AgenticMcpRequestMetadata, - McpToolCallResponse -} from './types' -import { cfValidateJsonSchema } from './cf-validate-json-schema' -import { createRequestForOpenAPIOperation } from './create-request-for-openapi-operation' +import type { AdminConsumer } from './types' +import { resolveOriginToolCall } from './resolve-origin-tool-call' import { transformHttpResponseToMcpToolCallResponse } from './transform-http-response-to-mcp-tool-call-response' -// import { fetchCache } from './fetch-cache' -// import { getRequestCacheKey } from './get-request-cache-key' -import { updateOriginRequest } from './update-origin-request' // type State = { counter: number } @@ -31,6 +23,7 @@ export class DurableMcpServer extends McpAgent< deployment: AdminDeployment consumer?: AdminConsumer pricingPlan?: PricingPlan + ip?: string } > { protected _serverP = Promise.withResolvers() @@ -41,8 +34,7 @@ export class DurableMcpServer extends McpAgent< // } override async init() { - const { consumer, deployment, pricingPlan } = this.props - const { originAdapter } = deployment + const { consumer, deployment, pricingPlan, ip } = this.props const { projectIdentifier } = parseDeploymentIdentifier( deployment.identifier ) @@ -102,106 +94,31 @@ export class DurableMcpServer extends McpAgent< // TODO: caching // TODO: usage tracking / reporting - if (originAdapter.type === 'raw') { - // TODO - assert(false, 500, 'Raw origin adapter not implemented') - } else { - // Validate incoming request params against the tool's input schema. - const toolCallArgs = cfValidateJsonSchema>({ - schema: tool.inputSchema, - data: args, - errorMessage: `Invalid request parameters for tool "${tool.name}"`, - strictAdditionalProperties: true + const sessionId = this.ctx.id.toString() + const { toolCallArgs, originRequest, originResponse, toolCallResponse } = + await resolveOriginToolCall({ + tool, + args, + deployment, + consumer, + pricingPlan, + sessionId, + env: this.env, + ip, + waitUntil: this.ctx.waitUntil }) - if (originAdapter.type === 'openapi') { - const operation = originAdapter.toolToOperationMap[tool.name] - assert( - operation, - 404, - `Tool "${tool.name}" not found in OpenAPI spec` - ) - assert(toolCallArgs, 500) - - const originRequest = await createRequestForOpenAPIOperation({ - toolCallArgs, - operation, - deployment - }) - - updateOriginRequest(originRequest, { consumer, deployment }) - - // TODO: re-add caching support - // const cacheKey = await getRequestCacheKey(ctx, originRequest) - - // // TODO: transform origin 5XX errors to 502 errors... - // // TODO: fetch origin request and transform response - // const originResponse = await fetchCache(ctx, { - // cacheKey, - // fetchResponse: () => fetch(originRequest) - // }) - - const originResponse = await fetch(originRequest) - - return transformHttpResponseToMcpToolCallResponse({ - originRequest, - originResponse, - tool, - toolCallArgs - }) - } else if (originAdapter.type === 'mcp') { - const sessionId = this.ctx.id.toString() - const id: DurableObjectId = - this.env.DO_MCP_CLIENT.idFromName(sessionId) - const originMcpClient = this.env.DO_MCP_CLIENT.get(id) - - await originMcpClient.init({ - url: deployment.originUrl, - name: originAdapter.serverInfo.name, - version: originAdapter.serverInfo.version - }) - - const { projectIdentifier } = parseDeploymentIdentifier( - deployment.identifier, - { errorStatusCode: 500 } - ) - - const originMcpRequestMetadata = { - agenticProxySecret: deployment._secret, - sessionId, - // ip, - isCustomerSubscriptionActive: - !!consumer?.isStripeSubscriptionActive, - customerId: consumer?.id, - customerSubscriptionPlan: consumer?.plan, - customerSubscriptionStatus: consumer?.stripeStatus, - userId: consumer?.user.id, - userEmail: consumer?.user.email, - userUsername: consumer?.user.username, - userName: consumer?.user.name, - userCreatedAt: consumer?.user.createdAt, - userUpdatedAt: consumer?.user.updatedAt, - deploymentId: deployment.id, - deploymentIdentifier: deployment.identifier, - projectId: deployment.projectId, - projectIdentifier - } as AgenticMcpRequestMetadata - - // TODO: add timeout support to the origin tool call? - // TODO: add response caching for MCP tool calls - const toolCallResponseString = await originMcpClient.callTool({ - name: tool.name, - args: toolCallArgs, - metadata: originMcpRequestMetadata! - }) - const toolCallResponse = JSON.parse( - toolCallResponseString - ) as McpToolCallResponse - - return toolCallResponse - } else { - assert(false, 500) - } + if (originResponse) { + return transformHttpResponseToMcpToolCallResponse({ + originRequest, + originResponse, + tool, + toolCallArgs + }) + } else if (toolCallResponse) { + return toolCallResponse + } else { + assert(false, 500) } }) } diff --git a/apps/gateway/src/lib/enforce-rate-limit.ts b/apps/gateway/src/lib/enforce-rate-limit.ts index 43aa940e..29069c49 100644 --- a/apps/gateway/src/lib/enforce-rate-limit.ts +++ b/apps/gateway/src/lib/enforce-rate-limit.ts @@ -1,27 +1,21 @@ import { assert } from '@agentic/platform-core' -import type { GatewayHonoContext } from './types' - // https://developers.cloudflare.com/durable-objects/examples/build-a-rate-limiter/ // https://github.com/rhinobase/hono-rate-limiter/blob/main/packages/cloudflare/src/stores/DurableObjectStore.ts // https://github.com/rhinobase/hono-rate-limiter/blob/main/packages/core/src/core.ts -export async function enforceRateLimit( - ctx: GatewayHonoContext, - { - id, - interval, - maxPerInterval - }: { - id?: string - interval: number - maxPerInterval: number - } -) { +export async function enforceRateLimit({ + id, + interval, + maxPerInterval +}: { + id?: string + interval: number + maxPerInterval: number +}) { assert(id, 400, 'Unauthenticated requests must have a valid IP address') // TODO - assert(ctx, 500, 'not implemented') assert(id, 500, 'not implemented') assert(interval > 0, 500, 'not implemented') assert(maxPerInterval >= 0, 500, 'not implemented') diff --git a/apps/gateway/src/lib/fetch-cache.ts b/apps/gateway/src/lib/fetch-cache.ts index 89040faf..541b69e5 100644 --- a/apps/gateway/src/lib/fetch-cache.ts +++ b/apps/gateway/src/lib/fetch-cache.ts @@ -1,17 +1,13 @@ -import type { GatewayHonoContext } from './types' - -export async function fetchCache( - ctx: GatewayHonoContext, - { - cacheKey, - fetchResponse - }: { - cacheKey?: Request - fetchResponse: () => Promise - } -): Promise { - const cache = ctx.get('cache') - const logger = ctx.get('logger') +export async function fetchCache({ + cacheKey, + fetchResponse, + waitUntil +}: { + cacheKey?: Request + fetchResponse: () => Promise + waitUntil: (promise: Promise) => void +}): Promise { + const cache = caches.default let response: Response | undefined if (cacheKey) { @@ -25,9 +21,10 @@ export async function fetchCache( if (cacheKey) { if (response.headers.has('Cache-Control')) { // Note that cloudflare's `cache` should respect response headers. - ctx.executionCtx.waitUntil( + waitUntil( cache.put(cacheKey, response.clone()).catch((err) => { - logger.warn('cache put error', cacheKey, err) + // eslint-disable-next-line no-console + console.warn('cache put error', cacheKey, err) }) ) } diff --git a/apps/gateway/src/lib/get-request-cache-key.ts b/apps/gateway/src/lib/get-request-cache-key.ts index ce376570..20fe8cf1 100644 --- a/apps/gateway/src/lib/get-request-cache-key.ts +++ b/apps/gateway/src/lib/get-request-cache-key.ts @@ -1,14 +1,12 @@ import { hashObject, sha256 } from '@agentic/platform-core' import contentType from 'fast-content-type-parse' -import type { GatewayHonoContext } from './types' import { normalizeUrl } from './normalize-url' // TODO: what is a reasonable upper bound for hashing the POST body size? const MAX_POST_BODY_SIZE_BYTES = 10_000 export async function getRequestCacheKey( - ctx: GatewayHonoContext, request: Request ): Promise { try { @@ -85,8 +83,13 @@ export async function getRequestCacheKey( return normalizeRequestHeaders(new Request(request)) } catch (err) { - const logger = ctx.get('logger') - logger.error('error computing cache key', request.method, request.url, err) + // eslint-disable-next-line no-console + console.warn( + 'warning: failed to compute cache key', + request.method, + request.url, + err + ) return } } diff --git a/apps/gateway/src/lib/get-tool-args-from-request.ts b/apps/gateway/src/lib/get-tool-args-from-request.ts index d95c3c29..6eebb395 100644 --- a/apps/gateway/src/lib/get-tool-args-from-request.ts +++ b/apps/gateway/src/lib/get-tool-args-from-request.ts @@ -21,23 +21,30 @@ export async function getToolArgsFromRequest( `Internal logic error for origin adapter type "${deployment.originAdapter.type}"` ) - let incomingRequestArgsRaw: Record = {} - let coerceRequestArgs = false - if (request.method === 'GET') { // Args will be coerced to match their expected types via // `cfValidateJsonSchemaObject` since all values will be strings. - incomingRequestArgsRaw = Object.fromEntries( + const incomingRequestArgsRaw = Object.fromEntries( new URL(request.url).searchParams.entries() ) - coerceRequestArgs = true + + // Validate incoming request params against the tool's input schema. + const incomingRequestArgs = cfValidateJsonSchema>({ + schema: tool.inputSchema, + data: incomingRequestArgsRaw, + errorPrefix: `Invalid request parameters for tool "${tool.name}"`, + coerce: true, + strictAdditionalProperties: true + }) + + return incomingRequestArgs } else if (request.method === 'POST') { - incomingRequestArgsRaw = (await request.clone().json()) as Record< + const incomingRequestArgsRaw = (await request.clone().json()) as Record< string, any > - // TODO: Support empty params for POST requests + // TODO: Proper support for empty params with POST requests assert(incomingRequestArgsRaw, 400, 'Invalid empty request body') assert( typeof incomingRequestArgsRaw === 'object', @@ -45,16 +52,8 @@ export async function getToolArgsFromRequest( 'Invalid request body' ) assert(!Array.isArray(incomingRequestArgsRaw), 400, 'Invalid request body') + return incomingRequestArgsRaw + } else { + assert(false, 405, `HTTP method "${request.method}" not allowed`) } - - // Validate incoming request params against the tool's input schema. - const incomingRequestArgs = cfValidateJsonSchema>({ - schema: tool.inputSchema, - data: incomingRequestArgsRaw, - errorMessage: `Invalid request parameters for tool "${tool.name}"`, - coerce: coerceRequestArgs, - strictAdditionalProperties: true - }) - - return incomingRequestArgs } diff --git a/apps/gateway/src/lib/resolve-http-edge-request.ts b/apps/gateway/src/lib/resolve-http-edge-request.ts new file mode 100644 index 00000000..907660e5 --- /dev/null +++ b/apps/gateway/src/lib/resolve-http-edge-request.ts @@ -0,0 +1,120 @@ +import type { + AdminDeployment, + PricingPlan, + Tool +} from '@agentic/platform-types' +import { assert } from '@agentic/platform-core' +import { parseToolIdentifier } from '@agentic/platform-validators' + +import type { AdminConsumer, GatewayHonoContext, ToolCallArgs } from './types' +import { getAdminConsumer } from './get-admin-consumer' +import { getAdminDeployment } from './get-admin-deployment' +import { getTool } from './get-tool' +import { getToolArgsFromRequest } from './get-tool-args-from-request' + +export type ResolvedHttpEdgeRequest = { + deployment: AdminDeployment + consumer?: AdminConsumer + pricingPlan?: PricingPlan + + tool: Tool + toolCallArgs: ToolCallArgs +} + +/** + * Resolves an input HTTP request to a specific deployment, tool call, and + * billing subscription. + * + * Also ensures that the request is valid, enforces rate limits, and adds proxy- + * specific headers to the origin request. + */ +export async function resolveHttpEdgeRequest( + ctx: GatewayHonoContext +): Promise { + const logger = ctx.get('logger') + const ip = ctx.get('ip') + + const { method } = ctx.req + const requestUrl = new URL(ctx.req.url) + const { pathname } = requestUrl + const requestedToolIdentifier = pathname.replace(/^\//, '').replace(/\/$/, '') + const { toolName, deploymentIdentifier } = parseToolIdentifier( + requestedToolIdentifier + ) + + const deployment = await getAdminDeployment(ctx, deploymentIdentifier) + + const tool = getTool({ + method, + deployment, + toolName + }) + + logger.debug('request', { + method, + pathname, + deploymentIdentifier: deployment.identifier, + toolName, + tool + }) + + let pricingPlan: PricingPlan | undefined + let consumer: AdminConsumer | undefined + + const token = (ctx.req.header('authorization') || '') + .replace(/^Bearer /i, '') + .trim() + + if (token) { + consumer = await getAdminConsumer(ctx, token) + assert(consumer, 401, `Invalid auth token "${token}"`) + assert( + consumer.isStripeSubscriptionActive, + 402, + `Auth token "${token}" does not have an active subscription` + ) + assert( + consumer.projectId === deployment.projectId, + 403, + `Auth token "${token}" is not authorized for project "${deployment.projectId}"` + ) + + // TODO: Ensure that consumer.plan is compatible with the target deployment? + // TODO: This could definitely cause issues when changing pricing plans. + + pricingPlan = deployment.pricingPlans.find( + (pricingPlan) => consumer!.plan === pricingPlan.slug + ) + + // assert( + // pricingPlan, + // 403, + // `Auth token "${token}" unable to find matching pricing plan for project "${deployment.project}"` + // ) + + if (!ctx.get('sessionId')) { + ctx.set('sessionId', `${consumer.id}:${deployment.id}`) + } + } else { + // For unauthenticated requests, default to a free pricing plan if available. + pricingPlan = deployment.pricingPlans.find((plan) => plan.slug === 'free') + + if (!ctx.get('sessionId')) { + assert(ip, 500, 'IP address is required for unauthenticated requests') + ctx.set('sessionId', `${ip}:${deployment.projectId}`) + } + } + + assert(ctx.get('sessionId'), 500, 'Internal error: sessionId should be set') + + // Parse tool call args from the request body. + const toolCallArgs = await getToolArgsFromRequest(ctx, { tool, deployment }) + + return { + deployment, + consumer, + pricingPlan, + tool, + toolCallArgs + } +} diff --git a/apps/gateway/src/lib/resolve-mcp-edge-request.ts b/apps/gateway/src/lib/resolve-mcp-edge-request.ts index 2e2e3e63..958da975 100644 --- a/apps/gateway/src/lib/resolve-mcp-edge-request.ts +++ b/apps/gateway/src/lib/resolve-mcp-edge-request.ts @@ -6,11 +6,16 @@ import type { AdminConsumer, GatewayHonoContext } from './types' import { getAdminConsumer } from './get-admin-consumer' import { getAdminDeployment } from './get-admin-deployment' -export async function resolveMcpEdgeRequest(ctx: GatewayHonoContext): Promise<{ +export type ResolvedMcpEdgeRequest = { deployment: AdminDeployment consumer?: AdminConsumer pricingPlan?: PricingPlan -}> { + ip?: string +} + +export async function resolveMcpEdgeRequest( + ctx: GatewayHonoContext +): Promise { const requestUrl = new URL(ctx.req.url) const { pathname } = requestUrl const requestedDeploymentIdentifier = pathname @@ -60,6 +65,7 @@ export async function resolveMcpEdgeRequest(ctx: GatewayHonoContext): Promise<{ return { deployment, consumer, - pricingPlan + pricingPlan, + ip: ctx.get('ip') } } diff --git a/apps/gateway/src/lib/resolve-origin-request.ts b/apps/gateway/src/lib/resolve-origin-request.ts deleted file mode 100644 index 2bf57fcc..00000000 --- a/apps/gateway/src/lib/resolve-origin-request.ts +++ /dev/null @@ -1,266 +0,0 @@ -import type { PricingPlan, RateLimit } from '@agentic/platform-types' -import { assert } from '@agentic/platform-core' -import { - parseDeploymentIdentifier, - parseToolIdentifier -} from '@agentic/platform-validators' - -import type { DurableMcpClient } from './durable-mcp-client' -import type { - AdminConsumer, - AgenticMcpRequestMetadata, - GatewayHonoContext, - ResolvedOriginRequest, - ToolCallArgs -} from './types' -import { createRequestForOpenAPIOperation } from './create-request-for-openapi-operation' -import { enforceRateLimit } from './enforce-rate-limit' -import { getAdminConsumer } from './get-admin-consumer' -import { getAdminDeployment } from './get-admin-deployment' -import { getTool } from './get-tool' -import { getToolArgsFromRequest } from './get-tool-args-from-request' -import { updateOriginRequest } from './update-origin-request' - -/** - * Resolves an input HTTP request to a specific deployment, tool call, and - * billing subscription. - * - * Also ensures that the request is valid, enforces rate limits, and adds proxy- - * specific headers to the origin request. - */ -export async function resolveOriginRequest( - ctx: GatewayHonoContext -): Promise { - const logger = ctx.get('logger') - const ip = ctx.get('ip') - - const { method } = ctx.req - const requestUrl = new URL(ctx.req.url) - const { pathname } = requestUrl - const requestedToolIdentifier = pathname.replace(/^\//, '').replace(/\/$/, '') - const { toolName, deploymentIdentifier } = parseToolIdentifier( - requestedToolIdentifier - ) - - const deployment = await getAdminDeployment(ctx, deploymentIdentifier) - - const tool = getTool({ - method, - deployment, - toolName - }) - - logger.debug('request', { - method, - pathname, - deploymentIdentifier: deployment.identifier, - toolName, - tool - }) - - let pricingPlan: PricingPlan | undefined - let consumer: AdminConsumer | undefined - let reportUsage = ctx.get('reportUsage') ?? true - - const token = (ctx.req.header('authorization') || '') - .replace(/^Bearer /i, '') - .trim() - - if (token) { - consumer = await getAdminConsumer(ctx, token) - assert(consumer, 401, `Invalid auth token "${token}"`) - assert( - consumer.isStripeSubscriptionActive, - 402, - `Auth token "${token}" does not have an active subscription` - ) - assert( - consumer.projectId === deployment.projectId, - 403, - `Auth token "${token}" is not authorized for project "${deployment.projectId}"` - ) - - // TODO: Ensure that consumer.plan is compatible with the target deployment? - // TODO: This could definitely cause issues when changing pricing plans. - - pricingPlan = deployment.pricingPlans.find( - (pricingPlan) => consumer!.plan === pricingPlan.slug - ) - - // assert( - // pricingPlan, - // 403, - // `Auth token "${token}" unable to find matching pricing plan for project "${deployment.project}"` - // ) - - if (!ctx.get('sessionId')) { - ctx.set('sessionId', `${consumer.id}:${deployment.id}`) - } - } else { - // For unauthenticated requests, default to a free pricing plan if available. - pricingPlan = deployment.pricingPlans.find((plan) => plan.slug === 'free') - - if (!ctx.get('sessionId')) { - assert(ip, 500, 'IP address is required for unauthenticated requests') - ctx.set('sessionId', `${ip}:${deployment.projectId}`) - } - } - - let rateLimit: RateLimit | undefined | null - - // Resolve rate limit and whether to report `requests` usage based on the - // customer's pricing plan and deployment config. - if (pricingPlan) { - const requestsLineItem = pricingPlan.lineItems.find( - (lineItem) => lineItem.slug === 'requests' - ) - - if (requestsLineItem) { - assert( - requestsLineItem.slug === 'requests', - 403, - `Invalid pricing plan "${pricingPlan.slug}" for project "${deployment.project}"` - ) - - rateLimit = requestsLineItem.rateLimit - } else { - // No `requests` line-item, so we don't report usage for this tool. - reportUsage = false - } - } - - const toolConfig = deployment.toolConfigs.find( - (toolConfig) => toolConfig.name === tool.name - ) - - if (toolConfig) { - if (toolConfig.reportUsage !== undefined) { - reportUsage &&= !!toolConfig.reportUsage - } - - if (toolConfig.rateLimit !== undefined) { - // TODO: Improve RateLimitInput vs RateLimit types - rateLimit = toolConfig.rateLimit as RateLimit - } - - const pricingPlanToolConfig = pricingPlan - ? toolConfig.pricingPlanConfig?.[pricingPlan.slug] - : undefined - - if (pricingPlan && pricingPlanToolConfig) { - assert( - pricingPlanToolConfig.enabled || - (pricingPlanToolConfig.enabled === undefined && toolConfig.enabled), - 403, - `Tool "${tool.name}" is not enabled for pricing plan "${pricingPlan.slug}"` - ) - - if (pricingPlanToolConfig.reportUsage !== undefined) { - reportUsage &&= !!pricingPlanToolConfig.reportUsage - } - - if (pricingPlanToolConfig.rateLimit !== undefined) { - // TODO: Improve RateLimitInput vs RateLimit types - rateLimit = pricingPlanToolConfig.rateLimit as RateLimit - } - } else { - assert(toolConfig.enabled, 403, `Tool "${tool.name}" is not enabled`) - } - } - - ctx.set('reportUsage', reportUsage) - - if (rateLimit) { - await enforceRateLimit(ctx, { - id: consumer?.id ?? ip, - interval: rateLimit.interval, - maxPerInterval: rateLimit.maxPerInterval - }) - } - - const { originAdapter } = deployment - let toolCallArgs: ToolCallArgs | undefined - let originRequest: Request | undefined - let originMcpClient: DurableObjectStub | undefined - let originMcpRequestMetadata: AgenticMcpRequestMetadata | undefined - - if (originAdapter.type === 'raw') { - const originRequestUrl = `${deployment.originUrl}/${toolName}${requestUrl.search}` - originRequest = new Request(originRequestUrl, ctx.req.raw) - } else { - // Parse tool call args from the request body for both OpenAPI and MCP - // origin adapters. - toolCallArgs = await getToolArgsFromRequest(ctx, { - tool, - deployment - }) - } - - if (originAdapter.type === 'openapi') { - const operation = originAdapter.toolToOperationMap[tool.name] - assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`) - assert(toolCallArgs, 500) - - originRequest = await createRequestForOpenAPIOperation({ - request: ctx.req.raw, - toolCallArgs, - operation, - deployment - }) - } else if (originAdapter.type === 'mcp') { - const sessionId = ctx.get('sessionId') - assert(sessionId, 500, 'Session ID is required for MCP origin requests') - - const id: DurableObjectId = ctx.env.DO_MCP_CLIENT.idFromName(sessionId) - originMcpClient = ctx.env.DO_MCP_CLIENT.get(id) - - await originMcpClient.init({ - url: deployment.originUrl, - name: originAdapter.serverInfo.name, - version: originAdapter.serverInfo.version - }) - - const { projectIdentifier } = parseDeploymentIdentifier( - deployment.identifier, - { errorStatusCode: 500 } - ) - - originMcpRequestMetadata = { - agenticProxySecret: deployment._secret, - sessionId, - ip, - isCustomerSubscriptionActive: !!consumer?.isStripeSubscriptionActive, - customerId: consumer?.id, - customerSubscriptionPlan: consumer?.plan, - customerSubscriptionStatus: consumer?.stripeStatus, - userId: consumer?.user.id, - userEmail: consumer?.user.email, - userUsername: consumer?.user.username, - userName: consumer?.user.name, - userCreatedAt: consumer?.user.createdAt, - userUpdatedAt: consumer?.user.updatedAt, - deploymentId: deployment.id, - deploymentIdentifier: deployment.identifier, - projectId: deployment.projectId, - projectIdentifier - } as AgenticMcpRequestMetadata - } - - if (originRequest) { - logger.info('originRequestUrl', originRequest.url) - updateOriginRequest(originRequest, { consumer, deployment }) - } - - assert(ctx.get('sessionId'), 500, 'Internal error: sessionId should be set') - - return { - deployment, - consumer, - tool, - pricingPlan, - toolCallArgs, - originRequest, - originMcpClient, - originMcpRequestMetadata - } -} diff --git a/apps/gateway/src/lib/resolve-origin-tool-call.ts b/apps/gateway/src/lib/resolve-origin-tool-call.ts new file mode 100644 index 00000000..89a34c06 --- /dev/null +++ b/apps/gateway/src/lib/resolve-origin-tool-call.ts @@ -0,0 +1,240 @@ +import type { + AdminDeployment, + PricingPlan, + RateLimit, + Tool +} from '@agentic/platform-types' +import { assert } from '@agentic/platform-core' +import { parseDeploymentIdentifier } from '@agentic/platform-validators' + +import type { RawEnv } from './env' +import type { + AdminConsumer, + AgenticMcpRequestMetadata, + McpToolCallResponse, + ToolCallArgs +} from './types' +import { cfValidateJsonSchema } from './cf-validate-json-schema' +import { createRequestForOpenAPIOperation } from './create-request-for-openapi-operation' +import { enforceRateLimit } from './enforce-rate-limit' +import { fetchCache } from './fetch-cache' +import { getRequestCacheKey } from './get-request-cache-key' +import { updateOriginRequest } from './update-origin-request' + +// type State = { counter: number } + +export type ResolvedOriginToolCallResult = { + toolCallArgs: ToolCallArgs + originRequest?: Request + originResponse?: Response + toolCallResponse?: McpToolCallResponse +} & ( + | { + originRequest: Request + originResponse: Response + toolCallResponse?: never + } + | { + originRequest?: never + originResponse?: never + toolCallResponse: McpToolCallResponse + } +) + +export async function resolveOriginToolCall({ + tool, + args, + deployment, + consumer, + pricingPlan, + sessionId, + env, + ip, + waitUntil +}: { + tool: Tool + args?: ToolCallArgs + deployment: AdminDeployment + consumer?: AdminConsumer + pricingPlan?: PricingPlan + sessionId: string + env: RawEnv + ip?: string + waitUntil: (promise: Promise) => void +}): Promise { + // TODO: rate-limiting + // TODO: caching + // TODO: usage tracking / reporting + // TODO: all of this per-request logic should maybe be moved to a diff method + // since it's not specific to tool calls. eg, other MCP requests may still + // need to be rate-limited / cached / tracked / etc. + + const { originAdapter } = deployment + let rateLimit: RateLimit | undefined | null + let reportUsage = true + + // Resolve rate limit and whether to report `requests` usage based on the + // customer's pricing plan and deployment config. + if (pricingPlan) { + const requestsLineItem = pricingPlan.lineItems.find( + (lineItem) => lineItem.slug === 'requests' + ) + + if (requestsLineItem) { + assert( + requestsLineItem.slug === 'requests', + 403, + `Invalid pricing plan "${pricingPlan.slug}" for project "${deployment.project}"` + ) + + rateLimit = requestsLineItem.rateLimit + } else { + // No `requests` line-item, so we don't report usage for this tool. + reportUsage = false + } + } + + const toolConfig = deployment.toolConfigs.find( + (toolConfig) => toolConfig.name === tool.name + ) + + if (toolConfig) { + if (toolConfig.reportUsage !== undefined) { + reportUsage &&= !!toolConfig.reportUsage + } + + if (toolConfig.rateLimit !== undefined) { + // TODO: Improve RateLimitInput vs RateLimit types + rateLimit = toolConfig.rateLimit as RateLimit + } + + const pricingPlanToolConfig = pricingPlan + ? toolConfig.pricingPlanConfig?.[pricingPlan.slug] + : undefined + + if (pricingPlan && pricingPlanToolConfig) { + assert( + pricingPlanToolConfig.enabled || + (pricingPlanToolConfig.enabled === undefined && toolConfig.enabled), + 403, + `Tool "${tool.name}" is not enabled for pricing plan "${pricingPlan.slug}"` + ) + + if (pricingPlanToolConfig.reportUsage !== undefined) { + reportUsage &&= !!pricingPlanToolConfig.reportUsage + } + + if (pricingPlanToolConfig.rateLimit !== undefined) { + // TODO: Improve RateLimitInput vs RateLimit types + rateLimit = pricingPlanToolConfig.rateLimit as RateLimit + } + } else { + assert(toolConfig.enabled, 403, `Tool "${tool.name}" is not enabled`) + } + } + + if (rateLimit) { + await enforceRateLimit({ + id: consumer?.id ?? ip, + interval: rateLimit.interval, + maxPerInterval: rateLimit.maxPerInterval + }) + } + + if (originAdapter.type === 'raw') { + // TODO + assert(false, 500, 'Raw origin adapter not implemented') + } else { + // Validate incoming request params against the tool's input schema. + const toolCallArgs = cfValidateJsonSchema>({ + schema: tool.inputSchema, + data: args, + errorPrefix: `Invalid request parameters for tool "${tool.name}"`, + strictAdditionalProperties: true + }) + + if (originAdapter.type === 'openapi') { + const operation = originAdapter.toolToOperationMap[tool.name] + assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`) + assert(toolCallArgs, 500) + + const originRequest = await createRequestForOpenAPIOperation({ + toolCallArgs, + operation, + deployment + }) + + updateOriginRequest(originRequest, { consumer, deployment }) + + const cacheKey = await getRequestCacheKey(originRequest) + + // TODO: transform origin 5XX errors to 502 errors... + const originResponse = await fetchCache({ + cacheKey, + fetchResponse: () => fetch(originRequest), + waitUntil + }) + + // non-cached version + // const originResponse = await fetch(originRequest) + + return { + toolCallArgs, + originRequest, + originResponse + } + } else if (originAdapter.type === 'mcp') { + const id: DurableObjectId = env.DO_MCP_CLIENT.idFromName(sessionId) + const originMcpClient = env.DO_MCP_CLIENT.get(id) + + await originMcpClient.init({ + url: deployment.originUrl, + name: originAdapter.serverInfo.name, + version: originAdapter.serverInfo.version + }) + + const { projectIdentifier } = parseDeploymentIdentifier( + deployment.identifier, + { errorStatusCode: 500 } + ) + + const originMcpRequestMetadata = { + agenticProxySecret: deployment._secret, + sessionId, + // ip, + isCustomerSubscriptionActive: !!consumer?.isStripeSubscriptionActive, + customerId: consumer?.id, + customerSubscriptionPlan: consumer?.plan, + customerSubscriptionStatus: consumer?.stripeStatus, + userId: consumer?.user.id, + userEmail: consumer?.user.email, + userUsername: consumer?.user.username, + userName: consumer?.user.name, + userCreatedAt: consumer?.user.createdAt, + userUpdatedAt: consumer?.user.updatedAt, + deploymentId: deployment.id, + deploymentIdentifier: deployment.identifier, + projectId: deployment.projectId, + projectIdentifier + } as AgenticMcpRequestMetadata + + // TODO: add timeout support to the origin tool call? + // TODO: add response caching for origin MCP tool calls + const toolCallResponseString = await originMcpClient.callTool({ + name: tool.name, + args: toolCallArgs, + metadata: originMcpRequestMetadata! + }) + const toolCallResponse = JSON.parse( + toolCallResponseString + ) as McpToolCallResponse + + return { + toolCallArgs, + toolCallResponse + } + } else { + assert(false, 500) + } + } +} diff --git a/apps/gateway/src/lib/transform-http-response-to-mcp-tool-call-response.ts b/apps/gateway/src/lib/transform-http-response-to-mcp-tool-call-response.ts index 3d83aa6d..b92ef36b 100644 --- a/apps/gateway/src/lib/transform-http-response-to-mcp-tool-call-response.ts +++ b/apps/gateway/src/lib/transform-http-response-to-mcp-tool-call-response.ts @@ -74,7 +74,7 @@ export async function transformHttpResponseToMcpToolCallResponse({ coerce: false, // TODO: double-check MCP schema on whether additional properties are allowed strictAdditionalProperties: true, - errorMessage: `Invalid tool response for tool "${tool.name}"`, + errorPrefix: `Invalid tool response for tool "${tool.name}"`, errorStatusCode: 502 }) diff --git a/apps/gateway/src/lib/types.ts b/apps/gateway/src/lib/types.ts index 1e0fa449..eaf7d436 100644 --- a/apps/gateway/src/lib/types.ts +++ b/apps/gateway/src/lib/types.ts @@ -3,18 +3,11 @@ import type { DefaultHonoBindings, DefaultHonoVariables } from '@agentic/platform-hono' -import type { - AdminDeployment, - Consumer, - PricingPlan, - Tool, - User -} from '@agentic/platform-types' +import type { Consumer, User } from '@agentic/platform-types' import type { Client as McpClient } from '@modelcontextprotocol/sdk/client/index.js' import type { Context } from 'hono' import type { Simplify } from 'type-fest' -import type { DurableMcpClient } from './durable-mcp-client' import type { Env } from './env' export type McpToolCallResponse = Simplify< @@ -48,19 +41,6 @@ export type GatewayHonoContext = Context // TODO: better type here export type ToolCallArgs = Record -export type ResolvedOriginRequest = { - deployment: AdminDeployment - tool: Tool - - consumer?: AdminConsumer - pricingPlan?: PricingPlan - - toolCallArgs?: ToolCallArgs - originRequest?: Request - originMcpClient?: DurableObjectStub - originMcpRequestMetadata?: AgenticMcpRequestMetadata -} - export type AgenticMcpRequestMetadata = { agenticProxySecret: string sessionId: string