Used Redis cache for CORS origins cache

pull/821/head
kompotkot 2023-06-22 10:55:29 +00:00
rodzic 872c1f6e76
commit aca575052b
4 zmienionych plików z 28 dodań i 29 usunięć

Wyświetl plik

@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field, root_validator, validator from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
from web3 import Web3 from web3 import Web3
@ -24,7 +24,7 @@ class NowResponse(BaseModel):
class CORSResponse(BaseModel): class CORSResponse(BaseModel):
cors: str cors: List[AnyHttpUrl] = Field(default_factory=list)
class SignerListResponse(BaseModel): class SignerListResponse(BaseModel):

Wyświetl plik

@ -10,7 +10,7 @@ from pydantic import AnyHttpUrl, parse_obj_as
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from starlette.responses import Response from starlette.responses import Response
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp
from web3 import Web3 from web3 import Web3
from .auth import ( from .auth import (
@ -18,7 +18,7 @@ from .auth import (
MoonstreamAuthorizationVerificationError, MoonstreamAuthorizationVerificationError,
verify, verify,
) )
from .rc import rc_client, yield_rc_async_session from .rc import REDIS_CONFIG_CORS_KEY, rc_client
from .settings import ( from .settings import (
ALLOW_ORIGINS, ALLOW_ORIGINS,
BUGOUT_REQUEST_TIMEOUT_SECONDS, BUGOUT_REQUEST_TIMEOUT_SECONDS,
@ -299,8 +299,7 @@ def fetch_application_settings_cors_origins(token: str) -> Set[str]:
def set_cors_origins_cache(allow_origins: Set[str]) -> None: def set_cors_origins_cache(allow_origins: Set[str]) -> None:
try: try:
allow_origins_str = ",".join(list(allow_origins)) rc_client.sadd(REDIS_CONFIG_CORS_KEY, *allow_origins)
rc_client.set("cors", allow_origins_str)
except Exception: except Exception:
logger.warning("Unable to set CORS origins at Redis cache") logger.warning("Unable to set CORS origins at Redis cache")
finally: finally:
@ -331,11 +330,11 @@ class BugoutCORSMiddleware(CORSMiddleware):
expose_headers: Sequence[str] = (), expose_headers: Sequence[str] = (),
max_age: int = 600, max_age: int = 600,
): ):
self.allow_origins = fetch_and_set_cors_origins_cache() application_configs_allowed_origins = fetch_and_set_cors_origins_cache()
super().__init__( super().__init__(
app=app, app=app,
allow_origins=self.allow_origins, allow_origins=application_configs_allowed_origins,
allow_methods=allow_methods, allow_methods=allow_methods,
allow_headers=allow_headers, allow_headers=allow_headers,
allow_credentials=allow_credentials, allow_credentials=allow_credentials,
@ -344,25 +343,23 @@ class BugoutCORSMiddleware(CORSMiddleware):
max_age=max_age, max_age=max_age,
) )
async def __call__(self, scope: Scope, receive: Receive, send: Send): def is_allowed_origin(self, origin: str) -> bool:
if self.allow_all_origins:
return True
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
origin
):
return True
try: try:
async with yield_rc_async_session() as rc: is_allowed_origin = rc_client.sismember(REDIS_CONFIG_CORS_KEY, origin)
corse_rc = await rc.get("cors") return is_allowed_origin
if corse_rc is not None:
self.allow_origins = corse_rc.split(",")
else:
allow_origins = fetch_application_settings_cors_origins(
token=MOONSTREAM_ADMIN_ACCESS_TOKEN
)
rc.set("cors", ",".join(allow_origins))
self.allow_origins = list(allow_origins)
except Exception as err: except Exception as err:
logger.warning( logger.warning(
f"Unable to get CORS origins from Redis cache, using default from environment variable, err: {str(err)}" f"Unable to fetch CORS origins from Redis cache, err: {str(err)}"
) )
allow_origins = fetch_application_settings_cors_origins( finally:
token=MOONSTREAM_ADMIN_ACCESS_TOKEN rc_client.close()
)
self.allow_origins = list(allow_origins)
await super().__call__(scope, receive, send) return origin in self.allow_origins

Wyświetl plik

@ -6,6 +6,8 @@ from redis import asyncio as aioredis
from .settings import ENGINE_REDIS_URI from .settings import ENGINE_REDIS_URI
REDIS_CONFIG_CORS_KEY = "configs:cors:engineapi"
def create_redis_client() -> Redis: def create_redis_client() -> Redis:
rc_pool = ConnectionPool.from_url( rc_pool = ConnectionPool.from_url(

Wyświetl plik

@ -71,7 +71,7 @@ app.add_middleware(
@app.get("/cors", response_model=data.CORSResponse) @app.get("/cors", response_model=data.CORSResponse)
async def get_cors( async def get_cors(
request: Request, request: Request,
): ) -> data.CORSResponse:
try: try:
resources = bc.list_resources( resources = bc.list_resources(
token=request.state.token, token=request.state.token,
@ -89,7 +89,7 @@ async def get_cors(
logger.error(repr(err)) logger.error(repr(err))
raise EngineHTTPException(status_code=500) raise EngineHTTPException(status_code=500)
return data.CORSResponse(cors=",".join(list(resource_origins_set))) return data.CORSResponse(cors=list(resource_origins_set))
@app.put("/cors", response_model=data.CORSResponse) @app.put("/cors", response_model=data.CORSResponse)
@ -97,7 +97,7 @@ async def update_cors(
request: Request, request: Request,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
new_origins: List[AnyHttpUrl] = Body(...), new_origins: List[AnyHttpUrl] = Body(...),
): ) -> data.CORSResponse:
new_origins = set(new_origins) new_origins = set(new_origins)
try: try:
@ -158,4 +158,4 @@ async def update_cors(
fetch_and_set_cors_origins_cache, fetch_and_set_cors_origins_cache,
) )
return data.CORSResponse(cors=",".join(target_resource.resource_data["origins"])) return data.CORSResponse(cors=target_resource.resource_data["origins"])