diff --git a/engineapi/engineapi/data.py b/engineapi/engineapi/data.py index b005ca9d..c9f7e2a2 100644 --- a/engineapi/engineapi/data.py +++ b/engineapi/engineapi/data.py @@ -23,6 +23,10 @@ class NowResponse(BaseModel): epoch_time: float +class CORSResponse(BaseModel): + cors: str + + class SignerListResponse(BaseModel): instances: List[Any] = Field(default_factory=list) diff --git a/engineapi/engineapi/middleware.py b/engineapi/engineapi/middleware.py index 9e4cc570..d062d909 100644 --- a/engineapi/engineapi/middleware.py +++ b/engineapi/engineapi/middleware.py @@ -1,12 +1,18 @@ import base64 +import functools import json import logging -from typing import Any, Awaitable, Callable, Dict, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence, Set, cast +from uuid import UUID -from bugout.data import BugoutUser +from bugout.data import BugoutResources, BugoutUser from bugout.exceptions import BugoutResponseException from fastapi import HTTPException, Request, Response +from pydantic import AnyHttpUrl, parse_obj_as from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.cors import CORSMiddleware +from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send from web3 import Web3 from .auth import ( @@ -14,7 +20,15 @@ from .auth import ( MoonstreamAuthorizationVerificationError, verify, ) -from .settings import bugout_client as bc, MOONSTREAM_APPLICATION_ID +from .rc import rc_client, yield_rc_async_session +from .settings import ( + ALLOW_ORIGINS, + BUGOUT_REQUEST_TIMEOUT_SECONDS, + BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG, + MOONSTREAM_ADMIN_ACCESS_TOKEN, + MOONSTREAM_APPLICATION_ID, +) +from .settings import bugout_client as bc logger = logging.getLogger(__name__) @@ -199,3 +213,129 @@ class ExtractBearerTokenMiddleware(BaseHTTPMiddleware): request.state.token = user_token return await call_next(request) + + +def fetch_application_settings_cors(token: str): + """ + Fetch application config resources with CORS origins setting. + If there are no such resources create new one with default origins from environment variable. + """ + resources: BugoutResources + try: + resources = bc.list_resources( + token=token, + params={ + "application_id": MOONSTREAM_APPLICATION_ID, + "type": BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG, + "setting": "cors", + }, + timeout=BUGOUT_REQUEST_TIMEOUT_SECONDS, + ) + + except Exception as err: + raise Exception( + f"Error fetching bugout resources with CORS origins: {str(err)}" + ) + + if len(resources.resources) == 0: + moonstream_admin_user = bc.get_user( + token=MOONSTREAM_ADMIN_ACCESS_TOKEN, + ) + resource = bc.create_resource( + token=MOONSTREAM_ADMIN_ACCESS_TOKEN, + application_id=MOONSTREAM_APPLICATION_ID, + resource_data={ + "type": BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG, + "setting": "cors", + "user_id": str(moonstream_admin_user.id), + "cors": ALLOW_ORIGINS, + }, + ) + resources.resources.append(resource) + logger.info( + "Created resource with default CORS origins setting by moonstream admin user" + ) + + return resources + + +def parse_origins_from_resources(resources: BugoutResources) -> set: + if len(resources.resources) == 0: + return ALLOW_ORIGINS + + resource_cors_origins = set() + for resource in resources.resources: + origins = resource.resource_data.get("cors", []) + for o in origins: + try: + parse_obj_as(AnyHttpUrl, o) + resource_cors_origins.add(o) + except Exception: + logger.info(f"Unable to parse origin: {o} as URL") + continue + + for o in ALLOW_ORIGINS: + if o not in resource_cors_origins: + resource_cors_origins.add(o) + + return resource_cors_origins + + +def initialize_origins() -> set: + allow_origins: set = set(ALLOW_ORIGINS) + try: + resources = fetch_application_settings_cors(token=MOONSTREAM_ADMIN_ACCESS_TOKEN) + resource_origins = parse_origins_from_resources(resources=resources) + try: + origins_str = ",".join(list(resource_origins)) + rc_client.set("cors", origins_str) + except Exception: + logger.warning("Unable to set CORS origins at Redis cache") + finally: + rc_client.close() + allow_origins = resource_origins + except Exception as err: + logger.error( + f"Unable to get CORS origins from Brood resources application config, err: {str(err)}" + ) + + return allow_origins + + +class BugoutCORSMiddleware(CORSMiddleware): + """ + Modified CORSMiddleware from starlette.middleware.cors.py to work with Redis cache. + """ + + def __init__( + self, + app: ASGIApp, + allow_methods: Sequence[str] = ("GET",), + allow_headers: Sequence[str] = (), + allow_credentials: bool = False, + expose_headers: Sequence[str] = (), + max_age: int = 600, + ): + self.allow_origins = initialize_origins() + + super().__init__( + app=app, + allow_origins=self.allow_origins, + allow_methods=allow_methods, + allow_headers=allow_headers, + allow_credentials=allow_credentials, + allow_origin_regex=None, + expose_headers=expose_headers, + max_age=max_age, + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + try: + allow_origins = initialize_origins() + self.allow_origins = allow_origins + except Exception: + logger.info( + "Unable to parse CORS configs, used default CORS origins by middleware" + ) + + await super().__call__(scope, receive, send) diff --git a/engineapi/engineapi/rc.py b/engineapi/engineapi/rc.py new file mode 100644 index 00000000..b8a4e362 --- /dev/null +++ b/engineapi/engineapi/rc.py @@ -0,0 +1,40 @@ +import os +from contextlib import asynccontextmanager + +from redis import ConnectionPool, Redis +from redis import asyncio as aioredis + +from .settings import ENGINE_REDIS_URI + + +def create_redis_client() -> Redis: + rc_pool = ConnectionPool.from_url( + url=ENGINE_REDIS_URI, + max_connections=10, + decode_responses=True + ) + return Redis(connection_pool=rc_pool) + + +rc_client = create_redis_client() + + +def create_async_redis_client() -> Redis: + rc_pool_async: ConnectionPool = aioredis.ConnectionPool.from_url( + url=ENGINE_REDIS_URI, + max_connections=10, + decode_responses=True + ) + + return aioredis.Redis(connection_pool=rc_pool_async) + + +rc_client_async = create_async_redis_client() + + +@asynccontextmanager +async def yield_rc_async_session(): + try: + yield rc_client_async + finally: + await rc_client_async.close() diff --git a/engineapi/engineapi/settings.py b/engineapi/engineapi/settings.py index be2de443..41e2e152 100644 --- a/engineapi/engineapi/settings.py +++ b/engineapi/engineapi/settings.py @@ -1,5 +1,6 @@ import os import warnings +from typing import List from web3 import Web3, HTTPProvider from web3.middleware import geth_poa_middleware @@ -21,7 +22,12 @@ if RAW_ORIGINS is None: raise ValueError( "ENGINE_CORS_ALLOWED_ORIGINS environment variable must be set (comma-separated list of CORS allowed origins)" ) -ORIGINS = RAW_ORIGINS.split(",") +ALLOW_ORIGINS: List[str] = RAW_ORIGINS.split(",") + +BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG = "application-config" +BUGOUT_REQUEST_TIMEOUT_SECONDS = 5 + +ENGINE_REDIS_URI = os.environ.get("ENGINE_REDIS_URI") # Open API documentation path DOCS_TARGET_PATH = os.environ.get("DOCS_TARGET_PATH", "docs") diff --git a/engineapi/engineapi/test_middleware.py b/engineapi/engineapi/test_middleware.py new file mode 100644 index 00000000..7483fd5c --- /dev/null +++ b/engineapi/engineapi/test_middleware.py @@ -0,0 +1,43 @@ +import unittest +import uuid +from datetime import datetime + +from bugout.data import BugoutResource, BugoutResources, BugoutUser +from pydantic import AnyHttpUrl, parse_obj_as + +from .middleware import parse_origins_from_resources +from .settings import BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG + +TEST_ALLOW_ORIGINS = ["http://localhost:3000", "http://localhost:4000", "wrong one"] + + +class TestInit(unittest.TestCase): + def setUp(self): + utc_now = datetime.utcnow() + self.resources: BugoutResources = BugoutResources( + resources=[ + BugoutResource( + id=uuid.uuid4(), + application_id=str(uuid.uuid4()), + resource_data={ + "type": BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG, + "setting": "cors", + "user_id": str(uuid.uuid4()), + "cors": TEST_ALLOW_ORIGINS, + }, + created_at=utc_now, + updated_at=utc_now, + ) + ] + ) + + def test_parse_origins_from_resources(self): + cnt = 0 + for o in TEST_ALLOW_ORIGINS: + try: + parse_obj_as(AnyHttpUrl, o) + cnt += 1 + except Exception: + continue + cors_origins = parse_origins_from_resources(self.resources) + self.assertEqual(cnt, len(cors_origins)) diff --git a/engineapi/sample.env b/engineapi/sample.env index 78a4b668..cfd716e3 100644 --- a/engineapi/sample.env +++ b/engineapi/sample.env @@ -9,6 +9,7 @@ export ENGINE_DB_URI="postgresql://:@:/=2.3.0", "fastapi", + "redis", "psycopg2-binary", "pydantic", "sqlalchemy",