Merge pull request #931 from moonstream-to/metatx-live-at

Field `live_at` for `call_requests`
pull/932/head^2
Sergei Sumarokov 2024-02-01 13:16:57 +03:00 zatwierdzone przez GitHub
commit 2e8ae13bb3
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
7 zmienionych plików z 196 dodań i 35 usunięć

Wyświetl plik

@ -0,0 +1,28 @@
"""Live at for metatx
Revision ID: 6d07739cb13e
Revises: 71e888082a6d
Create Date: 2023-12-06 14:33:04.814144
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '6d07739cb13e'
down_revision = '71e888082a6d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('call_requests', sa.Column('live_at', sa.DateTime(timezone=True), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('call_requests', 'live_at')
# ### end Alembic commands ###

Wyświetl plik

@ -2,10 +2,10 @@ import argparse
import json import json
import logging import logging
import uuid import uuid
from datetime import timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func, text from sqlalchemy import func, or_, text
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine import Row from sqlalchemy.engine import Row
from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.exc import IntegrityError, NoResultFound
@ -101,6 +101,7 @@ def parse_call_request_response(
request_id=str(obj[0].request_id), request_id=str(obj[0].request_id),
parameters=obj[0].parameters, parameters=obj[0].parameters,
expires_at=obj[0].expires_at, expires_at=obj[0].expires_at,
live_at=obj[0].live_at,
created_at=obj[0].created_at, created_at=obj[0].created_at,
updated_at=obj[0].updated_at, updated_at=obj[0].updated_at,
) )
@ -326,13 +327,14 @@ def delete_registered_contract(
return (registered_contract, blockchain) return (registered_contract, blockchain)
def request_calls( def create_request_calls(
db_session: Session, db_session: Session,
metatx_requester_id: uuid.UUID, metatx_requester_id: uuid.UUID,
registered_contract_id: Optional[uuid.UUID], registered_contract_id: Optional[uuid.UUID],
contract_address: Optional[str], contract_address: Optional[str],
call_specs: List[data.CallSpecification], call_specs: List[data.CallSpecification],
ttl_days: Optional[int] = None, ttl_days: Optional[int] = None,
live_at: Optional[int] = None,
) -> int: ) -> int:
""" """
Batch creates call requests for the given registered contract. Batch creates call requests for the given registered contract.
@ -350,6 +352,11 @@ def request_calls(
if ttl_days <= 0: if ttl_days <= 0:
raise ValueError("ttl_days must be positive") raise ValueError("ttl_days must be positive")
if live_at is not None:
assert live_at == int(live_at)
if live_at <= 0:
raise ValueError("live_at must be positive")
# Check that the moonstream_user_id matches a RegisteredContract with the given id or address # Check that the moonstream_user_id matches a RegisteredContract with the given id or address
query = db_session.query(RegisteredContract).filter( query = db_session.query(RegisteredContract).filter(
RegisteredContract.metatx_requester_id == metatx_requester_id RegisteredContract.metatx_requester_id == metatx_requester_id
@ -406,6 +413,7 @@ def request_calls(
request_id=specification.request_id, request_id=specification.request_id,
parameters=specification.parameters, parameters=specification.parameters,
expires_at=expires_at, expires_at=expires_at,
live_at=datetime.fromtimestamp(live_at) if live_at is not None else None,
) )
db_session.add(request) db_session.add(request)
@ -422,7 +430,7 @@ def request_calls(
return len(call_specs) return len(call_specs)
def get_call_requests( def get_call_request(
db_session: Session, db_session: Session,
request_id: uuid.UUID, request_id: uuid.UUID,
) -> Tuple[CallRequest, RegisteredContract]: ) -> Tuple[CallRequest, RegisteredContract]:
@ -472,9 +480,14 @@ def list_call_requests(
limit: int = 10, limit: int = 10,
offset: Optional[int] = None, offset: Optional[int] = None,
show_expired: bool = False, show_expired: bool = False,
show_before_live_at: bool = False,
metatx_requester_id: Optional[uuid.UUID] = None,
) -> List[Row[Tuple[CallRequest, RegisteredContract, CallRequestType]]]: ) -> List[Row[Tuple[CallRequest, RegisteredContract, CallRequestType]]]:
""" """
List call requests for the given moonstream_user_id List call requests.
Argument moonstream_user_id took from authorization workflow. And if it is specified
then user has access to call_requests before live_at param.
""" """
if caller is None: if caller is None:
raise ValueError("caller must be specified") raise ValueError("caller must be specified")
@ -507,6 +520,21 @@ def list_call_requests(
CallRequest.expires_at > func.now(), CallRequest.expires_at > func.now(),
) )
# If user id not specified, do not show call_requests before live_at.
# Otherwise check show_before_live_at argument from query parameter
if metatx_requester_id is not None:
query = query.filter(
CallRequest.metatx_requester_id == metatx_requester_id,
)
if not show_before_live_at:
query = query.filter(
or_(CallRequest.live_at < func.now(), CallRequest.live_at == None)
)
else:
query = query.filter(
or_(CallRequest.live_at < func.now(), CallRequest.live_at == None)
)
if offset is not None: if offset is not None:
query = query.offset(offset) query = query.offset(offset)
@ -633,7 +661,7 @@ def handle_request_calls(args: argparse.Namespace) -> None:
try: try:
with db.yield_db_session_ctx() as db_session: with db.yield_db_session_ctx() as db_session:
request_calls( create_request_calls(
db_session=db_session, db_session=db_session,
moonstream_user_id=args.moonstream_user_id, moonstream_user_id=args.moonstream_user_id,
registered_contract_id=args.registered_contract_id, registered_contract_id=args.registered_contract_id,

Wyświetl plik

@ -284,6 +284,7 @@ class CreateCallRequestsAPIRequest(BaseModel):
contract_address: Optional[str] = None contract_address: Optional[str] = None
specifications: List[CallSpecification] = Field(default_factory=list) specifications: List[CallSpecification] = Field(default_factory=list)
ttl_days: Optional[int] = None ttl_days: Optional[int] = None
live_at: Optional[int] = None
# Solution found thanks to https://github.com/pydantic/pydantic/issues/506 # Solution found thanks to https://github.com/pydantic/pydantic/issues/506
@root_validator @root_validator
@ -306,6 +307,7 @@ class CallRequestResponse(BaseModel):
request_id: str request_id: str
parameters: Dict[str, Any] parameters: Dict[str, Any]
expires_at: Optional[datetime] = None expires_at: Optional[datetime] = None
live_at: Optional[datetime] = None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime

Wyświetl plik

@ -6,8 +6,9 @@ from uuid import UUID
from bugout.data import BugoutResource, BugoutResources, BugoutUser from bugout.data import BugoutResource, BugoutResources, BugoutUser
from bugout.exceptions import BugoutResponseException from bugout.exceptions import BugoutResponseException
from fastapi import HTTPException, Request, Response from fastapi import Header, HTTPException, Request, Response
from pydantic import AnyHttpUrl, parse_obj_as from pydantic import AnyHttpUrl, parse_obj_as
from starlette.datastructures import Headers
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from starlette.responses import Response from starlette.responses import Response
@ -26,14 +27,97 @@ from .settings import (
BUGOUT_REQUEST_TIMEOUT_SECONDS, BUGOUT_REQUEST_TIMEOUT_SECONDS,
BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG, BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
MOONSTREAM_ADMIN_ACCESS_TOKEN, MOONSTREAM_ADMIN_ACCESS_TOKEN,
MOONSTREAM_APPLICATION_ID,
MOONSTREAM_ADMIN_ID, MOONSTREAM_ADMIN_ID,
MOONSTREAM_APPLICATION_ID,
) )
from .settings import bugout_client as bc from .settings import bugout_client as bc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class InvalidAuthHeaderFormat(Exception):
"""
Raised when authorization header not pass validation.
"""
class BugoutUnverifiedAuth(Exception):
"""
Raised when attempted access by unverified Brood account.
"""
class BugoutAuthWrongApp(Exception):
"""
Raised when user does not belong to this application.
"""
def parse_auth_header(auth_header: str) -> Tuple[str, str]:
"""
Returns: auth_format and user_token passed in authorization header.
"""
auth_list = auth_header.split()
if len(auth_list) != 2:
raise InvalidAuthHeaderFormat("Wrong authorization header")
return auth_list[0], auth_list[1]
def bugout_auth(token: str) -> BugoutUser:
"""
Extended bugout.get_user with additional checks.
"""
user: BugoutUser = bc.get_user(token)
if not user.verified:
raise BugoutUnverifiedAuth("Only verified accounts can have access")
if str(user.application_id) != str(MOONSTREAM_APPLICATION_ID):
raise BugoutAuthWrongApp("User does not belong to this application")
return user
async def user_for_auth_header(
authorization: str = Header(None),
) -> Optional[BugoutUser]:
"""
Fetch Bugout user if authorization token provided.
"""
user: Optional[BugoutUser] = None
if authorization is not None:
user_token: str = ""
try:
_, user_token = parse_auth_header(auth_header=authorization)
except InvalidAuthHeaderFormat:
raise EngineHTTPException(
status_code=403, detail="Wrong authorization header"
)
except Exception as e:
logger.error(f"Error processing Brood response: {str(e)}")
raise EngineHTTPException(status_code=500, detail="Internal server error")
if user_token != "":
try:
user: BugoutUser = bugout_auth(token=user_token)
except BugoutUnverifiedAuth:
logger.info(f"Attempted access by unverified Brood account: {user.id}")
raise EngineHTTPException(
status_code=403,
detail="Only verified accounts can have access",
)
except BugoutAuthWrongApp:
raise EngineHTTPException(
status_code=403, detail="User does not belong to this application"
)
except BugoutResponseException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
except Exception as e:
logger.error(f"Error processing Brood response: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
return user
class BroodAuthMiddleware(BaseHTTPMiddleware): class BroodAuthMiddleware(BaseHTTPMiddleware):
""" """
Checks the authorization header on the request. If it represents a verified Brood user, Checks the authorization header on the request. If it represents a verified Brood user,
@ -59,30 +143,33 @@ class BroodAuthMiddleware(BaseHTTPMiddleware):
if path in self.whitelist.keys() and self.whitelist[path] == method: if path in self.whitelist.keys() and self.whitelist[path] == method:
return await call_next(request) return await call_next(request)
authorization_header = request.headers.get("authorization") authorization = request.headers.get("authorization")
if authorization_header is None: if authorization is None:
return Response( return Response(
status_code=403, content="No authorization header passed with request" status_code=403,
content="No authorization header passed with request",
) )
user_token_list = authorization_header.split()
if len(user_token_list) != 2:
return Response(status_code=403, content="Wrong authorization header")
user_token: str = user_token_list[-1]
try: try:
user: BugoutUser = bc.get_user(user_token) _, user_token = parse_auth_header(auth_header=authorization)
if not user.verified: except InvalidAuthHeaderFormat:
logger.info( return Response(status_code=403, content="Wrong authorization header")
f"Attempted journal access by unverified Brood account: {user.id}" except Exception as e:
) logger.error(f"Error processing Brood response: {str(e)}")
return Response( return Response(status_code=500, content="Internal server error")
status_code=403,
content="Only verified accounts can access journals", try:
) user: BugoutUser = bugout_auth(token=user_token)
if str(user.application_id) != str(MOONSTREAM_APPLICATION_ID): except BugoutUnverifiedAuth:
return Response( logger.info(f"Attempted access by unverified Brood account: {user.id}")
status_code=403, content="User does not belong to this application" return Response(
) status_code=403,
content="Only verified accounts can have access",
)
except BugoutAuthWrongApp:
return Response(
status_code=403, content="User does not belong to this application"
)
except BugoutResponseException as e: except BugoutResponseException as e:
return Response(status_code=e.status_code, content=e.detail) return Response(status_code=e.status_code, content=e.detail)
except Exception as e: except Exception as e:

Wyświetl plik

@ -317,6 +317,7 @@ class CallRequest(Base):
parameters = Column(JSONB, nullable=False) parameters = Column(JSONB, nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=True, index=True) expires_at = Column(DateTime(timezone=True), nullable=True, index=True)
live_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column( created_at = Column(
DateTime(timezone=True), server_default=utcnow(), nullable=False DateTime(timezone=True), server_default=utcnow(), nullable=False

Wyświetl plik

@ -9,12 +9,18 @@ import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from uuid import UUID from uuid import UUID
from bugout.data import BugoutUser
from fastapi import Body, Depends, FastAPI, Path, Query, Request from fastapi import Body, Depends, FastAPI, Path, Query, Request
from sqlalchemy.exc import NoResultFound from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from .. import contracts_actions, data, db from .. import contracts_actions, data, db
from ..middleware import BroodAuthMiddleware, BugoutCORSMiddleware, EngineHTTPException from ..middleware import (
BroodAuthMiddleware,
BugoutCORSMiddleware,
EngineHTTPException,
user_for_auth_header,
)
from ..settings import DOCS_TARGET_PATH from ..settings import DOCS_TARGET_PATH
from ..version import VERSION from ..version import VERSION
@ -40,7 +46,7 @@ whitelist_paths = {
"/metatx/blockchains": "GET", "/metatx/blockchains": "GET",
"/metatx/contracts/types": "GET", "/metatx/contracts/types": "GET",
"/metatx/requests/types": "GET", "/metatx/requests/types": "GET",
"/metatx/requests": "GET", "/metatx/requests": "GET", # Controls by custom authentication check
} }
app = FastAPI( app = FastAPI(
@ -278,14 +284,20 @@ async def call_request_types_route(
return call_request_types return call_request_types
@app.get("/requests", tags=["requests"], response_model=List[data.CallRequestResponse]) @app.get(
"/requests",
tags=["requests"],
response_model=List[data.CallRequestResponse],
)
async def list_requests_route( async def list_requests_route(
contract_id: Optional[UUID] = Query(None), contract_id: Optional[UUID] = Query(None),
contract_address: Optional[str] = Query(None), contract_address: Optional[str] = Query(None),
caller: str = Query(...), caller: str = Query(...),
limit: int = Query(100), limit: int = Query(100),
offset: Optional[int] = Query(None), offset: Optional[int] = Query(None),
show_expired: Optional[bool] = Query(False), show_expired: bool = Query(False),
show_before_live_at: bool = Query(False),
user: Optional[BugoutUser] = Depends(user_for_auth_header),
db_session: Session = Depends(db.yield_db_read_only_session), db_session: Session = Depends(db.yield_db_read_only_session),
) -> List[data.CallRequestResponse]: ) -> List[data.CallRequestResponse]:
""" """
@ -302,6 +314,8 @@ async def list_requests_route(
limit=limit, limit=limit,
offset=offset, offset=offset,
show_expired=show_expired, show_expired=show_expired,
show_before_live_at=show_before_live_at,
metatx_requester_id=user.id if user is not None else None,
) )
except ValueError as e: except ValueError as e:
logger.error(repr(e)) logger.error(repr(e))
@ -326,7 +340,7 @@ async def get_request(
At least one of `contract_id` or `contract_address` must be provided as query parameters. At least one of `contract_id` or `contract_address` must be provided as query parameters.
""" """
try: try:
request = contracts_actions.get_call_requests( request = contracts_actions.get_call_request(
db_session=db_session, db_session=db_session,
request_id=request_id, request_id=request_id,
) )
@ -354,13 +368,14 @@ async def create_requests(
At least one of `contract_id` or `contract_address` must be provided in the request body. At least one of `contract_id` or `contract_address` must be provided in the request body.
""" """
try: try:
num_requests = contracts_actions.request_calls( num_requests = contracts_actions.create_request_calls(
db_session=db_session, db_session=db_session,
metatx_requester_id=request.state.user.id, metatx_requester_id=request.state.user.id,
registered_contract_id=data.contract_id, registered_contract_id=data.contract_id,
contract_address=data.contract_address, contract_address=data.contract_address,
call_specs=data.specifications, call_specs=data.specifications,
ttl_days=data.ttl_days, ttl_days=data.ttl_days,
live_at=data.live_at,
) )
except contracts_actions.InvalidAddressFormat as err: except contracts_actions.InvalidAddressFormat as err:
raise EngineHTTPException( raise EngineHTTPException(

Wyświetl plik

@ -1 +1 @@
0.0.7 0.0.8