kopia lustrzana https://github.com/transitive-bullshit/chatgpt-api
feat: refactoring tractors
rodzic
b97a210602
commit
a47acd376d
|
@ -9,13 +9,12 @@ import {
|
||||||
import { parseToolIdentifier } from '@agentic/platform-validators'
|
import { parseToolIdentifier } from '@agentic/platform-validators'
|
||||||
import { Hono } from 'hono'
|
import { Hono } from 'hono'
|
||||||
|
|
||||||
import type { GatewayHonoEnv, McpToolCallResponse } from './lib/types'
|
import type { GatewayHonoEnv } from './lib/types'
|
||||||
import { createAgenticClient } from './lib/agentic-client'
|
import { createAgenticClient } from './lib/agentic-client'
|
||||||
import { createHttpResponseFromMcpToolCallResponse } from './lib/create-http-response-from-mcp-tool-call-response'
|
import { createHttpResponseFromMcpToolCallResponse } from './lib/create-http-response-from-mcp-tool-call-response'
|
||||||
import { fetchCache } from './lib/fetch-cache'
|
import { resolveHttpEdgeRequest } from './lib/resolve-http-edge-request'
|
||||||
import { getRequestCacheKey } from './lib/get-request-cache-key'
|
|
||||||
import { resolveMcpEdgeRequest } from './lib/resolve-mcp-edge-request'
|
import { resolveMcpEdgeRequest } from './lib/resolve-mcp-edge-request'
|
||||||
import { resolveOriginRequest } from './lib/resolve-origin-request'
|
import { resolveOriginToolCall } from './lib/resolve-origin-tool-call'
|
||||||
import { DurableMcpServer } from './worker'
|
import { DurableMcpServer } from './worker'
|
||||||
|
|
||||||
export const app = new Hono<GatewayHonoEnv>()
|
export const app = new Hono<GatewayHonoEnv>()
|
||||||
|
@ -67,69 +66,37 @@ app.all(async (ctx) => {
|
||||||
}).fetch(ctx.req.raw, ctx.env, executionCtx)
|
}).fetch(ctx.req.raw, ctx.env, executionCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
const resolvedOriginRequest = await resolveOriginRequest(ctx)
|
const resolvedEdgeRequest = await resolveHttpEdgeRequest(ctx)
|
||||||
|
|
||||||
const originStartTime = Date.now()
|
const originStartTime = Date.now()
|
||||||
|
|
||||||
|
const resolvedOriginToolCallResult = await resolveOriginToolCall({
|
||||||
|
tool: resolvedEdgeRequest.tool,
|
||||||
|
args: resolvedEdgeRequest.toolCallArgs,
|
||||||
|
deployment: resolvedEdgeRequest.deployment,
|
||||||
|
consumer: resolvedEdgeRequest.consumer,
|
||||||
|
pricingPlan: resolvedEdgeRequest.pricingPlan,
|
||||||
|
sessionId: ctx.get('sessionId')!,
|
||||||
|
ip: ctx.get('ip'),
|
||||||
|
env: ctx.env,
|
||||||
|
waitUntil: ctx.executionCtx.waitUntil
|
||||||
|
})
|
||||||
|
|
||||||
let originResponse: Response | undefined
|
let originResponse: Response | undefined
|
||||||
|
if (resolvedOriginToolCallResult.originResponse) {
|
||||||
switch (resolvedOriginRequest.deployment.originAdapter.type) {
|
originResponse = resolvedOriginToolCallResult.originResponse
|
||||||
case 'openapi':
|
} else {
|
||||||
case 'raw': {
|
originResponse = await createHttpResponseFromMcpToolCallResponse(ctx, {
|
||||||
assert(
|
tool: resolvedEdgeRequest.tool,
|
||||||
resolvedOriginRequest.originRequest,
|
deployment: resolvedEdgeRequest.deployment,
|
||||||
500,
|
toolCallResponse: resolvedOriginToolCallResult.toolCallResponse
|
||||||
'Origin request is required'
|
})
|
||||||
)
|
|
||||||
|
|
||||||
const cacheKey = await getRequestCacheKey(
|
|
||||||
ctx,
|
|
||||||
resolvedOriginRequest.originRequest
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: transform origin 5XX errors to 502 errors...
|
|
||||||
originResponse = await fetchCache(ctx, {
|
|
||||||
cacheKey,
|
|
||||||
fetchResponse: () => fetch(resolvedOriginRequest.originRequest!)
|
|
||||||
})
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
case 'mcp': {
|
|
||||||
assert(
|
|
||||||
resolvedOriginRequest.toolCallArgs,
|
|
||||||
500,
|
|
||||||
'Tool args are required for MCP origin requests'
|
|
||||||
)
|
|
||||||
assert(
|
|
||||||
resolvedOriginRequest.originMcpClient,
|
|
||||||
500,
|
|
||||||
'MCP client is required for MCP origin requests'
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: add timeout support to the origin tool call?
|
|
||||||
// TODO: add response caching for MCP tool calls
|
|
||||||
const toolCallResponseString =
|
|
||||||
await resolvedOriginRequest.originMcpClient.callTool({
|
|
||||||
name: resolvedOriginRequest.tool.name,
|
|
||||||
args: resolvedOriginRequest.toolCallArgs,
|
|
||||||
metadata: resolvedOriginRequest.originMcpRequestMetadata!
|
|
||||||
})
|
|
||||||
const toolCallResponse = JSON.parse(
|
|
||||||
toolCallResponseString
|
|
||||||
) as McpToolCallResponse
|
|
||||||
|
|
||||||
originResponse = await createHttpResponseFromMcpToolCallResponse(ctx, {
|
|
||||||
tool: resolvedOriginRequest.tool,
|
|
||||||
deployment: resolvedOriginRequest.deployment,
|
|
||||||
toolCallResponse
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(originResponse, 500, 'Origin response is required')
|
assert(originResponse, 500, 'Origin response is required')
|
||||||
const res = new Response(originResponse.body, originResponse)
|
const res = new Response(originResponse.body, originResponse)
|
||||||
|
|
||||||
// Record the time it took for both the origin and gateway to respond
|
// Record the time it took for the origin to respond.
|
||||||
const now = Date.now()
|
const now = Date.now()
|
||||||
const originTimespan = now - originStartTime
|
const originTimespan = now - originStartTime
|
||||||
res.headers.set('x-origin-response-time', `${originTimespan}ms`)
|
res.headers.set('x-origin-response-time', `${originTimespan}ms`)
|
||||||
|
|
|
@ -18,14 +18,14 @@ export function cfValidateJsonSchema<T = unknown>({
|
||||||
data,
|
data,
|
||||||
coerce = false,
|
coerce = false,
|
||||||
strictAdditionalProperties = false,
|
strictAdditionalProperties = false,
|
||||||
errorMessage,
|
errorPrefix,
|
||||||
errorStatusCode = 400
|
errorStatusCode = 400
|
||||||
}: {
|
}: {
|
||||||
schema: any
|
schema: any
|
||||||
data: unknown
|
data: unknown
|
||||||
coerce?: boolean
|
coerce?: boolean
|
||||||
strictAdditionalProperties?: boolean
|
strictAdditionalProperties?: boolean
|
||||||
errorMessage?: string
|
errorPrefix?: string
|
||||||
errorStatusCode?: number
|
errorStatusCode?: number
|
||||||
}): T {
|
}): T {
|
||||||
assert(schema, 400, '`schema` is required')
|
assert(schema, 400, '`schema` is required')
|
||||||
|
@ -37,7 +37,7 @@ export function cfValidateJsonSchema<T = unknown>({
|
||||||
if (isSchemaObject && !isDataObject) {
|
if (isSchemaObject && !isDataObject) {
|
||||||
throw new HttpError({
|
throw new HttpError({
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
message: `${errorMessage ? errorMessage + ': ' : ''}Data must be an object according to its schema.`
|
message: `${errorPrefix ? errorPrefix + ': ' : ''}Data must be an object according to its schema.`
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ export function cfValidateJsonSchema<T = unknown>({
|
||||||
if (missingRequiredFields.length > 0) {
|
if (missingRequiredFields.length > 0) {
|
||||||
throw new HttpError({
|
throw new HttpError({
|
||||||
statusCode: errorStatusCode,
|
statusCode: errorStatusCode,
|
||||||
message: `${errorMessage ? errorMessage + ': ' : ''}Missing required ${plur('parameter', missingRequiredFields.length)}: ${missingRequiredFields.map((field) => `"${field}"`).join(', ')}`
|
message: `${errorPrefix ? errorPrefix + ': ' : ''}Missing required ${plur('parameter', missingRequiredFields.length)}: ${missingRequiredFields.map((field) => `"${field}"`).join(', ')}`
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ export function cfValidateJsonSchema<T = unknown>({
|
||||||
if (extraProperties.length > 0) {
|
if (extraProperties.length > 0) {
|
||||||
throw new HttpError({
|
throw new HttpError({
|
||||||
statusCode: errorStatusCode,
|
statusCode: errorStatusCode,
|
||||||
message: `${errorMessage ? errorMessage + ': ' : ''}Unexpected additional ${plur('parameter', extraProperties.length)}: ${extraProperties.map((property) => `"${property}"`).join(', ')}`
|
message: `${errorPrefix ? errorPrefix + ': ' : ''}Unexpected additional ${plur('parameter', extraProperties.length)}: ${extraProperties.map((property) => `"${property}"`).join(', ')}`
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -93,7 +93,7 @@ export function cfValidateJsonSchema<T = unknown>({
|
||||||
}
|
}
|
||||||
|
|
||||||
const finalErrorMessage = `${
|
const finalErrorMessage = `${
|
||||||
errorMessage ? errorMessage + ': ' : ''
|
errorPrefix ? errorPrefix + ': ' : ''
|
||||||
}${result.errors
|
}${result.errors
|
||||||
.map(({ keyword, error }) => `keyword "${keyword}" error ${error}`)
|
.map(({ keyword, error }) => `keyword "${keyword}" error ${error}`)
|
||||||
.join(' ')}`
|
.join(' ')}`
|
||||||
|
|
|
@ -24,6 +24,7 @@ export async function createHttpResponseFromMcpToolCallResponse(
|
||||||
assert(
|
assert(
|
||||||
!toolCallResponse.isError,
|
!toolCallResponse.isError,
|
||||||
502,
|
502,
|
||||||
|
// TODO: add content or structuredContent to the error message
|
||||||
`MCP tool "${tool.name}" returned an error.`
|
`MCP tool "${tool.name}" returned an error.`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,7 +42,7 @@ export async function createHttpResponseFromMcpToolCallResponse(
|
||||||
coerce: false,
|
coerce: false,
|
||||||
// TODO: double-check MCP schema on whether additional properties are allowed
|
// TODO: double-check MCP schema on whether additional properties are allowed
|
||||||
strictAdditionalProperties: true,
|
strictAdditionalProperties: true,
|
||||||
errorMessage: `Invalid tool response for tool "${tool.name}"`,
|
errorPrefix: `Invalid tool response for tool "${tool.name}"`,
|
||||||
errorStatusCode: 502
|
errorStatusCode: 502
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -10,17 +10,9 @@ import {
|
||||||
import { McpAgent } from 'agents/mcp'
|
import { McpAgent } from 'agents/mcp'
|
||||||
|
|
||||||
import type { RawEnv } from './env'
|
import type { RawEnv } from './env'
|
||||||
import type {
|
import type { AdminConsumer } from './types'
|
||||||
AdminConsumer,
|
import { resolveOriginToolCall } from './resolve-origin-tool-call'
|
||||||
AgenticMcpRequestMetadata,
|
|
||||||
McpToolCallResponse
|
|
||||||
} from './types'
|
|
||||||
import { cfValidateJsonSchema } from './cf-validate-json-schema'
|
|
||||||
import { createRequestForOpenAPIOperation } from './create-request-for-openapi-operation'
|
|
||||||
import { transformHttpResponseToMcpToolCallResponse } from './transform-http-response-to-mcp-tool-call-response'
|
import { transformHttpResponseToMcpToolCallResponse } from './transform-http-response-to-mcp-tool-call-response'
|
||||||
// import { fetchCache } from './fetch-cache'
|
|
||||||
// import { getRequestCacheKey } from './get-request-cache-key'
|
|
||||||
import { updateOriginRequest } from './update-origin-request'
|
|
||||||
|
|
||||||
// type State = { counter: number }
|
// type State = { counter: number }
|
||||||
|
|
||||||
|
@ -31,6 +23,7 @@ export class DurableMcpServer extends McpAgent<
|
||||||
deployment: AdminDeployment
|
deployment: AdminDeployment
|
||||||
consumer?: AdminConsumer
|
consumer?: AdminConsumer
|
||||||
pricingPlan?: PricingPlan
|
pricingPlan?: PricingPlan
|
||||||
|
ip?: string
|
||||||
}
|
}
|
||||||
> {
|
> {
|
||||||
protected _serverP = Promise.withResolvers<Server>()
|
protected _serverP = Promise.withResolvers<Server>()
|
||||||
|
@ -41,8 +34,7 @@ export class DurableMcpServer extends McpAgent<
|
||||||
// }
|
// }
|
||||||
|
|
||||||
override async init() {
|
override async init() {
|
||||||
const { consumer, deployment, pricingPlan } = this.props
|
const { consumer, deployment, pricingPlan, ip } = this.props
|
||||||
const { originAdapter } = deployment
|
|
||||||
const { projectIdentifier } = parseDeploymentIdentifier(
|
const { projectIdentifier } = parseDeploymentIdentifier(
|
||||||
deployment.identifier
|
deployment.identifier
|
||||||
)
|
)
|
||||||
|
@ -102,106 +94,31 @@ export class DurableMcpServer extends McpAgent<
|
||||||
// TODO: caching
|
// TODO: caching
|
||||||
// TODO: usage tracking / reporting
|
// TODO: usage tracking / reporting
|
||||||
|
|
||||||
if (originAdapter.type === 'raw') {
|
const sessionId = this.ctx.id.toString()
|
||||||
// TODO
|
const { toolCallArgs, originRequest, originResponse, toolCallResponse } =
|
||||||
assert(false, 500, 'Raw origin adapter not implemented')
|
await resolveOriginToolCall({
|
||||||
} else {
|
tool,
|
||||||
// Validate incoming request params against the tool's input schema.
|
args,
|
||||||
const toolCallArgs = cfValidateJsonSchema<Record<string, any>>({
|
deployment,
|
||||||
schema: tool.inputSchema,
|
consumer,
|
||||||
data: args,
|
pricingPlan,
|
||||||
errorMessage: `Invalid request parameters for tool "${tool.name}"`,
|
sessionId,
|
||||||
strictAdditionalProperties: true
|
env: this.env,
|
||||||
|
ip,
|
||||||
|
waitUntil: this.ctx.waitUntil
|
||||||
})
|
})
|
||||||
|
|
||||||
if (originAdapter.type === 'openapi') {
|
if (originResponse) {
|
||||||
const operation = originAdapter.toolToOperationMap[tool.name]
|
return transformHttpResponseToMcpToolCallResponse({
|
||||||
assert(
|
originRequest,
|
||||||
operation,
|
originResponse,
|
||||||
404,
|
tool,
|
||||||
`Tool "${tool.name}" not found in OpenAPI spec`
|
toolCallArgs
|
||||||
)
|
})
|
||||||
assert(toolCallArgs, 500)
|
} else if (toolCallResponse) {
|
||||||
|
return toolCallResponse
|
||||||
const originRequest = await createRequestForOpenAPIOperation({
|
} else {
|
||||||
toolCallArgs,
|
assert(false, 500)
|
||||||
operation,
|
|
||||||
deployment
|
|
||||||
})
|
|
||||||
|
|
||||||
updateOriginRequest(originRequest, { consumer, deployment })
|
|
||||||
|
|
||||||
// TODO: re-add caching support
|
|
||||||
// const cacheKey = await getRequestCacheKey(ctx, originRequest)
|
|
||||||
|
|
||||||
// // TODO: transform origin 5XX errors to 502 errors...
|
|
||||||
// // TODO: fetch origin request and transform response
|
|
||||||
// const originResponse = await fetchCache(ctx, {
|
|
||||||
// cacheKey,
|
|
||||||
// fetchResponse: () => fetch(originRequest)
|
|
||||||
// })
|
|
||||||
|
|
||||||
const originResponse = await fetch(originRequest)
|
|
||||||
|
|
||||||
return transformHttpResponseToMcpToolCallResponse({
|
|
||||||
originRequest,
|
|
||||||
originResponse,
|
|
||||||
tool,
|
|
||||||
toolCallArgs
|
|
||||||
})
|
|
||||||
} else if (originAdapter.type === 'mcp') {
|
|
||||||
const sessionId = this.ctx.id.toString()
|
|
||||||
const id: DurableObjectId =
|
|
||||||
this.env.DO_MCP_CLIENT.idFromName(sessionId)
|
|
||||||
const originMcpClient = this.env.DO_MCP_CLIENT.get(id)
|
|
||||||
|
|
||||||
await originMcpClient.init({
|
|
||||||
url: deployment.originUrl,
|
|
||||||
name: originAdapter.serverInfo.name,
|
|
||||||
version: originAdapter.serverInfo.version
|
|
||||||
})
|
|
||||||
|
|
||||||
const { projectIdentifier } = parseDeploymentIdentifier(
|
|
||||||
deployment.identifier,
|
|
||||||
{ errorStatusCode: 500 }
|
|
||||||
)
|
|
||||||
|
|
||||||
const originMcpRequestMetadata = {
|
|
||||||
agenticProxySecret: deployment._secret,
|
|
||||||
sessionId,
|
|
||||||
// ip,
|
|
||||||
isCustomerSubscriptionActive:
|
|
||||||
!!consumer?.isStripeSubscriptionActive,
|
|
||||||
customerId: consumer?.id,
|
|
||||||
customerSubscriptionPlan: consumer?.plan,
|
|
||||||
customerSubscriptionStatus: consumer?.stripeStatus,
|
|
||||||
userId: consumer?.user.id,
|
|
||||||
userEmail: consumer?.user.email,
|
|
||||||
userUsername: consumer?.user.username,
|
|
||||||
userName: consumer?.user.name,
|
|
||||||
userCreatedAt: consumer?.user.createdAt,
|
|
||||||
userUpdatedAt: consumer?.user.updatedAt,
|
|
||||||
deploymentId: deployment.id,
|
|
||||||
deploymentIdentifier: deployment.identifier,
|
|
||||||
projectId: deployment.projectId,
|
|
||||||
projectIdentifier
|
|
||||||
} as AgenticMcpRequestMetadata
|
|
||||||
|
|
||||||
// TODO: add timeout support to the origin tool call?
|
|
||||||
// TODO: add response caching for MCP tool calls
|
|
||||||
const toolCallResponseString = await originMcpClient.callTool({
|
|
||||||
name: tool.name,
|
|
||||||
args: toolCallArgs,
|
|
||||||
metadata: originMcpRequestMetadata!
|
|
||||||
})
|
|
||||||
const toolCallResponse = JSON.parse(
|
|
||||||
toolCallResponseString
|
|
||||||
) as McpToolCallResponse
|
|
||||||
|
|
||||||
return toolCallResponse
|
|
||||||
} else {
|
|
||||||
assert(false, 500)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,27 +1,21 @@
|
||||||
import { assert } from '@agentic/platform-core'
|
import { assert } from '@agentic/platform-core'
|
||||||
|
|
||||||
import type { GatewayHonoContext } from './types'
|
|
||||||
|
|
||||||
// https://developers.cloudflare.com/durable-objects/examples/build-a-rate-limiter/
|
// https://developers.cloudflare.com/durable-objects/examples/build-a-rate-limiter/
|
||||||
// https://github.com/rhinobase/hono-rate-limiter/blob/main/packages/cloudflare/src/stores/DurableObjectStore.ts
|
// https://github.com/rhinobase/hono-rate-limiter/blob/main/packages/cloudflare/src/stores/DurableObjectStore.ts
|
||||||
// https://github.com/rhinobase/hono-rate-limiter/blob/main/packages/core/src/core.ts
|
// https://github.com/rhinobase/hono-rate-limiter/blob/main/packages/core/src/core.ts
|
||||||
|
|
||||||
export async function enforceRateLimit(
|
export async function enforceRateLimit({
|
||||||
ctx: GatewayHonoContext,
|
id,
|
||||||
{
|
interval,
|
||||||
id,
|
maxPerInterval
|
||||||
interval,
|
}: {
|
||||||
maxPerInterval
|
id?: string
|
||||||
}: {
|
interval: number
|
||||||
id?: string
|
maxPerInterval: number
|
||||||
interval: number
|
}) {
|
||||||
maxPerInterval: number
|
|
||||||
}
|
|
||||||
) {
|
|
||||||
assert(id, 400, 'Unauthenticated requests must have a valid IP address')
|
assert(id, 400, 'Unauthenticated requests must have a valid IP address')
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
assert(ctx, 500, 'not implemented')
|
|
||||||
assert(id, 500, 'not implemented')
|
assert(id, 500, 'not implemented')
|
||||||
assert(interval > 0, 500, 'not implemented')
|
assert(interval > 0, 500, 'not implemented')
|
||||||
assert(maxPerInterval >= 0, 500, 'not implemented')
|
assert(maxPerInterval >= 0, 500, 'not implemented')
|
||||||
|
|
|
@ -1,17 +1,13 @@
|
||||||
import type { GatewayHonoContext } from './types'
|
export async function fetchCache({
|
||||||
|
cacheKey,
|
||||||
export async function fetchCache(
|
fetchResponse,
|
||||||
ctx: GatewayHonoContext,
|
waitUntil
|
||||||
{
|
}: {
|
||||||
cacheKey,
|
cacheKey?: Request
|
||||||
fetchResponse
|
fetchResponse: () => Promise<Response>
|
||||||
}: {
|
waitUntil: (promise: Promise<any>) => void
|
||||||
cacheKey?: Request
|
}): Promise<Response> {
|
||||||
fetchResponse: () => Promise<Response>
|
const cache = caches.default
|
||||||
}
|
|
||||||
): Promise<Response> {
|
|
||||||
const cache = ctx.get('cache')
|
|
||||||
const logger = ctx.get('logger')
|
|
||||||
let response: Response | undefined
|
let response: Response | undefined
|
||||||
|
|
||||||
if (cacheKey) {
|
if (cacheKey) {
|
||||||
|
@ -25,9 +21,10 @@ export async function fetchCache(
|
||||||
if (cacheKey) {
|
if (cacheKey) {
|
||||||
if (response.headers.has('Cache-Control')) {
|
if (response.headers.has('Cache-Control')) {
|
||||||
// Note that cloudflare's `cache` should respect response headers.
|
// Note that cloudflare's `cache` should respect response headers.
|
||||||
ctx.executionCtx.waitUntil(
|
waitUntil(
|
||||||
cache.put(cacheKey, response.clone()).catch((err) => {
|
cache.put(cacheKey, response.clone()).catch((err) => {
|
||||||
logger.warn('cache put error', cacheKey, err)
|
// eslint-disable-next-line no-console
|
||||||
|
console.warn('cache put error', cacheKey, err)
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
import { hashObject, sha256 } from '@agentic/platform-core'
|
import { hashObject, sha256 } from '@agentic/platform-core'
|
||||||
import contentType from 'fast-content-type-parse'
|
import contentType from 'fast-content-type-parse'
|
||||||
|
|
||||||
import type { GatewayHonoContext } from './types'
|
|
||||||
import { normalizeUrl } from './normalize-url'
|
import { normalizeUrl } from './normalize-url'
|
||||||
|
|
||||||
// TODO: what is a reasonable upper bound for hashing the POST body size?
|
// TODO: what is a reasonable upper bound for hashing the POST body size?
|
||||||
const MAX_POST_BODY_SIZE_BYTES = 10_000
|
const MAX_POST_BODY_SIZE_BYTES = 10_000
|
||||||
|
|
||||||
export async function getRequestCacheKey(
|
export async function getRequestCacheKey(
|
||||||
ctx: GatewayHonoContext,
|
|
||||||
request: Request
|
request: Request
|
||||||
): Promise<Request | undefined> {
|
): Promise<Request | undefined> {
|
||||||
try {
|
try {
|
||||||
|
@ -85,8 +83,13 @@ export async function getRequestCacheKey(
|
||||||
|
|
||||||
return normalizeRequestHeaders(new Request(request))
|
return normalizeRequestHeaders(new Request(request))
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
const logger = ctx.get('logger')
|
// eslint-disable-next-line no-console
|
||||||
logger.error('error computing cache key', request.method, request.url, err)
|
console.warn(
|
||||||
|
'warning: failed to compute cache key',
|
||||||
|
request.method,
|
||||||
|
request.url,
|
||||||
|
err
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,23 +21,30 @@ export async function getToolArgsFromRequest(
|
||||||
`Internal logic error for origin adapter type "${deployment.originAdapter.type}"`
|
`Internal logic error for origin adapter type "${deployment.originAdapter.type}"`
|
||||||
)
|
)
|
||||||
|
|
||||||
let incomingRequestArgsRaw: Record<string, any> = {}
|
|
||||||
let coerceRequestArgs = false
|
|
||||||
|
|
||||||
if (request.method === 'GET') {
|
if (request.method === 'GET') {
|
||||||
// Args will be coerced to match their expected types via
|
// Args will be coerced to match their expected types via
|
||||||
// `cfValidateJsonSchemaObject` since all values will be strings.
|
// `cfValidateJsonSchemaObject` since all values will be strings.
|
||||||
incomingRequestArgsRaw = Object.fromEntries(
|
const incomingRequestArgsRaw = Object.fromEntries(
|
||||||
new URL(request.url).searchParams.entries()
|
new URL(request.url).searchParams.entries()
|
||||||
)
|
)
|
||||||
coerceRequestArgs = true
|
|
||||||
|
// Validate incoming request params against the tool's input schema.
|
||||||
|
const incomingRequestArgs = cfValidateJsonSchema<Record<string, any>>({
|
||||||
|
schema: tool.inputSchema,
|
||||||
|
data: incomingRequestArgsRaw,
|
||||||
|
errorPrefix: `Invalid request parameters for tool "${tool.name}"`,
|
||||||
|
coerce: true,
|
||||||
|
strictAdditionalProperties: true
|
||||||
|
})
|
||||||
|
|
||||||
|
return incomingRequestArgs
|
||||||
} else if (request.method === 'POST') {
|
} else if (request.method === 'POST') {
|
||||||
incomingRequestArgsRaw = (await request.clone().json()) as Record<
|
const incomingRequestArgsRaw = (await request.clone().json()) as Record<
|
||||||
string,
|
string,
|
||||||
any
|
any
|
||||||
>
|
>
|
||||||
|
|
||||||
// TODO: Support empty params for POST requests
|
// TODO: Proper support for empty params with POST requests
|
||||||
assert(incomingRequestArgsRaw, 400, 'Invalid empty request body')
|
assert(incomingRequestArgsRaw, 400, 'Invalid empty request body')
|
||||||
assert(
|
assert(
|
||||||
typeof incomingRequestArgsRaw === 'object',
|
typeof incomingRequestArgsRaw === 'object',
|
||||||
|
@ -45,16 +52,8 @@ export async function getToolArgsFromRequest(
|
||||||
'Invalid request body'
|
'Invalid request body'
|
||||||
)
|
)
|
||||||
assert(!Array.isArray(incomingRequestArgsRaw), 400, 'Invalid request body')
|
assert(!Array.isArray(incomingRequestArgsRaw), 400, 'Invalid request body')
|
||||||
|
return incomingRequestArgsRaw
|
||||||
|
} else {
|
||||||
|
assert(false, 405, `HTTP method "${request.method}" not allowed`)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate incoming request params against the tool's input schema.
|
|
||||||
const incomingRequestArgs = cfValidateJsonSchema<Record<string, any>>({
|
|
||||||
schema: tool.inputSchema,
|
|
||||||
data: incomingRequestArgsRaw,
|
|
||||||
errorMessage: `Invalid request parameters for tool "${tool.name}"`,
|
|
||||||
coerce: coerceRequestArgs,
|
|
||||||
strictAdditionalProperties: true
|
|
||||||
})
|
|
||||||
|
|
||||||
return incomingRequestArgs
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
import type {
|
||||||
|
AdminDeployment,
|
||||||
|
PricingPlan,
|
||||||
|
Tool
|
||||||
|
} from '@agentic/platform-types'
|
||||||
|
import { assert } from '@agentic/platform-core'
|
||||||
|
import { parseToolIdentifier } from '@agentic/platform-validators'
|
||||||
|
|
||||||
|
import type { AdminConsumer, GatewayHonoContext, ToolCallArgs } from './types'
|
||||||
|
import { getAdminConsumer } from './get-admin-consumer'
|
||||||
|
import { getAdminDeployment } from './get-admin-deployment'
|
||||||
|
import { getTool } from './get-tool'
|
||||||
|
import { getToolArgsFromRequest } from './get-tool-args-from-request'
|
||||||
|
|
||||||
|
export type ResolvedHttpEdgeRequest = {
|
||||||
|
deployment: AdminDeployment
|
||||||
|
consumer?: AdminConsumer
|
||||||
|
pricingPlan?: PricingPlan
|
||||||
|
|
||||||
|
tool: Tool
|
||||||
|
toolCallArgs: ToolCallArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resolves an input HTTP request to a specific deployment, tool call, and
|
||||||
|
* billing subscription.
|
||||||
|
*
|
||||||
|
* Also ensures that the request is valid, enforces rate limits, and adds proxy-
|
||||||
|
* specific headers to the origin request.
|
||||||
|
*/
|
||||||
|
export async function resolveHttpEdgeRequest(
|
||||||
|
ctx: GatewayHonoContext
|
||||||
|
): Promise<ResolvedHttpEdgeRequest> {
|
||||||
|
const logger = ctx.get('logger')
|
||||||
|
const ip = ctx.get('ip')
|
||||||
|
|
||||||
|
const { method } = ctx.req
|
||||||
|
const requestUrl = new URL(ctx.req.url)
|
||||||
|
const { pathname } = requestUrl
|
||||||
|
const requestedToolIdentifier = pathname.replace(/^\//, '').replace(/\/$/, '')
|
||||||
|
const { toolName, deploymentIdentifier } = parseToolIdentifier(
|
||||||
|
requestedToolIdentifier
|
||||||
|
)
|
||||||
|
|
||||||
|
const deployment = await getAdminDeployment(ctx, deploymentIdentifier)
|
||||||
|
|
||||||
|
const tool = getTool({
|
||||||
|
method,
|
||||||
|
deployment,
|
||||||
|
toolName
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.debug('request', {
|
||||||
|
method,
|
||||||
|
pathname,
|
||||||
|
deploymentIdentifier: deployment.identifier,
|
||||||
|
toolName,
|
||||||
|
tool
|
||||||
|
})
|
||||||
|
|
||||||
|
let pricingPlan: PricingPlan | undefined
|
||||||
|
let consumer: AdminConsumer | undefined
|
||||||
|
|
||||||
|
const token = (ctx.req.header('authorization') || '')
|
||||||
|
.replace(/^Bearer /i, '')
|
||||||
|
.trim()
|
||||||
|
|
||||||
|
if (token) {
|
||||||
|
consumer = await getAdminConsumer(ctx, token)
|
||||||
|
assert(consumer, 401, `Invalid auth token "${token}"`)
|
||||||
|
assert(
|
||||||
|
consumer.isStripeSubscriptionActive,
|
||||||
|
402,
|
||||||
|
`Auth token "${token}" does not have an active subscription`
|
||||||
|
)
|
||||||
|
assert(
|
||||||
|
consumer.projectId === deployment.projectId,
|
||||||
|
403,
|
||||||
|
`Auth token "${token}" is not authorized for project "${deployment.projectId}"`
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: Ensure that consumer.plan is compatible with the target deployment?
|
||||||
|
// TODO: This could definitely cause issues when changing pricing plans.
|
||||||
|
|
||||||
|
pricingPlan = deployment.pricingPlans.find(
|
||||||
|
(pricingPlan) => consumer!.plan === pricingPlan.slug
|
||||||
|
)
|
||||||
|
|
||||||
|
// assert(
|
||||||
|
// pricingPlan,
|
||||||
|
// 403,
|
||||||
|
// `Auth token "${token}" unable to find matching pricing plan for project "${deployment.project}"`
|
||||||
|
// )
|
||||||
|
|
||||||
|
if (!ctx.get('sessionId')) {
|
||||||
|
ctx.set('sessionId', `${consumer.id}:${deployment.id}`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For unauthenticated requests, default to a free pricing plan if available.
|
||||||
|
pricingPlan = deployment.pricingPlans.find((plan) => plan.slug === 'free')
|
||||||
|
|
||||||
|
if (!ctx.get('sessionId')) {
|
||||||
|
assert(ip, 500, 'IP address is required for unauthenticated requests')
|
||||||
|
ctx.set('sessionId', `${ip}:${deployment.projectId}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(ctx.get('sessionId'), 500, 'Internal error: sessionId should be set')
|
||||||
|
|
||||||
|
// Parse tool call args from the request body.
|
||||||
|
const toolCallArgs = await getToolArgsFromRequest(ctx, { tool, deployment })
|
||||||
|
|
||||||
|
return {
|
||||||
|
deployment,
|
||||||
|
consumer,
|
||||||
|
pricingPlan,
|
||||||
|
tool,
|
||||||
|
toolCallArgs
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,11 +6,16 @@ import type { AdminConsumer, GatewayHonoContext } from './types'
|
||||||
import { getAdminConsumer } from './get-admin-consumer'
|
import { getAdminConsumer } from './get-admin-consumer'
|
||||||
import { getAdminDeployment } from './get-admin-deployment'
|
import { getAdminDeployment } from './get-admin-deployment'
|
||||||
|
|
||||||
export async function resolveMcpEdgeRequest(ctx: GatewayHonoContext): Promise<{
|
export type ResolvedMcpEdgeRequest = {
|
||||||
deployment: AdminDeployment
|
deployment: AdminDeployment
|
||||||
consumer?: AdminConsumer
|
consumer?: AdminConsumer
|
||||||
pricingPlan?: PricingPlan
|
pricingPlan?: PricingPlan
|
||||||
}> {
|
ip?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function resolveMcpEdgeRequest(
|
||||||
|
ctx: GatewayHonoContext
|
||||||
|
): Promise<ResolvedMcpEdgeRequest> {
|
||||||
const requestUrl = new URL(ctx.req.url)
|
const requestUrl = new URL(ctx.req.url)
|
||||||
const { pathname } = requestUrl
|
const { pathname } = requestUrl
|
||||||
const requestedDeploymentIdentifier = pathname
|
const requestedDeploymentIdentifier = pathname
|
||||||
|
@ -60,6 +65,7 @@ export async function resolveMcpEdgeRequest(ctx: GatewayHonoContext): Promise<{
|
||||||
return {
|
return {
|
||||||
deployment,
|
deployment,
|
||||||
consumer,
|
consumer,
|
||||||
pricingPlan
|
pricingPlan,
|
||||||
|
ip: ctx.get('ip')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,266 +0,0 @@
|
||||||
import type { PricingPlan, RateLimit } from '@agentic/platform-types'
|
|
||||||
import { assert } from '@agentic/platform-core'
|
|
||||||
import {
|
|
||||||
parseDeploymentIdentifier,
|
|
||||||
parseToolIdentifier
|
|
||||||
} from '@agentic/platform-validators'
|
|
||||||
|
|
||||||
import type { DurableMcpClient } from './durable-mcp-client'
|
|
||||||
import type {
|
|
||||||
AdminConsumer,
|
|
||||||
AgenticMcpRequestMetadata,
|
|
||||||
GatewayHonoContext,
|
|
||||||
ResolvedOriginRequest,
|
|
||||||
ToolCallArgs
|
|
||||||
} from './types'
|
|
||||||
import { createRequestForOpenAPIOperation } from './create-request-for-openapi-operation'
|
|
||||||
import { enforceRateLimit } from './enforce-rate-limit'
|
|
||||||
import { getAdminConsumer } from './get-admin-consumer'
|
|
||||||
import { getAdminDeployment } from './get-admin-deployment'
|
|
||||||
import { getTool } from './get-tool'
|
|
||||||
import { getToolArgsFromRequest } from './get-tool-args-from-request'
|
|
||||||
import { updateOriginRequest } from './update-origin-request'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Resolves an input HTTP request to a specific deployment, tool call, and
|
|
||||||
* billing subscription.
|
|
||||||
*
|
|
||||||
* Also ensures that the request is valid, enforces rate limits, and adds proxy-
|
|
||||||
* specific headers to the origin request.
|
|
||||||
*/
|
|
||||||
export async function resolveOriginRequest(
|
|
||||||
ctx: GatewayHonoContext
|
|
||||||
): Promise<ResolvedOriginRequest> {
|
|
||||||
const logger = ctx.get('logger')
|
|
||||||
const ip = ctx.get('ip')
|
|
||||||
|
|
||||||
const { method } = ctx.req
|
|
||||||
const requestUrl = new URL(ctx.req.url)
|
|
||||||
const { pathname } = requestUrl
|
|
||||||
const requestedToolIdentifier = pathname.replace(/^\//, '').replace(/\/$/, '')
|
|
||||||
const { toolName, deploymentIdentifier } = parseToolIdentifier(
|
|
||||||
requestedToolIdentifier
|
|
||||||
)
|
|
||||||
|
|
||||||
const deployment = await getAdminDeployment(ctx, deploymentIdentifier)
|
|
||||||
|
|
||||||
const tool = getTool({
|
|
||||||
method,
|
|
||||||
deployment,
|
|
||||||
toolName
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.debug('request', {
|
|
||||||
method,
|
|
||||||
pathname,
|
|
||||||
deploymentIdentifier: deployment.identifier,
|
|
||||||
toolName,
|
|
||||||
tool
|
|
||||||
})
|
|
||||||
|
|
||||||
let pricingPlan: PricingPlan | undefined
|
|
||||||
let consumer: AdminConsumer | undefined
|
|
||||||
let reportUsage = ctx.get('reportUsage') ?? true
|
|
||||||
|
|
||||||
const token = (ctx.req.header('authorization') || '')
|
|
||||||
.replace(/^Bearer /i, '')
|
|
||||||
.trim()
|
|
||||||
|
|
||||||
if (token) {
|
|
||||||
consumer = await getAdminConsumer(ctx, token)
|
|
||||||
assert(consumer, 401, `Invalid auth token "${token}"`)
|
|
||||||
assert(
|
|
||||||
consumer.isStripeSubscriptionActive,
|
|
||||||
402,
|
|
||||||
`Auth token "${token}" does not have an active subscription`
|
|
||||||
)
|
|
||||||
assert(
|
|
||||||
consumer.projectId === deployment.projectId,
|
|
||||||
403,
|
|
||||||
`Auth token "${token}" is not authorized for project "${deployment.projectId}"`
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: Ensure that consumer.plan is compatible with the target deployment?
|
|
||||||
// TODO: This could definitely cause issues when changing pricing plans.
|
|
||||||
|
|
||||||
pricingPlan = deployment.pricingPlans.find(
|
|
||||||
(pricingPlan) => consumer!.plan === pricingPlan.slug
|
|
||||||
)
|
|
||||||
|
|
||||||
// assert(
|
|
||||||
// pricingPlan,
|
|
||||||
// 403,
|
|
||||||
// `Auth token "${token}" unable to find matching pricing plan for project "${deployment.project}"`
|
|
||||||
// )
|
|
||||||
|
|
||||||
if (!ctx.get('sessionId')) {
|
|
||||||
ctx.set('sessionId', `${consumer.id}:${deployment.id}`)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// For unauthenticated requests, default to a free pricing plan if available.
|
|
||||||
pricingPlan = deployment.pricingPlans.find((plan) => plan.slug === 'free')
|
|
||||||
|
|
||||||
if (!ctx.get('sessionId')) {
|
|
||||||
assert(ip, 500, 'IP address is required for unauthenticated requests')
|
|
||||||
ctx.set('sessionId', `${ip}:${deployment.projectId}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let rateLimit: RateLimit | undefined | null
|
|
||||||
|
|
||||||
// Resolve rate limit and whether to report `requests` usage based on the
|
|
||||||
// customer's pricing plan and deployment config.
|
|
||||||
if (pricingPlan) {
|
|
||||||
const requestsLineItem = pricingPlan.lineItems.find(
|
|
||||||
(lineItem) => lineItem.slug === 'requests'
|
|
||||||
)
|
|
||||||
|
|
||||||
if (requestsLineItem) {
|
|
||||||
assert(
|
|
||||||
requestsLineItem.slug === 'requests',
|
|
||||||
403,
|
|
||||||
`Invalid pricing plan "${pricingPlan.slug}" for project "${deployment.project}"`
|
|
||||||
)
|
|
||||||
|
|
||||||
rateLimit = requestsLineItem.rateLimit
|
|
||||||
} else {
|
|
||||||
// No `requests` line-item, so we don't report usage for this tool.
|
|
||||||
reportUsage = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const toolConfig = deployment.toolConfigs.find(
|
|
||||||
(toolConfig) => toolConfig.name === tool.name
|
|
||||||
)
|
|
||||||
|
|
||||||
if (toolConfig) {
|
|
||||||
if (toolConfig.reportUsage !== undefined) {
|
|
||||||
reportUsage &&= !!toolConfig.reportUsage
|
|
||||||
}
|
|
||||||
|
|
||||||
if (toolConfig.rateLimit !== undefined) {
|
|
||||||
// TODO: Improve RateLimitInput vs RateLimit types
|
|
||||||
rateLimit = toolConfig.rateLimit as RateLimit
|
|
||||||
}
|
|
||||||
|
|
||||||
const pricingPlanToolConfig = pricingPlan
|
|
||||||
? toolConfig.pricingPlanConfig?.[pricingPlan.slug]
|
|
||||||
: undefined
|
|
||||||
|
|
||||||
if (pricingPlan && pricingPlanToolConfig) {
|
|
||||||
assert(
|
|
||||||
pricingPlanToolConfig.enabled ||
|
|
||||||
(pricingPlanToolConfig.enabled === undefined && toolConfig.enabled),
|
|
||||||
403,
|
|
||||||
`Tool "${tool.name}" is not enabled for pricing plan "${pricingPlan.slug}"`
|
|
||||||
)
|
|
||||||
|
|
||||||
if (pricingPlanToolConfig.reportUsage !== undefined) {
|
|
||||||
reportUsage &&= !!pricingPlanToolConfig.reportUsage
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pricingPlanToolConfig.rateLimit !== undefined) {
|
|
||||||
// TODO: Improve RateLimitInput vs RateLimit types
|
|
||||||
rateLimit = pricingPlanToolConfig.rateLimit as RateLimit
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert(toolConfig.enabled, 403, `Tool "${tool.name}" is not enabled`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.set('reportUsage', reportUsage)
|
|
||||||
|
|
||||||
if (rateLimit) {
|
|
||||||
await enforceRateLimit(ctx, {
|
|
||||||
id: consumer?.id ?? ip,
|
|
||||||
interval: rateLimit.interval,
|
|
||||||
maxPerInterval: rateLimit.maxPerInterval
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const { originAdapter } = deployment
|
|
||||||
let toolCallArgs: ToolCallArgs | undefined
|
|
||||||
let originRequest: Request | undefined
|
|
||||||
let originMcpClient: DurableObjectStub<DurableMcpClient> | undefined
|
|
||||||
let originMcpRequestMetadata: AgenticMcpRequestMetadata | undefined
|
|
||||||
|
|
||||||
if (originAdapter.type === 'raw') {
|
|
||||||
const originRequestUrl = `${deployment.originUrl}/${toolName}${requestUrl.search}`
|
|
||||||
originRequest = new Request(originRequestUrl, ctx.req.raw)
|
|
||||||
} else {
|
|
||||||
// Parse tool call args from the request body for both OpenAPI and MCP
|
|
||||||
// origin adapters.
|
|
||||||
toolCallArgs = await getToolArgsFromRequest(ctx, {
|
|
||||||
tool,
|
|
||||||
deployment
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if (originAdapter.type === 'openapi') {
|
|
||||||
const operation = originAdapter.toolToOperationMap[tool.name]
|
|
||||||
assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`)
|
|
||||||
assert(toolCallArgs, 500)
|
|
||||||
|
|
||||||
originRequest = await createRequestForOpenAPIOperation({
|
|
||||||
request: ctx.req.raw,
|
|
||||||
toolCallArgs,
|
|
||||||
operation,
|
|
||||||
deployment
|
|
||||||
})
|
|
||||||
} else if (originAdapter.type === 'mcp') {
|
|
||||||
const sessionId = ctx.get('sessionId')
|
|
||||||
assert(sessionId, 500, 'Session ID is required for MCP origin requests')
|
|
||||||
|
|
||||||
const id: DurableObjectId = ctx.env.DO_MCP_CLIENT.idFromName(sessionId)
|
|
||||||
originMcpClient = ctx.env.DO_MCP_CLIENT.get(id)
|
|
||||||
|
|
||||||
await originMcpClient.init({
|
|
||||||
url: deployment.originUrl,
|
|
||||||
name: originAdapter.serverInfo.name,
|
|
||||||
version: originAdapter.serverInfo.version
|
|
||||||
})
|
|
||||||
|
|
||||||
const { projectIdentifier } = parseDeploymentIdentifier(
|
|
||||||
deployment.identifier,
|
|
||||||
{ errorStatusCode: 500 }
|
|
||||||
)
|
|
||||||
|
|
||||||
originMcpRequestMetadata = {
|
|
||||||
agenticProxySecret: deployment._secret,
|
|
||||||
sessionId,
|
|
||||||
ip,
|
|
||||||
isCustomerSubscriptionActive: !!consumer?.isStripeSubscriptionActive,
|
|
||||||
customerId: consumer?.id,
|
|
||||||
customerSubscriptionPlan: consumer?.plan,
|
|
||||||
customerSubscriptionStatus: consumer?.stripeStatus,
|
|
||||||
userId: consumer?.user.id,
|
|
||||||
userEmail: consumer?.user.email,
|
|
||||||
userUsername: consumer?.user.username,
|
|
||||||
userName: consumer?.user.name,
|
|
||||||
userCreatedAt: consumer?.user.createdAt,
|
|
||||||
userUpdatedAt: consumer?.user.updatedAt,
|
|
||||||
deploymentId: deployment.id,
|
|
||||||
deploymentIdentifier: deployment.identifier,
|
|
||||||
projectId: deployment.projectId,
|
|
||||||
projectIdentifier
|
|
||||||
} as AgenticMcpRequestMetadata
|
|
||||||
}
|
|
||||||
|
|
||||||
if (originRequest) {
|
|
||||||
logger.info('originRequestUrl', originRequest.url)
|
|
||||||
updateOriginRequest(originRequest, { consumer, deployment })
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(ctx.get('sessionId'), 500, 'Internal error: sessionId should be set')
|
|
||||||
|
|
||||||
return {
|
|
||||||
deployment,
|
|
||||||
consumer,
|
|
||||||
tool,
|
|
||||||
pricingPlan,
|
|
||||||
toolCallArgs,
|
|
||||||
originRequest,
|
|
||||||
originMcpClient,
|
|
||||||
originMcpRequestMetadata
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,240 @@
|
||||||
|
import type {
|
||||||
|
AdminDeployment,
|
||||||
|
PricingPlan,
|
||||||
|
RateLimit,
|
||||||
|
Tool
|
||||||
|
} from '@agentic/platform-types'
|
||||||
|
import { assert } from '@agentic/platform-core'
|
||||||
|
import { parseDeploymentIdentifier } from '@agentic/platform-validators'
|
||||||
|
|
||||||
|
import type { RawEnv } from './env'
|
||||||
|
import type {
|
||||||
|
AdminConsumer,
|
||||||
|
AgenticMcpRequestMetadata,
|
||||||
|
McpToolCallResponse,
|
||||||
|
ToolCallArgs
|
||||||
|
} from './types'
|
||||||
|
import { cfValidateJsonSchema } from './cf-validate-json-schema'
|
||||||
|
import { createRequestForOpenAPIOperation } from './create-request-for-openapi-operation'
|
||||||
|
import { enforceRateLimit } from './enforce-rate-limit'
|
||||||
|
import { fetchCache } from './fetch-cache'
|
||||||
|
import { getRequestCacheKey } from './get-request-cache-key'
|
||||||
|
import { updateOriginRequest } from './update-origin-request'
|
||||||
|
|
||||||
|
// type State = { counter: number }
|
||||||
|
|
||||||
|
export type ResolvedOriginToolCallResult = {
|
||||||
|
toolCallArgs: ToolCallArgs
|
||||||
|
originRequest?: Request
|
||||||
|
originResponse?: Response
|
||||||
|
toolCallResponse?: McpToolCallResponse
|
||||||
|
} & (
|
||||||
|
| {
|
||||||
|
originRequest: Request
|
||||||
|
originResponse: Response
|
||||||
|
toolCallResponse?: never
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
originRequest?: never
|
||||||
|
originResponse?: never
|
||||||
|
toolCallResponse: McpToolCallResponse
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
export async function resolveOriginToolCall({
|
||||||
|
tool,
|
||||||
|
args,
|
||||||
|
deployment,
|
||||||
|
consumer,
|
||||||
|
pricingPlan,
|
||||||
|
sessionId,
|
||||||
|
env,
|
||||||
|
ip,
|
||||||
|
waitUntil
|
||||||
|
}: {
|
||||||
|
tool: Tool
|
||||||
|
args?: ToolCallArgs
|
||||||
|
deployment: AdminDeployment
|
||||||
|
consumer?: AdminConsumer
|
||||||
|
pricingPlan?: PricingPlan
|
||||||
|
sessionId: string
|
||||||
|
env: RawEnv
|
||||||
|
ip?: string
|
||||||
|
waitUntil: (promise: Promise<any>) => void
|
||||||
|
}): Promise<ResolvedOriginToolCallResult> {
|
||||||
|
// TODO: rate-limiting
|
||||||
|
// TODO: caching
|
||||||
|
// TODO: usage tracking / reporting
|
||||||
|
// TODO: all of this per-request logic should maybe be moved to a diff method
|
||||||
|
// since it's not specific to tool calls. eg, other MCP requests may still
|
||||||
|
// need to be rate-limited / cached / tracked / etc.
|
||||||
|
|
||||||
|
const { originAdapter } = deployment
|
||||||
|
let rateLimit: RateLimit | undefined | null
|
||||||
|
let reportUsage = true
|
||||||
|
|
||||||
|
// Resolve rate limit and whether to report `requests` usage based on the
|
||||||
|
// customer's pricing plan and deployment config.
|
||||||
|
if (pricingPlan) {
|
||||||
|
const requestsLineItem = pricingPlan.lineItems.find(
|
||||||
|
(lineItem) => lineItem.slug === 'requests'
|
||||||
|
)
|
||||||
|
|
||||||
|
if (requestsLineItem) {
|
||||||
|
assert(
|
||||||
|
requestsLineItem.slug === 'requests',
|
||||||
|
403,
|
||||||
|
`Invalid pricing plan "${pricingPlan.slug}" for project "${deployment.project}"`
|
||||||
|
)
|
||||||
|
|
||||||
|
rateLimit = requestsLineItem.rateLimit
|
||||||
|
} else {
|
||||||
|
// No `requests` line-item, so we don't report usage for this tool.
|
||||||
|
reportUsage = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolConfig = deployment.toolConfigs.find(
|
||||||
|
(toolConfig) => toolConfig.name === tool.name
|
||||||
|
)
|
||||||
|
|
||||||
|
if (toolConfig) {
|
||||||
|
if (toolConfig.reportUsage !== undefined) {
|
||||||
|
reportUsage &&= !!toolConfig.reportUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
if (toolConfig.rateLimit !== undefined) {
|
||||||
|
// TODO: Improve RateLimitInput vs RateLimit types
|
||||||
|
rateLimit = toolConfig.rateLimit as RateLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
const pricingPlanToolConfig = pricingPlan
|
||||||
|
? toolConfig.pricingPlanConfig?.[pricingPlan.slug]
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
if (pricingPlan && pricingPlanToolConfig) {
|
||||||
|
assert(
|
||||||
|
pricingPlanToolConfig.enabled ||
|
||||||
|
(pricingPlanToolConfig.enabled === undefined && toolConfig.enabled),
|
||||||
|
403,
|
||||||
|
`Tool "${tool.name}" is not enabled for pricing plan "${pricingPlan.slug}"`
|
||||||
|
)
|
||||||
|
|
||||||
|
if (pricingPlanToolConfig.reportUsage !== undefined) {
|
||||||
|
reportUsage &&= !!pricingPlanToolConfig.reportUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pricingPlanToolConfig.rateLimit !== undefined) {
|
||||||
|
// TODO: Improve RateLimitInput vs RateLimit types
|
||||||
|
rateLimit = pricingPlanToolConfig.rateLimit as RateLimit
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(toolConfig.enabled, 403, `Tool "${tool.name}" is not enabled`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rateLimit) {
|
||||||
|
await enforceRateLimit({
|
||||||
|
id: consumer?.id ?? ip,
|
||||||
|
interval: rateLimit.interval,
|
||||||
|
maxPerInterval: rateLimit.maxPerInterval
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (originAdapter.type === 'raw') {
|
||||||
|
// TODO
|
||||||
|
assert(false, 500, 'Raw origin adapter not implemented')
|
||||||
|
} else {
|
||||||
|
// Validate incoming request params against the tool's input schema.
|
||||||
|
const toolCallArgs = cfValidateJsonSchema<Record<string, any>>({
|
||||||
|
schema: tool.inputSchema,
|
||||||
|
data: args,
|
||||||
|
errorPrefix: `Invalid request parameters for tool "${tool.name}"`,
|
||||||
|
strictAdditionalProperties: true
|
||||||
|
})
|
||||||
|
|
||||||
|
if (originAdapter.type === 'openapi') {
|
||||||
|
const operation = originAdapter.toolToOperationMap[tool.name]
|
||||||
|
assert(operation, 404, `Tool "${tool.name}" not found in OpenAPI spec`)
|
||||||
|
assert(toolCallArgs, 500)
|
||||||
|
|
||||||
|
const originRequest = await createRequestForOpenAPIOperation({
|
||||||
|
toolCallArgs,
|
||||||
|
operation,
|
||||||
|
deployment
|
||||||
|
})
|
||||||
|
|
||||||
|
updateOriginRequest(originRequest, { consumer, deployment })
|
||||||
|
|
||||||
|
const cacheKey = await getRequestCacheKey(originRequest)
|
||||||
|
|
||||||
|
// TODO: transform origin 5XX errors to 502 errors...
|
||||||
|
const originResponse = await fetchCache({
|
||||||
|
cacheKey,
|
||||||
|
fetchResponse: () => fetch(originRequest),
|
||||||
|
waitUntil
|
||||||
|
})
|
||||||
|
|
||||||
|
// non-cached version
|
||||||
|
// const originResponse = await fetch(originRequest)
|
||||||
|
|
||||||
|
return {
|
||||||
|
toolCallArgs,
|
||||||
|
originRequest,
|
||||||
|
originResponse
|
||||||
|
}
|
||||||
|
} else if (originAdapter.type === 'mcp') {
|
||||||
|
const id: DurableObjectId = env.DO_MCP_CLIENT.idFromName(sessionId)
|
||||||
|
const originMcpClient = env.DO_MCP_CLIENT.get(id)
|
||||||
|
|
||||||
|
await originMcpClient.init({
|
||||||
|
url: deployment.originUrl,
|
||||||
|
name: originAdapter.serverInfo.name,
|
||||||
|
version: originAdapter.serverInfo.version
|
||||||
|
})
|
||||||
|
|
||||||
|
const { projectIdentifier } = parseDeploymentIdentifier(
|
||||||
|
deployment.identifier,
|
||||||
|
{ errorStatusCode: 500 }
|
||||||
|
)
|
||||||
|
|
||||||
|
const originMcpRequestMetadata = {
|
||||||
|
agenticProxySecret: deployment._secret,
|
||||||
|
sessionId,
|
||||||
|
// ip,
|
||||||
|
isCustomerSubscriptionActive: !!consumer?.isStripeSubscriptionActive,
|
||||||
|
customerId: consumer?.id,
|
||||||
|
customerSubscriptionPlan: consumer?.plan,
|
||||||
|
customerSubscriptionStatus: consumer?.stripeStatus,
|
||||||
|
userId: consumer?.user.id,
|
||||||
|
userEmail: consumer?.user.email,
|
||||||
|
userUsername: consumer?.user.username,
|
||||||
|
userName: consumer?.user.name,
|
||||||
|
userCreatedAt: consumer?.user.createdAt,
|
||||||
|
userUpdatedAt: consumer?.user.updatedAt,
|
||||||
|
deploymentId: deployment.id,
|
||||||
|
deploymentIdentifier: deployment.identifier,
|
||||||
|
projectId: deployment.projectId,
|
||||||
|
projectIdentifier
|
||||||
|
} as AgenticMcpRequestMetadata
|
||||||
|
|
||||||
|
// TODO: add timeout support to the origin tool call?
|
||||||
|
// TODO: add response caching for origin MCP tool calls
|
||||||
|
const toolCallResponseString = await originMcpClient.callTool({
|
||||||
|
name: tool.name,
|
||||||
|
args: toolCallArgs,
|
||||||
|
metadata: originMcpRequestMetadata!
|
||||||
|
})
|
||||||
|
const toolCallResponse = JSON.parse(
|
||||||
|
toolCallResponseString
|
||||||
|
) as McpToolCallResponse
|
||||||
|
|
||||||
|
return {
|
||||||
|
toolCallArgs,
|
||||||
|
toolCallResponse
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(false, 500)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -74,7 +74,7 @@ export async function transformHttpResponseToMcpToolCallResponse({
|
||||||
coerce: false,
|
coerce: false,
|
||||||
// TODO: double-check MCP schema on whether additional properties are allowed
|
// TODO: double-check MCP schema on whether additional properties are allowed
|
||||||
strictAdditionalProperties: true,
|
strictAdditionalProperties: true,
|
||||||
errorMessage: `Invalid tool response for tool "${tool.name}"`,
|
errorPrefix: `Invalid tool response for tool "${tool.name}"`,
|
||||||
errorStatusCode: 502
|
errorStatusCode: 502
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -3,18 +3,11 @@ import type {
|
||||||
DefaultHonoBindings,
|
DefaultHonoBindings,
|
||||||
DefaultHonoVariables
|
DefaultHonoVariables
|
||||||
} from '@agentic/platform-hono'
|
} from '@agentic/platform-hono'
|
||||||
import type {
|
import type { Consumer, User } from '@agentic/platform-types'
|
||||||
AdminDeployment,
|
|
||||||
Consumer,
|
|
||||||
PricingPlan,
|
|
||||||
Tool,
|
|
||||||
User
|
|
||||||
} from '@agentic/platform-types'
|
|
||||||
import type { Client as McpClient } from '@modelcontextprotocol/sdk/client/index.js'
|
import type { Client as McpClient } from '@modelcontextprotocol/sdk/client/index.js'
|
||||||
import type { Context } from 'hono'
|
import type { Context } from 'hono'
|
||||||
import type { Simplify } from 'type-fest'
|
import type { Simplify } from 'type-fest'
|
||||||
|
|
||||||
import type { DurableMcpClient } from './durable-mcp-client'
|
|
||||||
import type { Env } from './env'
|
import type { Env } from './env'
|
||||||
|
|
||||||
export type McpToolCallResponse = Simplify<
|
export type McpToolCallResponse = Simplify<
|
||||||
|
@ -48,19 +41,6 @@ export type GatewayHonoContext = Context<GatewayHonoEnv>
|
||||||
// TODO: better type here
|
// TODO: better type here
|
||||||
export type ToolCallArgs = Record<string, any>
|
export type ToolCallArgs = Record<string, any>
|
||||||
|
|
||||||
export type ResolvedOriginRequest = {
|
|
||||||
deployment: AdminDeployment
|
|
||||||
tool: Tool
|
|
||||||
|
|
||||||
consumer?: AdminConsumer
|
|
||||||
pricingPlan?: PricingPlan
|
|
||||||
|
|
||||||
toolCallArgs?: ToolCallArgs
|
|
||||||
originRequest?: Request
|
|
||||||
originMcpClient?: DurableObjectStub<DurableMcpClient>
|
|
||||||
originMcpRequestMetadata?: AgenticMcpRequestMetadata
|
|
||||||
}
|
|
||||||
|
|
||||||
export type AgenticMcpRequestMetadata = {
|
export type AgenticMcpRequestMetadata = {
|
||||||
agenticProxySecret: string
|
agenticProxySecret: string
|
||||||
sessionId: string
|
sessionId: string
|
||||||
|
|
Ładowanie…
Reference in New Issue