kopia lustrzana https://github.com/bugout-dev/moonstream
Modified CORS middleware with Redis cache workflow
rodzic
7bbb0be6df
commit
360f4d8286
|
@ -23,6 +23,10 @@ class NowResponse(BaseModel):
|
|||
epoch_time: float
|
||||
|
||||
|
||||
class CORSResponse(BaseModel):
|
||||
cors: str
|
||||
|
||||
|
||||
class SignerListResponse(BaseModel):
|
||||
instances: List[Any] = Field(default_factory=list)
|
||||
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
import base64
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence, Set, cast
|
||||
from uuid import UUID
|
||||
|
||||
from bugout.data import BugoutUser
|
||||
from bugout.data import BugoutResources, BugoutUser
|
||||
from bugout.exceptions import BugoutResponseException
|
||||
from fastapi import HTTPException, Request, Response
|
||||
from pydantic import AnyHttpUrl, parse_obj_as
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from web3 import Web3
|
||||
|
||||
from .auth import (
|
||||
|
@ -14,7 +20,15 @@ from .auth import (
|
|||
MoonstreamAuthorizationVerificationError,
|
||||
verify,
|
||||
)
|
||||
from .settings import bugout_client as bc, MOONSTREAM_APPLICATION_ID
|
||||
from .rc import rc_client, yield_rc_async_session
|
||||
from .settings import (
|
||||
ALLOW_ORIGINS,
|
||||
BUGOUT_REQUEST_TIMEOUT_SECONDS,
|
||||
BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
|
||||
MOONSTREAM_ADMIN_ACCESS_TOKEN,
|
||||
MOONSTREAM_APPLICATION_ID,
|
||||
)
|
||||
from .settings import bugout_client as bc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -199,3 +213,129 @@ class ExtractBearerTokenMiddleware(BaseHTTPMiddleware):
|
|||
request.state.token = user_token
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def fetch_application_settings_cors(token: str):
|
||||
"""
|
||||
Fetch application config resources with CORS origins setting.
|
||||
If there are no such resources create new one with default origins from environment variable.
|
||||
"""
|
||||
resources: BugoutResources
|
||||
try:
|
||||
resources = bc.list_resources(
|
||||
token=token,
|
||||
params={
|
||||
"application_id": MOONSTREAM_APPLICATION_ID,
|
||||
"type": BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
|
||||
"setting": "cors",
|
||||
},
|
||||
timeout=BUGOUT_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
raise Exception(
|
||||
f"Error fetching bugout resources with CORS origins: {str(err)}"
|
||||
)
|
||||
|
||||
if len(resources.resources) == 0:
|
||||
moonstream_admin_user = bc.get_user(
|
||||
token=MOONSTREAM_ADMIN_ACCESS_TOKEN,
|
||||
)
|
||||
resource = bc.create_resource(
|
||||
token=MOONSTREAM_ADMIN_ACCESS_TOKEN,
|
||||
application_id=MOONSTREAM_APPLICATION_ID,
|
||||
resource_data={
|
||||
"type": BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
|
||||
"setting": "cors",
|
||||
"user_id": str(moonstream_admin_user.id),
|
||||
"cors": ALLOW_ORIGINS,
|
||||
},
|
||||
)
|
||||
resources.resources.append(resource)
|
||||
logger.info(
|
||||
"Created resource with default CORS origins setting by moonstream admin user"
|
||||
)
|
||||
|
||||
return resources
|
||||
|
||||
|
||||
def parse_origins_from_resources(resources: BugoutResources) -> set:
|
||||
if len(resources.resources) == 0:
|
||||
return ALLOW_ORIGINS
|
||||
|
||||
resource_cors_origins = set()
|
||||
for resource in resources.resources:
|
||||
origins = resource.resource_data.get("cors", [])
|
||||
for o in origins:
|
||||
try:
|
||||
parse_obj_as(AnyHttpUrl, o)
|
||||
resource_cors_origins.add(o)
|
||||
except Exception:
|
||||
logger.info(f"Unable to parse origin: {o} as URL")
|
||||
continue
|
||||
|
||||
for o in ALLOW_ORIGINS:
|
||||
if o not in resource_cors_origins:
|
||||
resource_cors_origins.add(o)
|
||||
|
||||
return resource_cors_origins
|
||||
|
||||
|
||||
def initialize_origins() -> set:
|
||||
allow_origins: set = set(ALLOW_ORIGINS)
|
||||
try:
|
||||
resources = fetch_application_settings_cors(token=MOONSTREAM_ADMIN_ACCESS_TOKEN)
|
||||
resource_origins = parse_origins_from_resources(resources=resources)
|
||||
try:
|
||||
origins_str = ",".join(list(resource_origins))
|
||||
rc_client.set("cors", origins_str)
|
||||
except Exception:
|
||||
logger.warning("Unable to set CORS origins at Redis cache")
|
||||
finally:
|
||||
rc_client.close()
|
||||
allow_origins = resource_origins
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"Unable to get CORS origins from Brood resources application config, err: {str(err)}"
|
||||
)
|
||||
|
||||
return allow_origins
|
||||
|
||||
|
||||
class BugoutCORSMiddleware(CORSMiddleware):
|
||||
"""
|
||||
Modified CORSMiddleware from starlette.middleware.cors.py to work with Redis cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allow_methods: Sequence[str] = ("GET",),
|
||||
allow_headers: Sequence[str] = (),
|
||||
allow_credentials: bool = False,
|
||||
expose_headers: Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
):
|
||||
self.allow_origins = initialize_origins()
|
||||
|
||||
super().__init__(
|
||||
app=app,
|
||||
allow_origins=self.allow_origins,
|
||||
allow_methods=allow_methods,
|
||||
allow_headers=allow_headers,
|
||||
allow_credentials=allow_credentials,
|
||||
allow_origin_regex=None,
|
||||
expose_headers=expose_headers,
|
||||
max_age=max_age,
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
try:
|
||||
allow_origins = initialize_origins()
|
||||
self.allow_origins = allow_origins
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Unable to parse CORS configs, used default CORS origins by middleware"
|
||||
)
|
||||
|
||||
await super().__call__(scope, receive, send)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from redis import ConnectionPool, Redis
|
||||
from redis import asyncio as aioredis
|
||||
|
||||
from .settings import ENGINE_REDIS_URI
|
||||
|
||||
|
||||
def create_redis_client() -> Redis:
|
||||
rc_pool = ConnectionPool.from_url(
|
||||
url=ENGINE_REDIS_URI,
|
||||
max_connections=10,
|
||||
decode_responses=True
|
||||
)
|
||||
return Redis(connection_pool=rc_pool)
|
||||
|
||||
|
||||
rc_client = create_redis_client()
|
||||
|
||||
|
||||
def create_async_redis_client() -> Redis:
|
||||
rc_pool_async: ConnectionPool = aioredis.ConnectionPool.from_url(
|
||||
url=ENGINE_REDIS_URI,
|
||||
max_connections=10,
|
||||
decode_responses=True
|
||||
)
|
||||
|
||||
return aioredis.Redis(connection_pool=rc_pool_async)
|
||||
|
||||
|
||||
rc_client_async = create_async_redis_client()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def yield_rc_async_session():
|
||||
try:
|
||||
yield rc_client_async
|
||||
finally:
|
||||
await rc_client_async.close()
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
from web3 import Web3, HTTPProvider
|
||||
from web3.middleware import geth_poa_middleware
|
||||
|
@ -21,7 +22,12 @@ if RAW_ORIGINS is None:
|
|||
raise ValueError(
|
||||
"ENGINE_CORS_ALLOWED_ORIGINS environment variable must be set (comma-separated list of CORS allowed origins)"
|
||||
)
|
||||
ORIGINS = RAW_ORIGINS.split(",")
|
||||
ALLOW_ORIGINS: List[str] = RAW_ORIGINS.split(",")
|
||||
|
||||
BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG = "application-config"
|
||||
BUGOUT_REQUEST_TIMEOUT_SECONDS = 5
|
||||
|
||||
ENGINE_REDIS_URI = os.environ.get("ENGINE_REDIS_URI")
|
||||
|
||||
# Open API documentation path
|
||||
DOCS_TARGET_PATH = os.environ.get("DOCS_TARGET_PATH", "docs")
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
import unittest
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from bugout.data import BugoutResource, BugoutResources, BugoutUser
|
||||
from pydantic import AnyHttpUrl, parse_obj_as
|
||||
|
||||
from .middleware import parse_origins_from_resources
|
||||
from .settings import BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG
|
||||
|
||||
TEST_ALLOW_ORIGINS = ["http://localhost:3000", "http://localhost:4000", "wrong one"]
|
||||
|
||||
|
||||
class TestInit(unittest.TestCase):
|
||||
def setUp(self):
|
||||
utc_now = datetime.utcnow()
|
||||
self.resources: BugoutResources = BugoutResources(
|
||||
resources=[
|
||||
BugoutResource(
|
||||
id=uuid.uuid4(),
|
||||
application_id=str(uuid.uuid4()),
|
||||
resource_data={
|
||||
"type": BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
|
||||
"setting": "cors",
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"cors": TEST_ALLOW_ORIGINS,
|
||||
},
|
||||
created_at=utc_now,
|
||||
updated_at=utc_now,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def test_parse_origins_from_resources(self):
|
||||
cnt = 0
|
||||
for o in TEST_ALLOW_ORIGINS:
|
||||
try:
|
||||
parse_obj_as(AnyHttpUrl, o)
|
||||
cnt += 1
|
||||
except Exception:
|
||||
continue
|
||||
cors_origins = parse_origins_from_resources(self.resources)
|
||||
self.assertEqual(cnt, len(cors_origins))
|
|
@ -9,6 +9,7 @@ export ENGINE_DB_URI="postgresql://<username>:<password>@<db_host>:<db_port>/<db
|
|||
export ENGINE_DB_URI_READ_ONLY="postgresql://<username>:<password>@<db_host>:<db_port>/<db_name>"
|
||||
export MOONSTREAM_ADMIN_ACCESS_TOKEN="<admin access token>"
|
||||
export MOONSTREAM_APPLICATION_ID="<moonstream application id>"
|
||||
export ENGINE_REDIS_URI="redis://localhost:6380"
|
||||
|
||||
# Web3 Provider URIs
|
||||
export MOONSTREAM_ETHEREUM_WEB3_PROVIDER_URI="<JSON_RPC_API_URL>"
|
||||
|
|
|
@ -17,6 +17,7 @@ setup(
|
|||
"eip712==0.1.0",
|
||||
"eth-typing>=2.3.0",
|
||||
"fastapi",
|
||||
"redis",
|
||||
"psycopg2-binary",
|
||||
"pydantic",
|
||||
"sqlalchemy",
|
||||
|
|
Ładowanie…
Reference in New Issue