Switched middleware to depends oauth2_scheme in metatx

pull/932/head
kompotkot 2023-12-07 15:04:16 +00:00
rodzic f3b4e3e502
commit 065ff03476
3 zmienionych plików z 69 dodań i 58 usunięć

Wyświetl plik

@ -1,7 +1,7 @@
"""Tx hash for call requests
Revision ID: 7191eb70e99e
Revises: 4f05d212ea49
Revises: 6d07739cb13e
Create Date: 2023-10-04 11:23:12.516797
"""
@ -11,7 +11,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '7191eb70e99e'
down_revision = '4f05d212ea49'
down_revision = '6d07739cb13e'
branch_labels = None
depends_on = None

Wyświetl plik

@ -8,7 +8,8 @@ from bugout.data import BugoutResource, BugoutResources, BugoutUser
from bugout.exceptions import BugoutResponseException
from eip712.messages import EIP712Message, _hash_eip191_message
from eth_account.messages import encode_defunct
from fastapi import Header, HTTPException, Request, Response
from fastapi import Depends, Header, HTTPException, Request, Response
from fastapi.security import OAuth2PasswordBearer
from hexbytes import HexBytes
from pydantic import AnyHttpUrl, parse_obj_as
from starlette.middleware.base import BaseHTTPMiddleware
@ -39,6 +40,8 @@ from .settings import bugout_client as bc
logger = logging.getLogger(__name__)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class InvalidAuthHeaderFormat(Exception):
"""
@ -82,7 +85,44 @@ def bugout_auth(token: str) -> BugoutUser:
return user
async def user_for_auth_header(
def brood_auth(token: UUID) -> BugoutUser:
try:
user: BugoutUser = bugout_auth(token=token)
except BugoutUnverifiedAuth:
logger.info(f"Attempted access by unverified Brood account: {user.id}")
raise EngineHTTPException(
status_code=403,
detail="Only verified accounts can have access",
)
except BugoutAuthWrongApp:
raise EngineHTTPException(
status_code=403,
detail="User does not belong to this application",
)
except BugoutResponseException as e:
raise EngineHTTPException(
status_code=e.status_code,
detail=e.detail,
)
except Exception as e:
logger.error(f"Error processing Brood response: {str(e)}")
raise EngineHTTPException(
status_code=500,
detail="Internal server error",
)
return user
async def request_user_auth(
token: UUID = Depends(oauth2_scheme),
) -> BugoutUser:
user = brood_auth(token=token)
return user
async def request_none_or_user_auth(
authorization: str = Header(None),
) -> Optional[BugoutUser]:
"""
@ -90,9 +130,9 @@ async def user_for_auth_header(
"""
user: Optional[BugoutUser] = None
if authorization is not None:
user_token: str = ""
token: str = ""
try:
_, user_token = parse_auth_header(auth_header=authorization)
_, token = parse_auth_header(auth_header=authorization)
except InvalidAuthHeaderFormat:
raise EngineHTTPException(
status_code=403, detail="Wrong authorization header"
@ -101,24 +141,8 @@ async def user_for_auth_header(
logger.error(f"Error parsing auth header: {str(e)}")
raise EngineHTTPException(status_code=500, detail="Internal server error")
if user_token != "":
try:
user: BugoutUser = bugout_auth(token=user_token)
except BugoutUnverifiedAuth:
logger.info(f"Attempted access by unverified Brood account: {user.id}")
raise EngineHTTPException(
status_code=403,
detail="Only verified accounts can have access",
)
except BugoutAuthWrongApp:
raise EngineHTTPException(
status_code=403, detail="User does not belong to this application"
)
except BugoutResponseException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
except Exception as e:
logger.error(f"Error processing Brood response: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
if token != "":
user = brood_auth(token=token)
return user

Wyświetl plik

@ -16,11 +16,11 @@ from sqlalchemy.orm import Session
from .. import contracts_actions, data, db
from ..middleware import (
BroodAuthMiddleware,
BugoutCORSMiddleware,
EngineHTTPException,
metatx_sign_header,
user_for_auth_header,
request_none_or_user_auth,
request_user_auth,
)
from ..settings import DOCS_TARGET_PATH
from ..version import VERSION
@ -41,16 +41,6 @@ tags_metadata = [
]
whitelist_paths = {
"/metatx/openapi.json": "GET",
f"/metatx/{DOCS_TARGET_PATH}": "GET",
"/metatx/blockchains": "GET",
"/metatx/contracts/types": "GET",
"/metatx/requests/types": "GET",
"/metatx/requests": "GET", # Controls by custom authentication check
"/metatx/requests/complete": "POST", # Controls by metatx authentication check
}
app = FastAPI(
title=TITLE,
description=DESCRIPTION,
@ -61,9 +51,6 @@ app = FastAPI(
redoc_url=f"/{DOCS_TARGET_PATH}",
)
app.add_middleware(BroodAuthMiddleware, whitelist=whitelist_paths)
app.add_middleware(
BugoutCORSMiddleware,
allow_credentials=True,
@ -97,11 +84,11 @@ async def blockchains_route(
response_model=List[data.RegisteredContractResponse],
)
async def list_registered_contracts_route(
request: Request,
blockchain: Optional[str] = Query(None),
address: Optional[str] = Query(None),
limit: int = Query(10),
offset: Optional[int] = Query(None),
user: BugoutUser = Depends(request_user_auth),
db_session: Session = Depends(db.yield_db_read_only_session),
) -> List[data.RegisteredContractResponse]:
"""
@ -111,7 +98,7 @@ async def list_registered_contracts_route(
registered_contracts_with_blockchain = (
contracts_actions.lookup_registered_contracts(
db_session=db_session,
metatx_requester_id=request.state.user.id,
metatx_requester_id=user.id,
blockchain=blockchain,
address=address,
limit=limit,
@ -134,8 +121,8 @@ async def list_registered_contracts_route(
response_model=data.RegisteredContractResponse,
)
async def get_registered_contract_route(
request: Request,
contract_id: UUID = Path(...),
user: BugoutUser = Depends(request_user_auth),
db_session: Session = Depends(db.yield_db_read_only_session),
) -> List[data.RegisteredContractResponse]:
"""
@ -144,7 +131,7 @@ async def get_registered_contract_route(
try:
contract_with_blockchain = contracts_actions.get_registered_contract(
db_session=db_session,
metatx_requester_id=request.state.user.id,
metatx_requester_id=user.id,
contract_id=contract_id,
)
except NoResultFound:
@ -165,8 +152,8 @@ async def get_registered_contract_route(
"/contracts", tags=["contracts"], response_model=data.RegisteredContractResponse
)
async def register_contract_route(
request: Request,
contract: data.RegisterContractRequest = Body(...),
user: BugoutUser = Depends(request_user_auth),
db_session: Session = Depends(db.yield_db_session),
) -> data.RegisteredContractResponse:
"""
@ -175,7 +162,7 @@ async def register_contract_route(
try:
contract_with_blockchain = contracts_actions.register_contract(
db_session=db_session,
metatx_requester_id=request.state.user.id,
metatx_requester_id=user.id,
blockchain_name=contract.blockchain,
address=contract.address,
title=contract.title,
@ -206,15 +193,15 @@ async def register_contract_route(
response_model=data.RegisteredContractResponse,
)
async def update_contract_route(
request: Request,
contract_id: UUID = Path(...),
update_info: data.UpdateContractRequest = Body(...),
user: BugoutUser = Depends(request_user_auth),
db_session: Session = Depends(db.yield_db_session),
) -> data.RegisteredContractResponse:
try:
contract_with_blockchain = contracts_actions.update_registered_contract(
db_session=db_session,
metatx_requester_id=request.state.user.id,
metatx_requester_id=user.id,
contract_id=contract_id,
title=update_info.title,
description=update_info.description,
@ -241,8 +228,8 @@ async def update_contract_route(
response_model=data.RegisteredContractResponse,
)
async def delete_contract_route(
request: Request,
contract_id: UUID = Path(...),
user: BugoutUser = Depends(request_user_auth),
db_session: Session = Depends(db.yield_db_session),
) -> data.RegisteredContractResponse:
"""
@ -251,7 +238,7 @@ async def delete_contract_route(
try:
deleted_contract_with_blockchain = contracts_actions.delete_registered_contract(
db_session=db_session,
metatx_requester_id=request.state.user.id,
metatx_requester_id=user.id,
registered_contract_id=contract_id,
)
except Exception as err:
@ -299,7 +286,7 @@ async def list_requests_route(
offset: Optional[int] = Query(None),
show_expired: bool = Query(False),
show_before_live_at: bool = Query(False),
user: Optional[BugoutUser] = Depends(user_for_auth_header),
user: Optional[BugoutUser] = Depends(request_none_or_user_auth),
db_session: Session = Depends(db.yield_db_read_only_session),
) -> List[data.CallRequestResponse]:
"""
@ -334,6 +321,7 @@ async def list_requests_route(
)
async def get_request(
request_id: UUID = Path(...),
_: BugoutUser = Depends(request_user_auth),
db_session: Session = Depends(db.yield_db_read_only_session),
) -> List[data.CallRequestResponse]:
"""
@ -360,8 +348,8 @@ async def get_request(
@app.post("/requests", tags=["requests"], response_model=int)
async def create_requests(
request: Request,
data: data.CreateCallRequestsAPIRequest = Body(...),
user: BugoutUser = Depends(request_user_auth),
db_session: Session = Depends(db.yield_db_session),
) -> int:
"""
@ -372,7 +360,7 @@ async def create_requests(
try:
num_requests = contracts_actions.create_request_calls(
db_session=db_session,
metatx_requester_id=request.state.user.id,
metatx_requester_id=user.id,
registered_contract_id=data.contract_id,
contract_address=data.contract_address,
call_specs=data.specifications,
@ -413,8 +401,8 @@ async def create_requests(
@app.delete("/requests", tags=["requests"], response_model=int)
async def delete_requests(
request: Request,
request_ids: List[UUID] = Body(...),
user: BugoutUser = Depends(request_user_auth),
db_session: Session = Depends(db.yield_db_session),
) -> int:
"""
@ -423,7 +411,7 @@ async def delete_requests(
try:
deleted_requests = contracts_actions.delete_requests(
db_session=db_session,
metatx_requester_id=request.state.user.id,
metatx_requester_id=user.id,
request_ids=request_ids,
)
except Exception as err:
@ -433,11 +421,10 @@ async def delete_requests(
return deleted_requests
# @app.post("/requests/{request_id}/complete", tags=["requests"])
@app.post("/requests/complete", tags=["requests"])
@app.post("/requests/{request_id}/complete", tags=["requests"])
async def complete_call_request_route(
tx_hash: str = Form(...),
call_request_id: UUID = Form(...),
request_id: UUID = Path(...),
message=Depends(metatx_sign_header),
db_session: Session = Depends(db.yield_db_session),
):
@ -448,7 +435,7 @@ async def complete_call_request_route(
request = contracts_actions.complete_call_request(
db_session=db_session,
tx_hash=tx_hash,
call_request_id=call_request_id,
call_request_id=request_id,
caller=message["caller"],
)
except contracts_actions.CallRequestNotFound: