import dataclasses import urllib.parse from collections.abc import Callable from typing import Any, Generic, Protocol, TypeVar from django.db import models from django.db.models.expressions import Case, F, When from django.http import HttpRequest from hatchway.http import ApiResponse from activities.models import PostInteraction, TimelineEvent T = TypeVar("T") class SchemaWithId(Protocol): """ Little protocol type to represent schemas that have an ID attribute """ id: str TI = TypeVar("TI", bound=SchemaWithId) TM = TypeVar("TM", bound=models.Model) class PaginatingApiResponse(ApiResponse[list[TI]]): """ An ApiResponse subclass that also handles pagination link headers """ def __init__( self, data: list[TI], request: HttpRequest, include_params: list[str], **kwargs, ): # Call superclass super().__init__(data, **kwargs) # Figure out if we need link headers self._request = request self.extra_params = self.filter_params(self._request, include_params) link_header = self.build_link_header() if link_header: self.headers["link"] = link_header @staticmethod def filter_params(request: HttpRequest, allowed_params: list[str]): params = {} for key in allowed_params: value = request.GET.get(key, None) if value: params[key] = value return params def get_part(self, data_index: int, param_name: str, rel: str) -> str | None: """ Used to get next/prev URLs """ if not self.data: return None # Use the ID of the last object for the next page start params = dict(self.extra_params) params[param_name] = self.data[data_index].id return ( "<" + self._request.build_absolute_uri(self._request.path) + "?" + urllib.parse.urlencode(params) + f'>; rel="{rel}"' ) def build_link_header(self): parts = [ entry for entry in [ self.get_part(-1, "max_id", "next"), self.get_part(0, "min_id", "prev"), ] if entry ] if not parts: return None return ", ".join(parts) @dataclasses.dataclass class PaginationResult(Generic[T]): """ Represents a pagination result for Mastodon (it does Link header stuff) """ #: A list of objects that matched the pagination query. results: list[T] #: The actual applied limit, which may be different from what was requested. limit: int #: A list of transformed JSON objects json_results: list[dict] | None = None @classmethod def empty(cls): return cls(results=[], limit=20) def next(self, request: HttpRequest, allowed_params: list[str]): """ Returns a URL to the next page of results. """ if not self.results: return None if self.json_results is None: raise ValueError("You must JSONify the results first") params = self.filter_params(request, allowed_params) params["max_id"] = self.json_results[-1]["id"] return f"{request.build_absolute_uri(request.path)}?{urllib.parse.urlencode(params)}" def prev(self, request: HttpRequest, allowed_params: list[str]): """ Returns a URL to the previous page of results. """ if not self.results: return None if self.json_results is None: raise ValueError("You must JSONify the results first") params = self.filter_params(request, allowed_params) params["min_id"] = self.json_results[0]["id"] return f"{request.build_absolute_uri(request.path)}?{urllib.parse.urlencode(params)}" def link_header(self, request: HttpRequest, allowed_params: list[str]): """ Creates a link header for the given request """ return ", ".join( ( f'<{self.next(request, allowed_params)}>; rel="next"', f'<{self.prev(request, allowed_params)}>; rel="prev"', ) ) def jsonify_results(self, map_function: Callable[[Any], Any]): """ Replaces our results with ones transformed via map_function """ self.json_results = [map_function(result) for result in self.results] def jsonify_posts(self, identity): """ Predefined way of JSON-ifying Post objects """ interactions = PostInteraction.get_post_interactions(self.results, identity) self.jsonify_results( lambda post: post.to_mastodon_json( interactions=interactions, identity=identity ) ) def jsonify_status_events(self, identity): """ Predefined way of JSON-ifying TimelineEvent objects representing statuses """ interactions = PostInteraction.get_event_interactions(self.results, identity) self.jsonify_results( lambda event: event.to_mastodon_status_json( interactions=interactions, identity=identity ) ) def jsonify_notification_events(self, identity): """ Predefined way of JSON-ifying TimelineEvent objects representing notifications """ interactions = PostInteraction.get_event_interactions(self.results, identity) self.jsonify_results( lambda event: event.to_mastodon_notification_json(interactions=interactions) ) def jsonify_identities(self): """ Predefined way of JSON-ifying Identity objects """ self.jsonify_results(lambda identity: identity.to_mastodon_json()) @staticmethod def filter_params(request: HttpRequest, allowed_params: list[str]): params = {} for key in allowed_params: value = request.GET.get(key, None) if value: params[key] = value return params class MastodonPaginator: """ Paginates in the Mastodon style (max_id, min_id, etc). Note that this basically _requires_ us to always do it on IDs, so we do. """ def __init__( self, default_limit: int = 20, max_limit: int = 40, ): self.default_limit = default_limit self.max_limit = max_limit def paginate( self, queryset: models.QuerySet[TM], min_id: str | None, max_id: str | None, since_id: str | None, limit: int | None, home: bool = False, ) -> PaginationResult[TM]: limit = min(limit or self.default_limit, self.max_limit) filters = {} id_field = "id" reverse = False if home: # The home timeline interleaves Post IDs and PostInteraction IDs in an # annotated field called "subject_id". id_field = "subject_id" queryset = queryset.annotate( subject_id=Case( When(type=TimelineEvent.Types.post, then=F("subject_post_id")), default=F("subject_post_interaction"), ) ) # These "does not start with interaction" checks can be removed after a # couple months, when clients have flushed them out. if max_id and not max_id.startswith("interaction"): filters[f"{id_field}__lt"] = max_id if since_id and not since_id.startswith("interaction"): filters[f"{id_field}__gt"] = since_id if min_id and not min_id.startswith("interaction"): # Min ID requires items _immediately_ newer than specified, so we # invert the ordering to accommodate filters[f"{id_field}__gt"] = min_id reverse = True # Default is to order by ID descending (newest first), except for min_id # queries, which should order by ID for limiting, then reverse the results to be # consistent. The clearest explanation of this I've found so far is this: # https://mastodon.social/@Gargron/100846335353411164 ordering = id_field if reverse else f"-{id_field}" results = list(queryset.filter(**filters).order_by(ordering)[:limit]) if reverse: results.reverse() return PaginationResult( results=results, limit=limit, )