Modified CORS middleware with Redis cache workflow

pull/821/head
kompotkot 2023-06-17 21:04:05 +00:00
rodzic 7bbb0be6df
commit 360f4d8286
7 zmienionych plików z 239 dodań i 4 usunięć

Wyświetl plik

@ -23,6 +23,10 @@ class NowResponse(BaseModel):
epoch_time: float
class CORSResponse(BaseModel):
cors: str
class SignerListResponse(BaseModel):
instances: List[Any] = Field(default_factory=list)

Wyświetl plik

@ -1,12 +1,18 @@
import base64
import functools
import json
import logging
from typing import Any, Awaitable, Callable, Dict, Optional
from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence, Set, cast
from uuid import UUID
from bugout.data import BugoutUser
from bugout.data import BugoutResources, BugoutUser
from bugout.exceptions import BugoutResponseException
from fastapi import HTTPException, Request, Response
from pydantic import AnyHttpUrl, parse_obj_as
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import Response
from starlette.types import ASGIApp, Receive, Scope, Send
from web3 import Web3
from .auth import (
@ -14,7 +20,15 @@ from .auth import (
MoonstreamAuthorizationVerificationError,
verify,
)
from .settings import bugout_client as bc, MOONSTREAM_APPLICATION_ID
from .rc import rc_client, yield_rc_async_session
from .settings import (
ALLOW_ORIGINS,
BUGOUT_REQUEST_TIMEOUT_SECONDS,
BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
MOONSTREAM_ADMIN_ACCESS_TOKEN,
MOONSTREAM_APPLICATION_ID,
)
from .settings import bugout_client as bc
logger = logging.getLogger(__name__)
@ -199,3 +213,129 @@ class ExtractBearerTokenMiddleware(BaseHTTPMiddleware):
request.state.token = user_token
return await call_next(request)
def fetch_application_settings_cors(token: str):
"""
Fetch application config resources with CORS origins setting.
If there are no such resources create new one with default origins from environment variable.
"""
resources: BugoutResources
try:
resources = bc.list_resources(
token=token,
params={
"application_id": MOONSTREAM_APPLICATION_ID,
"type": BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
"setting": "cors",
},
timeout=BUGOUT_REQUEST_TIMEOUT_SECONDS,
)
except Exception as err:
raise Exception(
f"Error fetching bugout resources with CORS origins: {str(err)}"
)
if len(resources.resources) == 0:
moonstream_admin_user = bc.get_user(
token=MOONSTREAM_ADMIN_ACCESS_TOKEN,
)
resource = bc.create_resource(
token=MOONSTREAM_ADMIN_ACCESS_TOKEN,
application_id=MOONSTREAM_APPLICATION_ID,
resource_data={
"type": BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
"setting": "cors",
"user_id": str(moonstream_admin_user.id),
"cors": ALLOW_ORIGINS,
},
)
resources.resources.append(resource)
logger.info(
"Created resource with default CORS origins setting by moonstream admin user"
)
return resources
def parse_origins_from_resources(resources: BugoutResources) -> set:
if len(resources.resources) == 0:
return ALLOW_ORIGINS
resource_cors_origins = set()
for resource in resources.resources:
origins = resource.resource_data.get("cors", [])
for o in origins:
try:
parse_obj_as(AnyHttpUrl, o)
resource_cors_origins.add(o)
except Exception:
logger.info(f"Unable to parse origin: {o} as URL")
continue
for o in ALLOW_ORIGINS:
if o not in resource_cors_origins:
resource_cors_origins.add(o)
return resource_cors_origins
def initialize_origins() -> set:
allow_origins: set = set(ALLOW_ORIGINS)
try:
resources = fetch_application_settings_cors(token=MOONSTREAM_ADMIN_ACCESS_TOKEN)
resource_origins = parse_origins_from_resources(resources=resources)
try:
origins_str = ",".join(list(resource_origins))
rc_client.set("cors", origins_str)
except Exception:
logger.warning("Unable to set CORS origins at Redis cache")
finally:
rc_client.close()
allow_origins = resource_origins
except Exception as err:
logger.error(
f"Unable to get CORS origins from Brood resources application config, err: {str(err)}"
)
return allow_origins
class BugoutCORSMiddleware(CORSMiddleware):
"""
Modified CORSMiddleware from starlette.middleware.cors.py to work with Redis cache.
"""
def __init__(
self,
app: ASGIApp,
allow_methods: Sequence[str] = ("GET",),
allow_headers: Sequence[str] = (),
allow_credentials: bool = False,
expose_headers: Sequence[str] = (),
max_age: int = 600,
):
self.allow_origins = initialize_origins()
super().__init__(
app=app,
allow_origins=self.allow_origins,
allow_methods=allow_methods,
allow_headers=allow_headers,
allow_credentials=allow_credentials,
allow_origin_regex=None,
expose_headers=expose_headers,
max_age=max_age,
)
async def __call__(self, scope: Scope, receive: Receive, send: Send):
try:
allow_origins = initialize_origins()
self.allow_origins = allow_origins
except Exception:
logger.info(
"Unable to parse CORS configs, used default CORS origins by middleware"
)
await super().__call__(scope, receive, send)

Wyświetl plik

@ -0,0 +1,40 @@
import os
from contextlib import asynccontextmanager
from redis import ConnectionPool, Redis
from redis import asyncio as aioredis
from .settings import ENGINE_REDIS_URI
def create_redis_client() -> Redis:
rc_pool = ConnectionPool.from_url(
url=ENGINE_REDIS_URI,
max_connections=10,
decode_responses=True
)
return Redis(connection_pool=rc_pool)
rc_client = create_redis_client()
def create_async_redis_client() -> Redis:
rc_pool_async: ConnectionPool = aioredis.ConnectionPool.from_url(
url=ENGINE_REDIS_URI,
max_connections=10,
decode_responses=True
)
return aioredis.Redis(connection_pool=rc_pool_async)
rc_client_async = create_async_redis_client()
@asynccontextmanager
async def yield_rc_async_session():
try:
yield rc_client_async
finally:
await rc_client_async.close()

Wyświetl plik

@ -1,5 +1,6 @@
import os
import warnings
from typing import List
from web3 import Web3, HTTPProvider
from web3.middleware import geth_poa_middleware
@ -21,7 +22,12 @@ if RAW_ORIGINS is None:
raise ValueError(
"ENGINE_CORS_ALLOWED_ORIGINS environment variable must be set (comma-separated list of CORS allowed origins)"
)
ORIGINS = RAW_ORIGINS.split(",")
ALLOW_ORIGINS: List[str] = RAW_ORIGINS.split(",")
BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG = "application-config"
BUGOUT_REQUEST_TIMEOUT_SECONDS = 5
ENGINE_REDIS_URI = os.environ.get("ENGINE_REDIS_URI")
# Open API documentation path
DOCS_TARGET_PATH = os.environ.get("DOCS_TARGET_PATH", "docs")

Wyświetl plik

@ -0,0 +1,43 @@
import unittest
import uuid
from datetime import datetime
from bugout.data import BugoutResource, BugoutResources, BugoutUser
from pydantic import AnyHttpUrl, parse_obj_as
from .middleware import parse_origins_from_resources
from .settings import BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG
TEST_ALLOW_ORIGINS = ["http://localhost:3000", "http://localhost:4000", "wrong one"]
class TestInit(unittest.TestCase):
def setUp(self):
utc_now = datetime.utcnow()
self.resources: BugoutResources = BugoutResources(
resources=[
BugoutResource(
id=uuid.uuid4(),
application_id=str(uuid.uuid4()),
resource_data={
"type": BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
"setting": "cors",
"user_id": str(uuid.uuid4()),
"cors": TEST_ALLOW_ORIGINS,
},
created_at=utc_now,
updated_at=utc_now,
)
]
)
def test_parse_origins_from_resources(self):
cnt = 0
for o in TEST_ALLOW_ORIGINS:
try:
parse_obj_as(AnyHttpUrl, o)
cnt += 1
except Exception:
continue
cors_origins = parse_origins_from_resources(self.resources)
self.assertEqual(cnt, len(cors_origins))

Wyświetl plik

@ -9,6 +9,7 @@ export ENGINE_DB_URI="postgresql://<username>:<password>@<db_host>:<db_port>/<db
export ENGINE_DB_URI_READ_ONLY="postgresql://<username>:<password>@<db_host>:<db_port>/<db_name>"
export MOONSTREAM_ADMIN_ACCESS_TOKEN="<admin access token>"
export MOONSTREAM_APPLICATION_ID="<moonstream application id>"
export ENGINE_REDIS_URI="redis://localhost:6380"
# Web3 Provider URIs
export MOONSTREAM_ETHEREUM_WEB3_PROVIDER_URI="<JSON_RPC_API_URL>"

Wyświetl plik

@ -17,6 +17,7 @@ setup(
"eip712==0.1.0",
"eth-typing>=2.3.0",
"fastapi",
"redis",
"psycopg2-binary",
"pydantic",
"sqlalchemy",