From 1f3f28e8ff43bb41e23d724d4936d6a17e024d37 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Sun, 19 Feb 2023 11:37:02 -0700 Subject: [PATCH] Check scope on API endpoints --- api/decorators.py | 26 +++++++++++++++++++++++++- api/middleware.py | 2 ++ api/models/token.py | 9 +++++++++ api/views/accounts.py | 28 ++++++++++++++-------------- api/views/announcements.py | 6 +++--- api/views/media.py | 8 ++++---- api/views/notifications.py | 4 ++-- api/views/polls.py | 6 +++--- api/views/search.py | 4 ++-- api/views/statuses.py | 22 +++++++++++----------- api/views/timelines.py | 10 +++++----- api/views/trends.py | 8 ++++---- tests/api/test_apps.py | 13 +++++++++++++ tests/api/test_tokens.py | 11 +++++++++++ 14 files changed, 108 insertions(+), 49 deletions(-) create mode 100644 tests/api/test_apps.py create mode 100644 tests/api/test_tokens.py diff --git a/api/decorators.py b/api/decorators.py index 14215e6..4a93715 100644 --- a/api/decorators.py +++ b/api/decorators.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from functools import wraps from django.http import JsonResponse @@ -13,10 +14,33 @@ def identity_required(function): def inner(request, *args, **kwargs): # They need an identity if not request.identity: - return JsonResponse({"error": "identity_token_required"}, status=400) + return JsonResponse({"error": "identity_token_required"}, status=401) return function(request, *args, **kwargs) # This is for the API only inner.csrf_exempt = True return inner + + +def scope_required(scope: str, requires_identity=True): + """ + Asserts that the token we're using has the provided scope + """ + + def decorator(function: Callable): + @wraps(function) + def inner(request, *args, **kwargs): + if not request.token: + return JsonResponse({"error": "identity_token_required"}, status=401) + # They need an identity + if not request.identity and requires_identity: + return JsonResponse({"error": "identity_token_required"}, status=401) + if not request.token.has_scope(scope): + return JsonResponse({"error": "out_of_scope_for_token"}, status=403) + return function(request, *args, **kwargs) + + inner.csrf_exempt = True # type:ignore + return inner + + return decorator diff --git a/api/middleware.py b/api/middleware.py index 0d55fb3..ae3c44a 100644 --- a/api/middleware.py +++ b/api/middleware.py @@ -14,6 +14,7 @@ class ApiTokenMiddleware: def __call__(self, request): auth_header = request.headers.get("authorization", None) + request.token = None if auth_header and auth_header.startswith("Bearer "): token_value = auth_header[7:] try: @@ -22,6 +23,7 @@ class ApiTokenMiddleware: return HttpResponse("Invalid Bearer token", status=400) request.user = token.user request.identity = token.identity + request.token = token request.session = None response = self.get_response(request) return response diff --git a/api/models/token.py b/api/models/token.py index bec8bb3..d070a01 100644 --- a/api/models/token.py +++ b/api/models/token.py @@ -36,3 +36,12 @@ class Token(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) revoked = models.DateTimeField(blank=True, null=True) + + def has_scope(self, scope: str): + """ + Returns if this token has the given scope. + It's a function so we can do mapping/reduction if needed + """ + # TODO: Support granular scopes the other way? + scope_prefix = scope.split(":")[0] + return (scope in self.scopes) or (scope_prefix in self.scopes) diff --git a/api/views/accounts.py b/api/views/accounts.py index 007ac8b..371aca0 100644 --- a/api/views/accounts.py +++ b/api/views/accounts.py @@ -8,7 +8,7 @@ from hatchway import ApiResponse, QueryOrBody, api_view from activities.models import Post from activities.services import SearchService from api import schemas -from api.decorators import identity_required +from api.decorators import scope_required from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult from core.models import Config from users.models import Identity @@ -16,13 +16,13 @@ from users.services import IdentityService from users.shortcuts import by_handle_or_404 -@identity_required +@scope_required("read") @api_view.get def verify_credentials(request) -> schemas.Account: return schemas.Account.from_identity(request.identity, source=True) -@identity_required +@scope_required("write") @api_view.patch def update_credentials( request, @@ -73,7 +73,7 @@ def update_credentials( return schemas.Account.from_identity(identity, source=True) -@identity_required +@scope_required("read") @api_view.get def account_relationships(request, id: list[str] | None) -> list[schemas.Relationship]: result = [] @@ -87,7 +87,7 @@ def account_relationships(request, id: list[str] | None) -> list[schemas.Relatio return result -@identity_required +@scope_required("read") @api_view.get def familiar_followers( request, id: list[str] | None @@ -114,7 +114,7 @@ def familiar_followers( return result -@identity_required +@scope_required("read") @api_view.get def accounts_search( request, @@ -146,8 +146,8 @@ def lookup(request: HttpRequest, acct: str) -> schemas.Account: return schemas.Account.from_identity(identity) +@scope_required("read:accounts") @api_view.get -@identity_required def account(request, id: str) -> schemas.Account: identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), @@ -156,8 +156,8 @@ def account(request, id: str) -> schemas.Account: return schemas.Account.from_identity(identity) +@scope_required("read:statuses") @api_view.get -@identity_required def account_statuses( request: HttpRequest, id: str, @@ -218,8 +218,8 @@ def account_statuses( ) +@scope_required("write:follows") @api_view.post -@identity_required def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationship: identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id @@ -229,8 +229,8 @@ def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationsh return schemas.Relationship.from_identity_pair(identity, request.identity) +@scope_required("write:follows") @api_view.post -@identity_required def account_unfollow(request, id: str) -> schemas.Relationship: identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id @@ -240,8 +240,8 @@ def account_unfollow(request, id: str) -> schemas.Relationship: return schemas.Relationship.from_identity_pair(identity, request.identity) +@scope_required("write:blocks") @api_view.post -@identity_required def account_block(request, id: str) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) service = IdentityService(request.identity) @@ -249,8 +249,8 @@ def account_block(request, id: str) -> schemas.Relationship: return schemas.Relationship.from_identity_pair(identity, request.identity) +@scope_required("write:blocks") @api_view.post -@identity_required def account_unblock(request, id: str) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) service = IdentityService(request.identity) @@ -258,7 +258,7 @@ def account_unblock(request, id: str) -> schemas.Relationship: return schemas.Relationship.from_identity_pair(identity, request.identity) -@identity_required +@scope_required("write:blocks") @api_view.post def account_mute( request, @@ -276,7 +276,7 @@ def account_mute( return schemas.Relationship.from_identity_pair(identity, request.identity) -@identity_required +@scope_required("write:blocks") @api_view.post def account_unmute(request, id: str) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) diff --git a/api/views/announcements.py b/api/views/announcements.py index 90e3f42..27f1ab9 100644 --- a/api/views/announcements.py +++ b/api/views/announcements.py @@ -2,12 +2,12 @@ from django.shortcuts import get_object_or_404 from hatchway import api_view from api import schemas -from api.decorators import identity_required +from api.decorators import scope_required from users.models import Announcement from users.services import AnnouncementService -@identity_required +@scope_required("read:notifications") @api_view.get def announcement_list(request) -> list[schemas.Announcement]: return [ @@ -16,7 +16,7 @@ def announcement_list(request) -> list[schemas.Announcement]: ] -@identity_required +@scope_required("write:notifications") @api_view.post def announcement_dismiss(request, pk: str): announcement = get_object_or_404(Announcement, pk=pk) diff --git a/api/views/media.py b/api/views/media.py index 3ec9918..6ff1c3a 100644 --- a/api/views/media.py +++ b/api/views/media.py @@ -6,10 +6,10 @@ from activities.models import PostAttachment, PostAttachmentStates from api import schemas from core.files import blurhash_image, resize_image -from ..decorators import identity_required +from ..decorators import scope_required -@identity_required +@scope_required("write:media") @api_view.post def upload_media( request, @@ -47,7 +47,7 @@ def upload_media( return schemas.MediaAttachment.from_post_attachment(attachment) -@identity_required +@scope_required("read:media") @api_view.get def get_media( request, @@ -59,7 +59,7 @@ def get_media( return schemas.MediaAttachment.from_post_attachment(attachment) -@identity_required +@scope_required("write:media") @api_view.put def update_media( request, diff --git a/api/views/notifications.py b/api/views/notifications.py index fad897e..1320fe6 100644 --- a/api/views/notifications.py +++ b/api/views/notifications.py @@ -4,11 +4,11 @@ from hatchway import ApiResponse, api_view from activities.models import TimelineEvent from activities.services import TimelineService from api import schemas -from api.decorators import identity_required +from api.decorators import scope_required from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult -@identity_required +@scope_required("read:notifications") @api_view.get def notifications( request: HttpRequest, diff --git a/api/views/polls.py b/api/views/polls.py index 2658061..7fa58c9 100644 --- a/api/views/polls.py +++ b/api/views/polls.py @@ -3,21 +3,21 @@ from hatchway import Schema, api_view from activities.models import Post, PostInteraction from api import schemas -from api.decorators import identity_required +from api.decorators import scope_required class PostVoteSchema(Schema): choices: list[int] -@identity_required +@scope_required("read:statuses") @api_view.get def get_poll(request, id: str) -> schemas.Poll: post = get_object_or_404(Post, pk=id, type=Post.Types.question) return schemas.Poll.from_post(post, identity=request.identity) -@identity_required +@scope_required("write:statuses") @api_view.post def vote_poll(request, id: str, details: PostVoteSchema) -> schemas.Poll: post = get_object_or_404(Post, pk=id, type=Post.Types.question) diff --git a/api/views/search.py b/api/views/search.py index 853d990..a664af9 100644 --- a/api/views/search.py +++ b/api/views/search.py @@ -5,10 +5,10 @@ from hatchway import Field, api_view from activities.models import PostInteraction from activities.services.search import SearchService from api import schemas -from api.decorators import identity_required +from api.decorators import scope_required -@identity_required +@scope_required("read") @api_view.get def search( request, diff --git a/api/views/statuses.py b/api/views/statuses.py index c473594..5ce6e36 100644 --- a/api/views/statuses.py +++ b/api/views/statuses.py @@ -15,7 +15,7 @@ from activities.models import ( ) from activities.services import PostService from api import schemas -from api.decorators import identity_required +from api.decorators import scope_required from api.pagination import MastodonPaginator, PaginationResult from core.models import Config @@ -72,7 +72,7 @@ def post_for_id(request: HttpRequest, id: str) -> Post: return get_object_or_404(queryset, pk=id) -@identity_required +@scope_required("write:statuses") @api_view.post def post_status(request, details: PostStatusSchema) -> schemas.Status: # Check text length @@ -110,7 +110,7 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status: return schemas.Status.from_post(post, identity=request.identity) -@identity_required +@scope_required("read:statuses") @api_view.get def status(request, id: str) -> schemas.Status: post = post_for_id(request, id) @@ -120,7 +120,7 @@ def status(request, id: str) -> schemas.Status: ) -@identity_required +@scope_required("write:statuses") @api_view.put def edit_status(request, id: str, details: EditStatusSchema) -> schemas.Status: post = post_for_id(request, id) @@ -138,7 +138,7 @@ def edit_status(request, id: str, details: EditStatusSchema) -> schemas.Status: return schemas.Status.from_post(post) -@identity_required +@scope_required("write:statuses") @api_view.delete def delete_status(request, id: str) -> schemas.Status: post = post_for_id(request, id) @@ -148,14 +148,14 @@ def delete_status(request, id: str) -> schemas.Status: return schemas.Status.from_post(post, identity=request.identity) -@identity_required +@scope_required("read:statuses") @api_view.get def status_source(request, id: str) -> schemas.StatusSource: post = post_for_id(request, id) return schemas.StatusSource.from_post(post) -@identity_required +@scope_required("read:statuses") @api_view.get def status_context(request, id: str) -> schemas.Context: post = post_for_id(request, id) @@ -180,7 +180,7 @@ def status_context(request, id: str) -> schemas.Context: ) -@identity_required +@scope_required("write:favourites") @api_view.post def favourite_status(request, id: str) -> schemas.Status: post = post_for_id(request, id) @@ -192,7 +192,7 @@ def favourite_status(request, id: str) -> schemas.Status: ) -@identity_required +@scope_required("write:favourites") @api_view.post def unfavourite_status(request, id: str) -> schemas.Status: post = post_for_id(request, id) @@ -245,7 +245,7 @@ def favourited_by( ) -@identity_required +@scope_required("write:favourites") @api_view.post def reblog_status(request, id: str) -> schemas.Status: post = post_for_id(request, id) @@ -257,7 +257,7 @@ def reblog_status(request, id: str) -> schemas.Status: ) -@identity_required +@scope_required("write:favourites") @api_view.post def unreblog_status(request, id: str) -> schemas.Status: post = post_for_id(request, id) diff --git a/api/views/timelines.py b/api/views/timelines.py index 1a1a5db..c254303 100644 --- a/api/views/timelines.py +++ b/api/views/timelines.py @@ -4,12 +4,12 @@ from hatchway import ApiError, ApiResponse, api_view from activities.models import Post from activities.services import TimelineService from api import schemas -from api.decorators import identity_required +from api.decorators import scope_required from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult from core.models import Config -@identity_required +@scope_required("read:statuses") @api_view.get def home( request: HttpRequest, @@ -86,7 +86,7 @@ def public( ) -@identity_required +@scope_required("read:statuses") @api_view.get def hashtag( request: HttpRequest, @@ -121,7 +121,7 @@ def hashtag( ) -@identity_required +@scope_required("read:conversations") @api_view.get def conversations( request: HttpRequest, @@ -134,7 +134,7 @@ def conversations( return [] -@identity_required +@scope_required("read:favourites") @api_view.get def favourites( request: HttpRequest, diff --git a/api/views/trends.py b/api/views/trends.py index 6d03168..c5a5f26 100644 --- a/api/views/trends.py +++ b/api/views/trends.py @@ -2,10 +2,10 @@ from django.http import HttpRequest from hatchway import api_view from api import schemas -from api.decorators import identity_required +from api.decorators import scope_required -@identity_required +@scope_required("read") @api_view.get def trends_tags( request: HttpRequest, @@ -16,7 +16,7 @@ def trends_tags( return [] -@identity_required +@scope_required("read") @api_view.get def trends_statuses( request: HttpRequest, @@ -27,7 +27,7 @@ def trends_statuses( return [] -@identity_required +@scope_required("read") @api_view.get def trends_links( request: HttpRequest, diff --git a/tests/api/test_apps.py b/tests/api/test_apps.py new file mode 100644 index 0000000..f4cfbb7 --- /dev/null +++ b/tests/api/test_apps.py @@ -0,0 +1,13 @@ +import pytest + + +@pytest.mark.django_db +def test_create(api_client): + """ + Tests creating an app + """ + response = api_client.post( + "/api/v1/apps", {"client_name": "test", "redirect_uris": ""} + ) + assert response.status_code == 200 + assert response.json()["name"] == "test" diff --git a/tests/api/test_tokens.py b/tests/api/test_tokens.py new file mode 100644 index 0000000..16f42d1 --- /dev/null +++ b/tests/api/test_tokens.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.mark.django_db +def test_has_scope(api_token): + """ + Tests has_scope on the Token model + """ + assert api_token.has_scope("read") + assert api_token.has_scope("read:statuses") + assert not api_token.has_scope("destroyearth")