kopia lustrzana https://github.com/bugout-dev/moonstream
Used Redis cache for CORS origins cache
rodzic
872c1f6e76
commit
aca575052b
|
|
@ -3,7 +3,7 @@ from enum import Enum
|
|||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
|
||||
from web3 import Web3
|
||||
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ class NowResponse(BaseModel):
|
|||
|
||||
|
||||
class CORSResponse(BaseModel):
|
||||
cors: str
|
||||
cors: List[AnyHttpUrl] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SignerListResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pydantic import AnyHttpUrl, parse_obj_as
|
|||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp
|
||||
from web3 import Web3
|
||||
|
||||
from .auth import (
|
||||
|
|
@ -18,7 +18,7 @@ from .auth import (
|
|||
MoonstreamAuthorizationVerificationError,
|
||||
verify,
|
||||
)
|
||||
from .rc import rc_client, yield_rc_async_session
|
||||
from .rc import REDIS_CONFIG_CORS_KEY, rc_client
|
||||
from .settings import (
|
||||
ALLOW_ORIGINS,
|
||||
BUGOUT_REQUEST_TIMEOUT_SECONDS,
|
||||
|
|
@ -299,8 +299,7 @@ def fetch_application_settings_cors_origins(token: str) -> Set[str]:
|
|||
|
||||
def set_cors_origins_cache(allow_origins: Set[str]) -> None:
|
||||
try:
|
||||
allow_origins_str = ",".join(list(allow_origins))
|
||||
rc_client.set("cors", allow_origins_str)
|
||||
rc_client.sadd(REDIS_CONFIG_CORS_KEY, *allow_origins)
|
||||
except Exception:
|
||||
logger.warning("Unable to set CORS origins at Redis cache")
|
||||
finally:
|
||||
|
|
@ -331,11 +330,11 @@ class BugoutCORSMiddleware(CORSMiddleware):
|
|||
expose_headers: Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
):
|
||||
self.allow_origins = fetch_and_set_cors_origins_cache()
|
||||
application_configs_allowed_origins = fetch_and_set_cors_origins_cache()
|
||||
|
||||
super().__init__(
|
||||
app=app,
|
||||
allow_origins=self.allow_origins,
|
||||
allow_origins=application_configs_allowed_origins,
|
||||
allow_methods=allow_methods,
|
||||
allow_headers=allow_headers,
|
||||
allow_credentials=allow_credentials,
|
||||
|
|
@ -344,25 +343,23 @@ class BugoutCORSMiddleware(CORSMiddleware):
|
|||
max_age=max_age,
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
def is_allowed_origin(self, origin: str) -> bool:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
|
||||
origin
|
||||
):
|
||||
return True
|
||||
|
||||
try:
|
||||
async with yield_rc_async_session() as rc:
|
||||
corse_rc = await rc.get("cors")
|
||||
if corse_rc is not None:
|
||||
self.allow_origins = corse_rc.split(",")
|
||||
else:
|
||||
allow_origins = fetch_application_settings_cors_origins(
|
||||
token=MOONSTREAM_ADMIN_ACCESS_TOKEN
|
||||
)
|
||||
rc.set("cors", ",".join(allow_origins))
|
||||
self.allow_origins = list(allow_origins)
|
||||
is_allowed_origin = rc_client.sismember(REDIS_CONFIG_CORS_KEY, origin)
|
||||
return is_allowed_origin
|
||||
except Exception as err:
|
||||
logger.warning(
|
||||
f"Unable to get CORS origins from Redis cache, using default from environment variable, err: {str(err)}"
|
||||
f"Unable to fetch CORS origins from Redis cache, err: {str(err)}"
|
||||
)
|
||||
allow_origins = fetch_application_settings_cors_origins(
|
||||
token=MOONSTREAM_ADMIN_ACCESS_TOKEN
|
||||
)
|
||||
self.allow_origins = list(allow_origins)
|
||||
finally:
|
||||
rc_client.close()
|
||||
|
||||
await super().__call__(scope, receive, send)
|
||||
return origin in self.allow_origins
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from redis import asyncio as aioredis
|
|||
|
||||
from .settings import ENGINE_REDIS_URI
|
||||
|
||||
REDIS_CONFIG_CORS_KEY = "configs:cors:engineapi"
|
||||
|
||||
|
||||
def create_redis_client() -> Redis:
|
||||
rc_pool = ConnectionPool.from_url(
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ app.add_middleware(
|
|||
@app.get("/cors", response_model=data.CORSResponse)
|
||||
async def get_cors(
|
||||
request: Request,
|
||||
):
|
||||
) -> data.CORSResponse:
|
||||
try:
|
||||
resources = bc.list_resources(
|
||||
token=request.state.token,
|
||||
|
|
@ -89,7 +89,7 @@ async def get_cors(
|
|||
logger.error(repr(err))
|
||||
raise EngineHTTPException(status_code=500)
|
||||
|
||||
return data.CORSResponse(cors=",".join(list(resource_origins_set)))
|
||||
return data.CORSResponse(cors=list(resource_origins_set))
|
||||
|
||||
|
||||
@app.put("/cors", response_model=data.CORSResponse)
|
||||
|
|
@ -97,7 +97,7 @@ async def update_cors(
|
|||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
new_origins: List[AnyHttpUrl] = Body(...),
|
||||
):
|
||||
) -> data.CORSResponse:
|
||||
new_origins = set(new_origins)
|
||||
|
||||
try:
|
||||
|
|
@ -158,4 +158,4 @@ async def update_cors(
|
|||
fetch_and_set_cors_origins_cache,
|
||||
)
|
||||
|
||||
return data.CORSResponse(cors=",".join(target_resource.resource_data["origins"]))
|
||||
return data.CORSResponse(cors=target_resource.resource_data["origins"])
|
||||
|
|
|
|||
Ładowanie…
Reference in New Issue