diff --git a/engineapi/alembic/versions/6d07739cb13e_live_at_for_metatx.py b/engineapi/alembic/versions/6d07739cb13e_live_at_for_metatx.py new file mode 100644 index 00000000..498705a4 --- /dev/null +++ b/engineapi/alembic/versions/6d07739cb13e_live_at_for_metatx.py @@ -0,0 +1,28 @@ +"""Live at for metatx + +Revision ID: 6d07739cb13e +Revises: 71e888082a6d +Create Date: 2023-12-06 14:33:04.814144 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '6d07739cb13e' +down_revision = '71e888082a6d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('call_requests', sa.Column('live_at', sa.DateTime(timezone=True), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('call_requests', 'live_at') + # ### end Alembic commands ### diff --git a/engineapi/engineapi/contracts_actions.py b/engineapi/engineapi/contracts_actions.py index e5a632c1..e1c39e29 100644 --- a/engineapi/engineapi/contracts_actions.py +++ b/engineapi/engineapi/contracts_actions.py @@ -2,10 +2,10 @@ import argparse import json import logging import uuid -from datetime import timedelta +from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Tuple -from sqlalchemy import func, text +from sqlalchemy import func, or_, text from sqlalchemy.dialects.postgresql import insert from sqlalchemy.engine import Row from sqlalchemy.exc import IntegrityError, NoResultFound @@ -101,6 +101,7 @@ def parse_call_request_response( request_id=str(obj[0].request_id), parameters=obj[0].parameters, expires_at=obj[0].expires_at, + live_at=obj[0].live_at, created_at=obj[0].created_at, updated_at=obj[0].updated_at, ) @@ -326,13 +327,14 @@ def delete_registered_contract( return (registered_contract, blockchain) -def request_calls( +def create_request_calls( db_session: Session, metatx_requester_id: uuid.UUID, registered_contract_id: Optional[uuid.UUID], contract_address: Optional[str], call_specs: List[data.CallSpecification], ttl_days: Optional[int] = None, + live_at: Optional[int] = None, ) -> int: """ Batch creates call requests for the given registered contract. @@ -350,6 +352,11 @@ def request_calls( if ttl_days <= 0: raise ValueError("ttl_days must be positive") + if live_at is not None: + assert live_at == int(live_at) + if live_at <= 0: + raise ValueError("live_at must be positive") + # Check that the moonstream_user_id matches a RegisteredContract with the given id or address query = db_session.query(RegisteredContract).filter( RegisteredContract.metatx_requester_id == metatx_requester_id @@ -406,6 +413,7 @@ def request_calls( request_id=specification.request_id, parameters=specification.parameters, expires_at=expires_at, + live_at=datetime.fromtimestamp(live_at) if live_at is not None else None, ) db_session.add(request) @@ -422,7 +430,7 @@ def request_calls( return len(call_specs) -def get_call_requests( +def get_call_request( db_session: Session, request_id: uuid.UUID, ) -> Tuple[CallRequest, RegisteredContract]: @@ -472,9 +480,14 @@ def list_call_requests( limit: int = 10, offset: Optional[int] = None, show_expired: bool = False, + show_before_live_at: bool = False, + metatx_requester_id: Optional[uuid.UUID] = None, ) -> List[Row[Tuple[CallRequest, RegisteredContract, CallRequestType]]]: """ - List call requests for the given moonstream_user_id + List call requests. + + Argument moonstream_user_id took from authorization workflow. And if it is specified + then user has access to call_requests before live_at param. """ if caller is None: raise ValueError("caller must be specified") @@ -507,6 +520,21 @@ def list_call_requests( CallRequest.expires_at > func.now(), ) + # If user id not specified, do not show call_requests before live_at. + # Otherwise check show_before_live_at argument from query parameter + if metatx_requester_id is not None: + query = query.filter( + CallRequest.metatx_requester_id == metatx_requester_id, + ) + if not show_before_live_at: + query = query.filter( + or_(CallRequest.live_at < func.now(), CallRequest.live_at == None) + ) + else: + query = query.filter( + or_(CallRequest.live_at < func.now(), CallRequest.live_at == None) + ) + if offset is not None: query = query.offset(offset) @@ -633,7 +661,7 @@ def handle_request_calls(args: argparse.Namespace) -> None: try: with db.yield_db_session_ctx() as db_session: - request_calls( + create_request_calls( db_session=db_session, moonstream_user_id=args.moonstream_user_id, registered_contract_id=args.registered_contract_id, diff --git a/engineapi/engineapi/data.py b/engineapi/engineapi/data.py index 9830427e..e812df49 100644 --- a/engineapi/engineapi/data.py +++ b/engineapi/engineapi/data.py @@ -284,6 +284,7 @@ class CreateCallRequestsAPIRequest(BaseModel): contract_address: Optional[str] = None specifications: List[CallSpecification] = Field(default_factory=list) ttl_days: Optional[int] = None + live_at: Optional[int] = None # Solution found thanks to https://github.com/pydantic/pydantic/issues/506 @root_validator @@ -306,6 +307,7 @@ class CallRequestResponse(BaseModel): request_id: str parameters: Dict[str, Any] expires_at: Optional[datetime] = None + live_at: Optional[datetime] = None created_at: datetime updated_at: datetime diff --git a/engineapi/engineapi/middleware.py b/engineapi/engineapi/middleware.py index f2bb83b0..5986b417 100644 --- a/engineapi/engineapi/middleware.py +++ b/engineapi/engineapi/middleware.py @@ -6,8 +6,9 @@ from uuid import UUID from bugout.data import BugoutResource, BugoutResources, BugoutUser from bugout.exceptions import BugoutResponseException -from fastapi import HTTPException, Request, Response +from fastapi import Header, HTTPException, Request, Response from pydantic import AnyHttpUrl, parse_obj_as +from starlette.datastructures import Headers from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.responses import Response @@ -26,14 +27,97 @@ from .settings import ( BUGOUT_REQUEST_TIMEOUT_SECONDS, BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG, MOONSTREAM_ADMIN_ACCESS_TOKEN, - MOONSTREAM_APPLICATION_ID, MOONSTREAM_ADMIN_ID, + MOONSTREAM_APPLICATION_ID, ) from .settings import bugout_client as bc logger = logging.getLogger(__name__) +class InvalidAuthHeaderFormat(Exception): + """ + Raised when authorization header not pass validation. + """ + + +class BugoutUnverifiedAuth(Exception): + """ + Raised when attempted access by unverified Brood account. + """ + + +class BugoutAuthWrongApp(Exception): + """ + Raised when user does not belong to this application. + """ + + +def parse_auth_header(auth_header: str) -> Tuple[str, str]: + """ + Returns: auth_format and user_token passed in authorization header. + """ + auth_list = auth_header.split() + if len(auth_list) != 2: + raise InvalidAuthHeaderFormat("Wrong authorization header") + + return auth_list[0], auth_list[1] + + +def bugout_auth(token: str) -> BugoutUser: + """ + Extended bugout.get_user with additional checks. + """ + user: BugoutUser = bc.get_user(token) + if not user.verified: + raise BugoutUnverifiedAuth("Only verified accounts can have access") + if str(user.application_id) != str(MOONSTREAM_APPLICATION_ID): + raise BugoutAuthWrongApp("User does not belong to this application") + + return user + + +async def user_for_auth_header( + authorization: str = Header(None), +) -> Optional[BugoutUser]: + """ + Fetch Bugout user if authorization token provided. + """ + user: Optional[BugoutUser] = None + if authorization is not None: + user_token: str = "" + try: + _, user_token = parse_auth_header(auth_header=authorization) + except InvalidAuthHeaderFormat: + raise EngineHTTPException( + status_code=403, detail="Wrong authorization header" + ) + except Exception as e: + logger.error(f"Error processing Brood response: {str(e)}") + raise EngineHTTPException(status_code=500, detail="Internal server error") + + if user_token != "": + try: + user: BugoutUser = bugout_auth(token=user_token) + except BugoutUnverifiedAuth: + logger.info(f"Attempted access by unverified Brood account: {user.id}") + raise EngineHTTPException( + status_code=403, + detail="Only verified accounts can have access", + ) + except BugoutAuthWrongApp: + raise EngineHTTPException( + status_code=403, detail="User does not belong to this application" + ) + except BugoutResponseException as e: + raise HTTPException(status_code=e.status_code, detail=e.detail) + except Exception as e: + logger.error(f"Error processing Brood response: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + return user + + class BroodAuthMiddleware(BaseHTTPMiddleware): """ Checks the authorization header on the request. If it represents a verified Brood user, @@ -59,30 +143,33 @@ class BroodAuthMiddleware(BaseHTTPMiddleware): if path in self.whitelist.keys() and self.whitelist[path] == method: return await call_next(request) - authorization_header = request.headers.get("authorization") - if authorization_header is None: + authorization = request.headers.get("authorization") + if authorization is None: return Response( - status_code=403, content="No authorization header passed with request" + status_code=403, + content="No authorization header passed with request", ) - user_token_list = authorization_header.split() - if len(user_token_list) != 2: - return Response(status_code=403, content="Wrong authorization header") - user_token: str = user_token_list[-1] try: - user: BugoutUser = bc.get_user(user_token) - if not user.verified: - logger.info( - f"Attempted journal access by unverified Brood account: {user.id}" - ) - return Response( - status_code=403, - content="Only verified accounts can access journals", - ) - if str(user.application_id) != str(MOONSTREAM_APPLICATION_ID): - return Response( - status_code=403, content="User does not belong to this application" - ) + _, user_token = parse_auth_header(auth_header=authorization) + except InvalidAuthHeaderFormat: + return Response(status_code=403, content="Wrong authorization header") + except Exception as e: + logger.error(f"Error processing Brood response: {str(e)}") + return Response(status_code=500, content="Internal server error") + + try: + user: BugoutUser = bugout_auth(token=user_token) + except BugoutUnverifiedAuth: + logger.info(f"Attempted access by unverified Brood account: {user.id}") + return Response( + status_code=403, + content="Only verified accounts can have access", + ) + except BugoutAuthWrongApp: + return Response( + status_code=403, content="User does not belong to this application" + ) except BugoutResponseException as e: return Response(status_code=e.status_code, content=e.detail) except Exception as e: diff --git a/engineapi/engineapi/models.py b/engineapi/engineapi/models.py index 9e384f27..3a3d6909 100644 --- a/engineapi/engineapi/models.py +++ b/engineapi/engineapi/models.py @@ -317,6 +317,7 @@ class CallRequest(Base): parameters = Column(JSONB, nullable=False) expires_at = Column(DateTime(timezone=True), nullable=True, index=True) + live_at = Column(DateTime(timezone=True), nullable=True) created_at = Column( DateTime(timezone=True), server_default=utcnow(), nullable=False diff --git a/engineapi/engineapi/routes/metatx.py b/engineapi/engineapi/routes/metatx.py index 3af6bac9..4ac03a4d 100644 --- a/engineapi/engineapi/routes/metatx.py +++ b/engineapi/engineapi/routes/metatx.py @@ -9,12 +9,18 @@ import logging from typing import Dict, List, Optional from uuid import UUID +from bugout.data import BugoutUser from fastapi import Body, Depends, FastAPI, Path, Query, Request from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Session from .. import contracts_actions, data, db -from ..middleware import BroodAuthMiddleware, BugoutCORSMiddleware, EngineHTTPException +from ..middleware import ( + BroodAuthMiddleware, + BugoutCORSMiddleware, + EngineHTTPException, + user_for_auth_header, +) from ..settings import DOCS_TARGET_PATH from ..version import VERSION @@ -40,7 +46,7 @@ whitelist_paths = { "/metatx/blockchains": "GET", "/metatx/contracts/types": "GET", "/metatx/requests/types": "GET", - "/metatx/requests": "GET", + "/metatx/requests": "GET", # Controls by custom authentication check } app = FastAPI( @@ -278,14 +284,20 @@ async def call_request_types_route( return call_request_types -@app.get("/requests", tags=["requests"], response_model=List[data.CallRequestResponse]) +@app.get( + "/requests", + tags=["requests"], + response_model=List[data.CallRequestResponse], +) async def list_requests_route( contract_id: Optional[UUID] = Query(None), contract_address: Optional[str] = Query(None), caller: str = Query(...), limit: int = Query(100), offset: Optional[int] = Query(None), - show_expired: Optional[bool] = Query(False), + show_expired: bool = Query(False), + show_before_live_at: bool = Query(False), + user: Optional[BugoutUser] = Depends(user_for_auth_header), db_session: Session = Depends(db.yield_db_read_only_session), ) -> List[data.CallRequestResponse]: """ @@ -302,6 +314,8 @@ async def list_requests_route( limit=limit, offset=offset, show_expired=show_expired, + show_before_live_at=show_before_live_at, + metatx_requester_id=user.id if user is not None else None, ) except ValueError as e: logger.error(repr(e)) @@ -326,7 +340,7 @@ async def get_request( At least one of `contract_id` or `contract_address` must be provided as query parameters. """ try: - request = contracts_actions.get_call_requests( + request = contracts_actions.get_call_request( db_session=db_session, request_id=request_id, ) @@ -354,13 +368,14 @@ async def create_requests( At least one of `contract_id` or `contract_address` must be provided in the request body. """ try: - num_requests = contracts_actions.request_calls( + num_requests = contracts_actions.create_request_calls( db_session=db_session, metatx_requester_id=request.state.user.id, registered_contract_id=data.contract_id, contract_address=data.contract_address, call_specs=data.specifications, ttl_days=data.ttl_days, + live_at=data.live_at, ) except contracts_actions.InvalidAddressFormat as err: raise EngineHTTPException( diff --git a/engineapi/engineapi/version.txt b/engineapi/engineapi/version.txt index 5a5831ab..d169b2f2 100644 --- a/engineapi/engineapi/version.txt +++ b/engineapi/engineapi/version.txt @@ -1 +1 @@ -0.0.7 +0.0.8