Merge pull request #931 from moonstream-to/metatx-live-at

Field `live_at` for `call_requests`
pull/932/head^2
Sergei Sumarokov 2024-02-01 13:16:57 +03:00 zatwierdzone przez GitHub
commit 2e8ae13bb3
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
7 zmienionych plików z 196 dodań i 35 usunięć

Wyświetl plik

@ -0,0 +1,28 @@
"""Live at for metatx
Revision ID: 6d07739cb13e
Revises: 71e888082a6d
Create Date: 2023-12-06 14:33:04.814144
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '6d07739cb13e'
down_revision = '71e888082a6d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('call_requests', sa.Column('live_at', sa.DateTime(timezone=True), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('call_requests', 'live_at')
# ### end Alembic commands ###

Wyświetl plik

@ -2,10 +2,10 @@ import argparse
import json
import logging
import uuid
from datetime import timedelta
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func, text
from sqlalchemy import func, or_, text
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine import Row
from sqlalchemy.exc import IntegrityError, NoResultFound
@ -101,6 +101,7 @@ def parse_call_request_response(
request_id=str(obj[0].request_id),
parameters=obj[0].parameters,
expires_at=obj[0].expires_at,
live_at=obj[0].live_at,
created_at=obj[0].created_at,
updated_at=obj[0].updated_at,
)
@ -326,13 +327,14 @@ def delete_registered_contract(
return (registered_contract, blockchain)
def request_calls(
def create_request_calls(
db_session: Session,
metatx_requester_id: uuid.UUID,
registered_contract_id: Optional[uuid.UUID],
contract_address: Optional[str],
call_specs: List[data.CallSpecification],
ttl_days: Optional[int] = None,
live_at: Optional[int] = None,
) -> int:
"""
Batch creates call requests for the given registered contract.
@ -350,6 +352,11 @@ def request_calls(
if ttl_days <= 0:
raise ValueError("ttl_days must be positive")
if live_at is not None:
assert live_at == int(live_at)
if live_at <= 0:
raise ValueError("live_at must be positive")
# Check that the moonstream_user_id matches a RegisteredContract with the given id or address
query = db_session.query(RegisteredContract).filter(
RegisteredContract.metatx_requester_id == metatx_requester_id
@ -406,6 +413,7 @@ def request_calls(
request_id=specification.request_id,
parameters=specification.parameters,
expires_at=expires_at,
live_at=datetime.fromtimestamp(live_at) if live_at is not None else None,
)
db_session.add(request)
@ -422,7 +430,7 @@ def request_calls(
return len(call_specs)
def get_call_requests(
def get_call_request(
db_session: Session,
request_id: uuid.UUID,
) -> Tuple[CallRequest, RegisteredContract]:
@ -472,9 +480,14 @@ def list_call_requests(
limit: int = 10,
offset: Optional[int] = None,
show_expired: bool = False,
show_before_live_at: bool = False,
metatx_requester_id: Optional[uuid.UUID] = None,
) -> List[Row[Tuple[CallRequest, RegisteredContract, CallRequestType]]]:
"""
List call requests for the given moonstream_user_id
List call requests.
Argument moonstream_user_id took from authorization workflow. And if it is specified
then user has access to call_requests before live_at param.
"""
if caller is None:
raise ValueError("caller must be specified")
@ -507,6 +520,21 @@ def list_call_requests(
CallRequest.expires_at > func.now(),
)
# If user id not specified, do not show call_requests before live_at.
# Otherwise check show_before_live_at argument from query parameter
if metatx_requester_id is not None:
query = query.filter(
CallRequest.metatx_requester_id == metatx_requester_id,
)
if not show_before_live_at:
query = query.filter(
or_(CallRequest.live_at < func.now(), CallRequest.live_at == None)
)
else:
query = query.filter(
or_(CallRequest.live_at < func.now(), CallRequest.live_at == None)
)
if offset is not None:
query = query.offset(offset)
@ -633,7 +661,7 @@ def handle_request_calls(args: argparse.Namespace) -> None:
try:
with db.yield_db_session_ctx() as db_session:
request_calls(
create_request_calls(
db_session=db_session,
moonstream_user_id=args.moonstream_user_id,
registered_contract_id=args.registered_contract_id,

Wyświetl plik

@ -284,6 +284,7 @@ class CreateCallRequestsAPIRequest(BaseModel):
contract_address: Optional[str] = None
specifications: List[CallSpecification] = Field(default_factory=list)
ttl_days: Optional[int] = None
live_at: Optional[int] = None
# Solution found thanks to https://github.com/pydantic/pydantic/issues/506
@root_validator
@ -306,6 +307,7 @@ class CallRequestResponse(BaseModel):
request_id: str
parameters: Dict[str, Any]
expires_at: Optional[datetime] = None
live_at: Optional[datetime] = None
created_at: datetime
updated_at: datetime

Wyświetl plik

@ -6,8 +6,9 @@ from uuid import UUID
from bugout.data import BugoutResource, BugoutResources, BugoutUser
from bugout.exceptions import BugoutResponseException
from fastapi import HTTPException, Request, Response
from fastapi import Header, HTTPException, Request, Response
from pydantic import AnyHttpUrl, parse_obj_as
from starlette.datastructures import Headers
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import Response
@ -26,14 +27,97 @@ from .settings import (
BUGOUT_REQUEST_TIMEOUT_SECONDS,
BUGOUT_RESOURCE_TYPE_APPLICATION_CONFIG,
MOONSTREAM_ADMIN_ACCESS_TOKEN,
MOONSTREAM_APPLICATION_ID,
MOONSTREAM_ADMIN_ID,
MOONSTREAM_APPLICATION_ID,
)
from .settings import bugout_client as bc
logger = logging.getLogger(__name__)
class InvalidAuthHeaderFormat(Exception):
"""
Raised when authorization header not pass validation.
"""
class BugoutUnverifiedAuth(Exception):
"""
Raised when attempted access by unverified Brood account.
"""
class BugoutAuthWrongApp(Exception):
"""
Raised when user does not belong to this application.
"""
def parse_auth_header(auth_header: str) -> Tuple[str, str]:
"""
Returns: auth_format and user_token passed in authorization header.
"""
auth_list = auth_header.split()
if len(auth_list) != 2:
raise InvalidAuthHeaderFormat("Wrong authorization header")
return auth_list[0], auth_list[1]
def bugout_auth(token: str) -> BugoutUser:
"""
Extended bugout.get_user with additional checks.
"""
user: BugoutUser = bc.get_user(token)
if not user.verified:
raise BugoutUnverifiedAuth("Only verified accounts can have access")
if str(user.application_id) != str(MOONSTREAM_APPLICATION_ID):
raise BugoutAuthWrongApp("User does not belong to this application")
return user
async def user_for_auth_header(
authorization: str = Header(None),
) -> Optional[BugoutUser]:
"""
Fetch Bugout user if authorization token provided.
"""
user: Optional[BugoutUser] = None
if authorization is not None:
user_token: str = ""
try:
_, user_token = parse_auth_header(auth_header=authorization)
except InvalidAuthHeaderFormat:
raise EngineHTTPException(
status_code=403, detail="Wrong authorization header"
)
except Exception as e:
logger.error(f"Error processing Brood response: {str(e)}")
raise EngineHTTPException(status_code=500, detail="Internal server error")
if user_token != "":
try:
user: BugoutUser = bugout_auth(token=user_token)
except BugoutUnverifiedAuth:
logger.info(f"Attempted access by unverified Brood account: {user.id}")
raise EngineHTTPException(
status_code=403,
detail="Only verified accounts can have access",
)
except BugoutAuthWrongApp:
raise EngineHTTPException(
status_code=403, detail="User does not belong to this application"
)
except BugoutResponseException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
except Exception as e:
logger.error(f"Error processing Brood response: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
return user
class BroodAuthMiddleware(BaseHTTPMiddleware):
"""
Checks the authorization header on the request. If it represents a verified Brood user,
@ -59,30 +143,33 @@ class BroodAuthMiddleware(BaseHTTPMiddleware):
if path in self.whitelist.keys() and self.whitelist[path] == method:
return await call_next(request)
authorization_header = request.headers.get("authorization")
if authorization_header is None:
authorization = request.headers.get("authorization")
if authorization is None:
return Response(
status_code=403, content="No authorization header passed with request"
status_code=403,
content="No authorization header passed with request",
)
user_token_list = authorization_header.split()
if len(user_token_list) != 2:
return Response(status_code=403, content="Wrong authorization header")
user_token: str = user_token_list[-1]
try:
user: BugoutUser = bc.get_user(user_token)
if not user.verified:
logger.info(
f"Attempted journal access by unverified Brood account: {user.id}"
)
return Response(
status_code=403,
content="Only verified accounts can access journals",
)
if str(user.application_id) != str(MOONSTREAM_APPLICATION_ID):
return Response(
status_code=403, content="User does not belong to this application"
)
_, user_token = parse_auth_header(auth_header=authorization)
except InvalidAuthHeaderFormat:
return Response(status_code=403, content="Wrong authorization header")
except Exception as e:
logger.error(f"Error processing Brood response: {str(e)}")
return Response(status_code=500, content="Internal server error")
try:
user: BugoutUser = bugout_auth(token=user_token)
except BugoutUnverifiedAuth:
logger.info(f"Attempted access by unverified Brood account: {user.id}")
return Response(
status_code=403,
content="Only verified accounts can have access",
)
except BugoutAuthWrongApp:
return Response(
status_code=403, content="User does not belong to this application"
)
except BugoutResponseException as e:
return Response(status_code=e.status_code, content=e.detail)
except Exception as e:

Wyświetl plik

@ -317,6 +317,7 @@ class CallRequest(Base):
parameters = Column(JSONB, nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=True, index=True)
live_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(
DateTime(timezone=True), server_default=utcnow(), nullable=False

Wyświetl plik

@ -9,12 +9,18 @@ import logging
from typing import Dict, List, Optional
from uuid import UUID
from bugout.data import BugoutUser
from fastapi import Body, Depends, FastAPI, Path, Query, Request
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from .. import contracts_actions, data, db
from ..middleware import BroodAuthMiddleware, BugoutCORSMiddleware, EngineHTTPException
from ..middleware import (
BroodAuthMiddleware,
BugoutCORSMiddleware,
EngineHTTPException,
user_for_auth_header,
)
from ..settings import DOCS_TARGET_PATH
from ..version import VERSION
@ -40,7 +46,7 @@ whitelist_paths = {
"/metatx/blockchains": "GET",
"/metatx/contracts/types": "GET",
"/metatx/requests/types": "GET",
"/metatx/requests": "GET",
"/metatx/requests": "GET", # Controls by custom authentication check
}
app = FastAPI(
@ -278,14 +284,20 @@ async def call_request_types_route(
return call_request_types
@app.get("/requests", tags=["requests"], response_model=List[data.CallRequestResponse])
@app.get(
"/requests",
tags=["requests"],
response_model=List[data.CallRequestResponse],
)
async def list_requests_route(
contract_id: Optional[UUID] = Query(None),
contract_address: Optional[str] = Query(None),
caller: str = Query(...),
limit: int = Query(100),
offset: Optional[int] = Query(None),
show_expired: Optional[bool] = Query(False),
show_expired: bool = Query(False),
show_before_live_at: bool = Query(False),
user: Optional[BugoutUser] = Depends(user_for_auth_header),
db_session: Session = Depends(db.yield_db_read_only_session),
) -> List[data.CallRequestResponse]:
"""
@ -302,6 +314,8 @@ async def list_requests_route(
limit=limit,
offset=offset,
show_expired=show_expired,
show_before_live_at=show_before_live_at,
metatx_requester_id=user.id if user is not None else None,
)
except ValueError as e:
logger.error(repr(e))
@ -326,7 +340,7 @@ async def get_request(
At least one of `contract_id` or `contract_address` must be provided as query parameters.
"""
try:
request = contracts_actions.get_call_requests(
request = contracts_actions.get_call_request(
db_session=db_session,
request_id=request_id,
)
@ -354,13 +368,14 @@ async def create_requests(
At least one of `contract_id` or `contract_address` must be provided in the request body.
"""
try:
num_requests = contracts_actions.request_calls(
num_requests = contracts_actions.create_request_calls(
db_session=db_session,
metatx_requester_id=request.state.user.id,
registered_contract_id=data.contract_id,
contract_address=data.contract_address,
call_specs=data.specifications,
ttl_days=data.ttl_days,
live_at=data.live_at,
)
except contracts_actions.InvalidAddressFormat as err:
raise EngineHTTPException(

Wyświetl plik

@ -1 +1 @@
0.0.7
0.0.8