kopia lustrzana https://github.com/bugout-dev/moonstream
				
				
				
			Used Redis cache for CORS origins cache
							rodzic
							
								
									872c1f6e76
								
							
						
					
					
						commit
						aca575052b
					
				|  | @ -3,7 +3,7 @@ from enum import Enum | ||||||
| from typing import Any, Dict, List, Optional | from typing import Any, Dict, List, Optional | ||||||
| from uuid import UUID | from uuid import UUID | ||||||
| 
 | 
 | ||||||
| from pydantic import BaseModel, Field, root_validator, validator | from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator | ||||||
| from web3 import Web3 | from web3 import Web3 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -24,7 +24,7 @@ class NowResponse(BaseModel): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CORSResponse(BaseModel): | class CORSResponse(BaseModel): | ||||||
|     cors: str |     cors: List[AnyHttpUrl] = Field(default_factory=list) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SignerListResponse(BaseModel): | class SignerListResponse(BaseModel): | ||||||
|  |  | ||||||
|  | @ -10,7 +10,7 @@ from pydantic import AnyHttpUrl, parse_obj_as | ||||||
| from starlette.middleware.base import BaseHTTPMiddleware | from starlette.middleware.base import BaseHTTPMiddleware | ||||||
| from starlette.middleware.cors import CORSMiddleware | from starlette.middleware.cors import CORSMiddleware | ||||||
| from starlette.responses import Response | from starlette.responses import Response | ||||||
| from starlette.types import ASGIApp, Receive, Scope, Send | from starlette.types import ASGIApp | ||||||
| from web3 import Web3 | from web3 import Web3 | ||||||
| 
 | 
 | ||||||
