diff --git a/api/urls.py b/api/urls.py index e3e541a..faff8aa 100644 --- a/api/urls.py +++ b/api/urls.py @@ -76,6 +76,7 @@ urlpatterns = [ path("v1/statuses//source", statuses.status_source), # Notifications path("v1/notifications", notifications.notifications), + path("v1/notifications/", notifications.get_notification), # Polls path("v1/polls/", polls.get_poll), path("v1/polls//votes", polls.vote_poll), diff --git a/api/views/notifications.py b/api/views/notifications.py index 1320fe6..1dcb924 100644 --- a/api/views/notifications.py +++ b/api/views/notifications.py @@ -1,4 +1,5 @@ from django.http import HttpRequest +from django.shortcuts import get_object_or_404 from hatchway import ApiResponse, api_view from activities.models import TimelineEvent @@ -7,6 +8,15 @@ from api import schemas from api.decorators import scope_required from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult +# Types/exclude_types use weird syntax so we have to handle them manually +NOTIFICATION_TYPES = { + "favourite": TimelineEvent.Types.liked, + "reblog": TimelineEvent.Types.boosted, + "mention": TimelineEvent.Types.mentioned, + "follow": TimelineEvent.Types.followed, + "admin.sign_up": TimelineEvent.Types.identity_created, +} + @scope_required("read:notifications") @api_view.get @@ -18,22 +28,14 @@ def notifications( limit: int = 20, account_id: str | None = None, ) -> ApiResponse[list[schemas.Notification]]: - # Types/exclude_types use weird syntax so we have to handle them manually - base_types = { - "favourite": TimelineEvent.Types.liked, - "reblog": TimelineEvent.Types.boosted, - "mention": TimelineEvent.Types.mentioned, - "follow": TimelineEvent.Types.followed, - "admin.sign_up": TimelineEvent.Types.identity_created, - } requested_types = set(request.GET.getlist("types[]")) excluded_types = set(request.GET.getlist("exclude_types[]")) if not requested_types: - requested_types = set(base_types.keys()) + requested_types = set(NOTIFICATION_TYPES.keys()) requested_types.difference_update(excluded_types) # Use that to pull relevant events queryset = TimelineService(request.identity).notifications( - [base_types[r] for r in requested_types if r in base_types] + [NOTIFICATION_TYPES[r] for r in requested_types if r in NOTIFICATION_TYPES] ) paginator = MastodonPaginator() pager: PaginationResult[TimelineEvent] = paginator.paginate( @@ -48,3 +50,18 @@ def notifications( request=request, include_params=["limit", "account_id"], ) + + +@scope_required("read:notifications") +@api_view.get +def get_notification( + request: HttpRequest, + id: str, +) -> schemas.Notification: + notification = get_object_or_404( + TimelineService(request.identity).notifications( + list(NOTIFICATION_TYPES.values()) + ), + id=id, + ) + return schemas.Notification.from_timeline_event(notification) diff --git a/tests/api/notifications.py b/tests/api/notifications.py new file mode 100644 index 0000000..cebf2c9 --- /dev/null +++ b/tests/api/notifications.py @@ -0,0 +1,35 @@ +import pytest + +from activities.models import TimelineEvent + + +@pytest.mark.django_db +def test_notifications(api_client, identity, remote_identity): + event = TimelineEvent.objects.create( + identity=identity, + type=TimelineEvent.Types.followed, + subject_identity=remote_identity, + ) + + response = api_client.get("/api/v1/notifications").json() + + assert len(response) == 1 + assert response[0]["type"] == "follow" + assert response[0]["account"]["id"] == str(remote_identity.id) + + event.delete() + + +@pytest.mark.django_db +def test_get_notification(api_client, identity, remote_identity): + event = TimelineEvent.objects.create( + identity=identity, + type=TimelineEvent.Types.followed, + subject_identity=remote_identity, + ) + + response = api_client.get(f"/api/v1/notifications/{event.id}").json() + assert response["type"] == "follow" + assert response["account"]["id"] == str(remote_identity.id) + + event.delete()