Used Redis cache for CORS origins cache

pull/821/head
kompotkot 2023-06-22 10:55:29 +00:00
rodzic 872c1f6e76
commit aca575052b
4 zmienionych plików z 28 dodań i 29 usunięć

Wyświetl plik

@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field, root_validator, validator
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
from web3 import Web3
@ -24,7 +24,7 @@ class NowResponse(BaseModel):
class CORSResponse(BaseModel):
cors: str
cors: List[AnyHttpUrl] = Field(default_factory=list)
class SignerListResponse(BaseModel):

Wyświetl plik

@ -10,7 +10,7 @@ from pydantic import AnyHttpUrl, parse_obj_as
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import Response
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp
from web3 import Web3
from .auth import (
@ -18,7 +18,7 @@ from .auth import (
MoonstreamAuthorizationVerificationError,
verify,
)
from .rc import rc_client, yield_rc_async_session
from .rc import REDIS_CONFIG_CORS_KEY, rc_client
from .settings import (
ALLOW_ORIGINS,
BUGOUT_REQUEST_TIMEOUT_SECONDS,
@ -299,8 +299,7 @@ def fetch_application_settings_cors_origins(token: str) -> Set[str]:
def set_cors_origins_cache(allow_origins: Set[str]) -> None:
try:
allow_origins_str = ",".join(list(allow_origins))
rc_client.set("cors", allow_origins_str)
rc_client.sadd(REDIS_CONFIG_CORS_KEY, *allow_origins)
except Exception:
logger.warning("Unable to set CORS origins at Redis cache")
finally:
@ -331,11 +330,11 @@ class BugoutCORSMiddleware(CORSMiddleware):
expose_headers: Sequence[str] = (),
max_age: int = 600,
):
self.allow_origins = fetch_and_set_cors_origins_cache()
application_configs_allowed_origins = fetch_and_set_cors_origins_cache()
super().__init__(
app=app,
allow_origins=self.allow_origins,
allow_origins=application_configs_allowed_origins,
allow_methods=allow_methods,
allow_headers=allow_headers,
allow_credentials=allow_credentials,
@ -344,25 +343,23 @@ class BugoutCORSMiddleware(CORSMiddleware):
max_age=max_age,
)
async def __call__(self, scope: Scope, receive: Receive, send: Send):
def is_allowed_origin(self, origin: str) -> bool:
if self.allow_all_origins:
return True
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
origin
):
return True
try:
async with yield_rc_async_session() as rc:
corse_rc = await rc.get("cors")
if corse_rc is not None:
self.allow_origins = corse_rc.split(",")
else:
allow_origins = fetch_application_settings_cors_origins(
token=MOONSTREAM_ADMIN_ACCESS_TOKEN
)
rc.set("cors", ",".join(allow_origins))
self.allow_origins = list(allow_origins)
is_allowed_origin = rc_client.sismember(REDIS_CONFIG_CORS_KEY, origin)
return is_allowed_origin
except Exception as err:
logger.warning(
f"Unable to get CORS origins from Redis cache, using default from environment variable, err: {str(err)}"
f"Unable to fetch CORS origins from Redis cache, err: {str(err)}"
)
allow_origins = fetch_application_settings_cors_origins(
token=MOONSTREAM_ADMIN_ACCESS_TOKEN
)
self.allow_origins = list(allow_origins)
finally:
rc_client.close()
await super().__call__(scope, receive, send)
return origin in self.allow_origins

Wyświetl plik

@ -6,6 +6,8 @@ from redis import asyncio as aioredis
from .settings import ENGINE_REDIS_URI
REDIS_CONFIG_CORS_KEY = "configs:cors:engineapi"
def create_redis_client() -> Redis:
rc_pool = ConnectionPool.from_url(

Wyświetl plik

@ -71,7 +71,7 @@ app.add_middleware(
@app.get("/cors", response_model=data.CORSResponse)
async def get_cors(
request: Request,
):
) -> data.CORSResponse:
try:
resources = bc.list_resources(
token=request.state.token,
@ -89,7 +89,7 @@ async def get_cors(
logger.error(repr(err))
raise EngineHTTPException(status_code=500)
return data.CORSResponse(cors=",".join(list(resource_origins_set)))
return data.CORSResponse(cors=list(resource_origins_set))
@app.put("/cors", response_model=data.CORSResponse)
@ -97,7 +97,7 @@ async def update_cors(
request: Request,
background_tasks: BackgroundTasks,
new_origins: List[AnyHttpUrl] = Body(...),
):
) -> data.CORSResponse:
new_origins = set(new_origins)
try:
@ -158,4 +158,4 @@ async def update_cors(
fetch_and_set_cors_origins_cache,
)
return data.CORSResponse(cors=",".join(target_resource.resource_data["origins"]))
return data.CORSResponse(cors=target_resource.resource_data["origins"])