| from .auth import ( | from .auth import ( | ||||||
|  | @ -18,7 +18,7 @@ from .auth import ( | ||||||
|     MoonstreamAuthorizationVerificationError, |     MoonstreamAuthorizationVerificationError, | ||||||
|     verify, |     verify, | ||||||
| ) | ) | ||||||
| from .rc import rc_client, yield_rc_async_session | from .rc import REDIS_CONFIG_CORS_KEY, rc_client | ||||||
| from .settings import ( | from .settings import ( | ||||||
|     ALLOW_ORIGINS, |     ALLOW_ORIGINS, | ||||||
|     BUGOUT_REQUEST_TIMEOUT_SECONDS, |     BUGOUT_REQUEST_TIMEOUT_SECONDS, | ||||||
|  | @ -299,8 +299,7 @@ def fetch_application_settings_cors_origins(token: str) -> Set[str]: | ||||||
| 
 | 
 | ||||||
| def set_cors_origins_cache(allow_origins: Set[str]) -> None: | def set_cors_origins_cache(allow_origins: Set[str]) -> None: | ||||||
|     try: |     try: | ||||||
|         allow_origins_str = ",".join(list(allow_origins)) |         rc_client.sadd(REDIS_CONFIG_CORS_KEY, *allow_origins) | ||||||
|         rc_client.set("cors", allow_origins_str) |  | ||||||
|     except Exception: |     except Exception: | ||||||
|         logger.warning("Unable to set CORS origins at Redis cache") |         logger.warning("Unable to set CORS origins at Redis cache") | ||||||
|     finally: |     finally: | ||||||
|  | @ -331,11 +330,11 @@ class BugoutCORSMiddleware(CORSMiddleware): | ||||||
|         expose_headers: Sequence[str] = (), |         expose_headers: Sequence[str] = (), | ||||||
|         max_age: int = 600, |         max_age: int = 600, | ||||||
|     ): |     ): | ||||||
|         self.allow_origins = fetch_and_set_cors_origins_cache() |         application_configs_allowed_origins = fetch_and_set_cors_origins_cache() | ||||||
| 
 | 
 | ||||||
|         super().__init__( |         super().__init__( | ||||||
|             app=app, |             app=app, | ||||||
|             allow_origins=self.allow_origins, |             allow_origins=application_configs_allowed_origins, | ||||||
|             allow_methods=allow_methods, |             allow_methods=allow_methods, | ||||||
|             allow_headers=allow_headers, |             allow_headers=allow_headers, | ||||||
|             allow_credentials=allow_credentials, |             allow_credentials=allow_credentials, | ||||||
|  | @ -344,25 +343,23 @@ class BugoutCORSMiddleware(CORSMiddleware): | ||||||
|             max_age=max_age, |             max_age=max_age, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     async def __call__(self, scope: Scope, receive: Receive, send: Send): |     def is_allowed_origin(self, origin: str) -> bool: | ||||||
|  |         if self.allow_all_origins: | ||||||
|  |             return True | ||||||
|  | 
 | ||||||
|  |         if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch( | ||||||
|  |             origin | ||||||
|  |         ): | ||||||
|  |             return True | ||||||
|  | 
 | ||||||
|         try: |         try: | ||||||
|             async with yield_rc_async_session() as rc: |             is_allowed_origin = rc_client.sismember(REDIS_CONFIG_CORS_KEY, origin) | ||||||
|                 corse_rc = await rc.get("cors") |             return is_allowed_origin | ||||||
|                 if corse_rc is not None: |  | ||||||
|                     self.allow_origins = corse_rc.split(",") |  | ||||||
|                 else: |  | ||||||
|                     allow_origins = fetch_application_settings_cors_origins( |  | ||||||
|                         token=MOONSTREAM_ADMIN_ACCESS_TOKEN |  | ||||||
|                     ) |  | ||||||
|                     rc.set("cors", ",".join(allow_origins)) |  | ||||||
|                     self.allow_origins = list(allow_origins) |  | ||||||
|         except Exception as err: |         except Exception as err: | ||||||
|             logger.warning( |             logger.warning( | ||||||
|                 f"Unable to get CORS origins from Redis cache, using default from environment variable, err: {str(err)}" |                 f"Unable to fetch CORS origins from Redis cache, err: {str(err)}" | ||||||
|             ) |             ) | ||||||
|             allow_origins = fetch_application_settings_cors_origins( |         finally: | ||||||
|                 token=MOONSTREAM_ADMIN_ACCESS_TOKEN |             rc_client.close() | ||||||
|             ) |  | ||||||
|             self.allow_origins = list(allow_origins) |  | ||||||
| 
 | 
 | ||||||
|         await super().__call__(scope, receive, send) |         return origin in self.allow_origins | ||||||
|  |  | ||||||
|  | @ -6,6 +6,8 @@ from redis import asyncio as aioredis | ||||||
| 
 | 
 | ||||||
| from .settings import ENGINE_REDIS_URI | from .settings import ENGINE_REDIS_URI | ||||||
| 
 | 
 | ||||||
|  | REDIS_CONFIG_CORS_KEY = "configs:cors:engineapi" | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def create_redis_client() -> Redis: | def create_redis_client() -> Redis: | ||||||
|     rc_pool = ConnectionPool.from_url( |     rc_pool = ConnectionPool.from_url( | ||||||
|  |  | ||||||
|  | @ -71,7 +71,7 @@ app.add_middleware( | ||||||
| @app.get("/cors", response_model=data.CORSResponse) | @app.get("/cors", response_model=data.CORSResponse) | ||||||
| async def get_cors( | async def get_cors( | ||||||
|     request: Request, |     request: Request, | ||||||
| ): | ) -> data.CORSResponse: | ||||||
|     try: |     try: | ||||||
|         resources = bc.list_resources( |         resources = bc.list_resources( | ||||||
|             token=request.state.token, |             token=request.state.token, | ||||||
|  | @ -89,7 +89,7 @@ async def get_cors( | ||||||
|         logger.error(repr(err)) |         logger.error(repr(err)) | ||||||
|         raise EngineHTTPException(status_code=500) |         raise EngineHTTPException(status_code=500) | ||||||
| 
 | 
 | ||||||
|     return data.CORSResponse(cors=",".join(list(resource_origins_set))) |     return data.CORSResponse(cors=list(resource_origins_set)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.put("/cors", response_model=data.CORSResponse) | @app.put("/cors", response_model=data.CORSResponse) | ||||||
|  | @ -97,7 +97,7 @@ async def update_cors( | ||||||
|     request: Request, |     request: Request, | ||||||
|     background_tasks: BackgroundTasks, |     background_tasks: BackgroundTasks, | ||||||
|     new_origins: List[AnyHttpUrl] = Body(...), |     new_origins: List[AnyHttpUrl] = Body(...), | ||||||
| ): | ) -> data.CORSResponse: | ||||||
|     new_origins = set(new_origins) |     new_origins = set(new_origins) | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|  | @ -158,4 +158,4 @@ async def update_cors( | ||||||
|         fetch_and_set_cors_origins_cache, |         fetch_and_set_cors_origins_cache, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     return data.CORSResponse(cors=",".join(target_resource.resource_data["origins"])) |     return data.CORSResponse(cors=target_resource.resource_data["origins"]) | ||||||
|  |  | ||||||
		Ładowanie…
	
		Reference in New Issue
	
	 kompotkot
						kompotkot