diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 260a987..abc8050 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,12 @@ repos: - id: pretty-format-json - id: trailing-whitespace + - repo: https://github.com/asottile/pyupgrade + rev: "v3.3.0" + hooks: + - id: pyupgrade + args: [--py310-plus] + - repo: https://github.com/psf/black rev: 22.10.0 hooks: diff --git a/activities/models/hashtag.py b/activities/models/hashtag.py index b7f0832..162f8b4 100644 --- a/activities/models/hashtag.py +++ b/activities/models/hashtag.py @@ -1,6 +1,5 @@ import re from datetime import date, timedelta -from typing import Dict, List import urlman from asgiref.sync import sync_to_async @@ -138,7 +137,7 @@ class Hashtag(StatorModel): def __str__(self): return self.display_name - def usage_months(self, num: int = 12) -> Dict[date, int]: + def usage_months(self, num: int = 12) -> dict[date, int]: """ Return the most recent num months of stats """ @@ -153,7 +152,7 @@ class Hashtag(StatorModel): results[date(year, month, 1)] = val return dict(sorted(results.items(), reverse=True)[:num]) - def usage_days(self, num: int = 7) -> Dict[date, int]: + def usage_days(self, num: int = 7) -> dict[date, int]: """ Return the most recent num days of stats """ @@ -170,7 +169,7 @@ class Hashtag(StatorModel): return dict(sorted(results.items(), reverse=True)[:num]) @classmethod - def hashtags_from_content(cls, content) -> List[str]: + def hashtags_from_content(cls, content) -> list[str]: """ Return a parsed and sanitized of hashtags found in content without leading '#'. diff --git a/activities/models/post.py b/activities/models/post.py index ee1f393..5ca5a1b 100644 --- a/activities/models/post.py +++ b/activities/models/post.py @@ -1,5 +1,6 @@ import re -from typing import Dict, Iterable, List, Optional, Set +from collections.abc import Iterable +from typing import Optional import httpx import urlman @@ -324,10 +325,10 @@ class Post(StatorModel): cls, author: Identity, content: str, - summary: Optional[str] = None, + summary: str | None = None, visibility: int = Visibilities.public, reply_to: Optional["Post"] = None, - attachments: Optional[List] = None, + attachments: list | None = None, ) -> "Post": with transaction.atomic(): # Find mentions in this post @@ -363,9 +364,9 @@ class Post(StatorModel): def edit_local( self, content: str, - summary: Optional[str] = None, + summary: str | None = None, visibility: int = Visibilities.public, - attachments: Optional[List] = None, + attachments: list | None = None, ): with transaction.atomic(): # Strip all HTML and apply linebreaks filter @@ -380,7 +381,7 @@ class Post(StatorModel): self.save() @classmethod - def mentions_from_content(cls, content, author) -> Set[Identity]: + def mentions_from_content(cls, content, author) -> set[Identity]: mention_hits = cls.mention_regex.findall(content) mentions = set() for precursor, handle in mention_hits: @@ -413,7 +414,7 @@ class Post(StatorModel): ### ActivityPub (outbound) ### - def to_ap(self) -> Dict: + def to_ap(self) -> dict: """ Returns the AP JSON for this object """ diff --git a/activities/models/post_interaction.py b/activities/models/post_interaction.py index a913a0f..0bc2fff 100644 --- a/activities/models/post_interaction.py +++ b/activities/models/post_interaction.py @@ -1,5 +1,3 @@ -from typing import Dict - from django.db import models, transaction from django.utils import timezone @@ -195,7 +193,7 @@ class PostInteraction(StatorModel): ### ActivityPub (outbound) ### - def to_ap(self) -> Dict: + def to_ap(self) -> dict: """ Returns the AP JSON for this object """ @@ -223,7 +221,7 @@ class PostInteraction(StatorModel): raise ValueError("Cannot turn into AP") return value - def to_undo_ap(self) -> Dict: + def to_undo_ap(self) -> dict: """ Returns the AP JSON to undo this object """ diff --git a/activities/views/search.py b/activities/views/search.py index ccfc1a4..f7ab237 100644 --- a/activities/views/search.py +++ b/activities/views/search.py @@ -1,5 +1,3 @@ -from typing import Set - import httpx from asgiref.sync import async_to_sync from django import forms @@ -32,7 +30,7 @@ class Search(FormView): # Try to fetch the user by handle query = query.lstrip("@") - results: Set[Identity] = set() + results: set[Identity] = set() if "@" in query: username, domain = query.split("@", 1) @@ -118,7 +116,7 @@ class Search(FormView): if "@" in query or "://" in query: return set() - results: Set[Hashtag] = set() + results: set[Hashtag] = set() query = query.lstrip("#") for hashtag in Hashtag.objects.public().hashtag_or_alias(query)[:10]: results.add(hashtag) diff --git a/core/htmx.py b/core/htmx.py index c83fba9..a6cf6dd 100644 --- a/core/htmx.py +++ b/core/htmx.py @@ -1,8 +1,5 @@ -from typing import Optional - - class HTMXMixin: - template_name_htmx: Optional[str] = None + template_name_htmx: str | None = None def get_template_name(self): if self.request.htmx and self.template_name_htmx: diff --git a/core/ld.py b/core/ld.py index f70642a..4b01e71 100644 --- a/core/ld.py +++ b/core/ld.py @@ -1,7 +1,6 @@ import datetime import os import urllib.parse as urllib_parse -from typing import Dict, List, Optional, Union from pyld import jsonld from pyld.jsonld import JsonLdError @@ -396,7 +395,7 @@ def builtin_document_loader(url: str, options={}): ) -def canonicalise(json_data: Dict, include_security: bool = False) -> Dict: +def canonicalise(json_data: dict, include_security: bool = False) -> dict: """ Given an ActivityPub JSON-LD document, round-trips it through the LD systems to end up in a canonicalised, compacted format. @@ -408,7 +407,7 @@ def canonicalise(json_data: Dict, include_security: bool = False) -> Dict: """ if not isinstance(json_data, dict): raise ValueError("Pass decoded JSON data into LDDocument") - context: Union[str, List[str]] + context: str | list[str] if include_security: context = [ "https://www.w3.org/ns/activitystreams", @@ -422,7 +421,7 @@ def canonicalise(json_data: Dict, include_security: bool = False) -> Dict: return jsonld.compact(jsonld.expand(json_data), context) -def get_list(container, key) -> List: +def get_list(container, key) -> list: """ Given a JSON-LD value (that can be either a list, or a dict if it's just one item), always returns a list""" @@ -438,7 +437,7 @@ def format_ld_date(value: datetime.datetime) -> str: return value.strftime(DATETIME_FORMAT) -def parse_ld_date(value: Optional[str]) -> Optional[datetime.datetime]: +def parse_ld_date(value: str | None) -> datetime.datetime | None: if value is None: return None try: diff --git a/core/signatures.py b/core/signatures.py index ead33da..640483a 100644 --- a/core/signatures.py +++ b/core/signatures.py @@ -1,6 +1,6 @@ import base64 import json -from typing import Dict, List, Literal, Optional, Tuple, TypedDict +from typing import Literal, TypedDict from urllib.parse import urlparse import httpx @@ -35,7 +35,7 @@ class VerificationFormatError(VerificationError): class RsaKeys: @classmethod - def generate_keypair(cls) -> Tuple[str, str]: + def generate_keypair(cls) -> tuple[str, str]: """ Generates a new RSA keypair """ @@ -77,7 +77,7 @@ class HttpSignature: raise ValueError(f"Unknown digest algorithm {algorithm}") @classmethod - def headers_from_request(cls, request: HttpRequest, header_names: List[str]) -> str: + def headers_from_request(cls, request: HttpRequest, header_names: list[str]) -> str: """ Creates the to-be-signed header payload from a Django request """ @@ -170,7 +170,7 @@ class HttpSignature: async def signed_request( cls, uri: str, - body: Optional[Dict], + body: dict | None, private_key: str, key_id: str, content_type: str = "application/json", @@ -239,7 +239,7 @@ class HttpSignature: class HttpSignatureDetails(TypedDict): algorithm: str - headers: List[str] + headers: list[str] signature: bytes keyid: str @@ -250,7 +250,7 @@ class LDSignature: """ @classmethod - def verify_signature(cls, document: Dict, public_key: str) -> None: + def verify_signature(cls, document: dict, public_key: str) -> None: """ Verifies a document """ @@ -285,13 +285,13 @@ class LDSignature: @classmethod def create_signature( - cls, document: Dict, private_key: str, key_id: str - ) -> Dict[str, str]: + cls, document: dict, private_key: str, key_id: str + ) -> dict[str, str]: """ Creates the signature for a document """ # Create the options document - options: Dict[str, str] = { + options: dict[str, str] = { "@context": "https://w3id.org/identity/v1", "creator": key_id, "created": format_ld_date(timezone.now()), diff --git a/stator/graph.py b/stator/graph.py index 5c71d4a..0ec5ee7 100644 --- a/stator/graph.py +++ b/stator/graph.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type +from collections.abc import Callable +from typing import Any, ClassVar class StateGraph: @@ -7,11 +8,11 @@ class StateGraph: Does not support subclasses of existing graphs yet. """ - states: ClassVar[Dict[str, "State"]] - choices: ClassVar[List[Tuple[object, str]]] + states: ClassVar[dict[str, "State"]] + choices: ClassVar[list[tuple[object, str]]] initial_state: ClassVar["State"] - terminal_states: ClassVar[Set["State"]] - automatic_states: ClassVar[Set["State"]] + terminal_states: ClassVar[set["State"]] + automatic_states: ClassVar[set["State"]] def __init_subclass__(cls) -> None: # Collect state members @@ -84,8 +85,8 @@ class State: def __init__( self, - try_interval: Optional[float] = None, - handler_name: Optional[str] = None, + try_interval: float | None = None, + handler_name: str | None = None, externally_progressed: bool = False, attempt_immediately: bool = True, force_initial: bool = False, @@ -95,10 +96,10 @@ class State: self.externally_progressed = externally_progressed self.attempt_immediately = attempt_immediately self.force_initial = force_initial - self.parents: Set["State"] = set() - self.children: Set["State"] = set() + self.parents: set["State"] = set() + self.children: set["State"] = set() - def _add_to_graph(self, graph: Type[StateGraph], name: str): + def _add_to_graph(self, graph: type[StateGraph], name: str): self.graph = graph self.name = name self.graph.states[name] = self @@ -132,7 +133,7 @@ class State: return not self.children @property - def handler(self) -> Callable[[Any], Optional[str]]: + def handler(self) -> Callable[[Any], str | None]: # Retrieve it by name off the graph if self.handler_name is None: raise AttributeError("No handler defined") diff --git a/stator/management/commands/runstator.py b/stator/management/commands/runstator.py index 4d52520..bec88d6 100644 --- a/stator/management/commands/runstator.py +++ b/stator/management/commands/runstator.py @@ -1,4 +1,4 @@ -from typing import List, Type, cast +from typing import cast from asgiref.sync import async_to_sync from django.apps import apps @@ -44,7 +44,7 @@ class Command(BaseCommand): def handle( self, - model_labels: List[str], + model_labels: list[str], concurrency: int, liveness_file: str, schedule_interval: int, @@ -56,7 +56,7 @@ class Command(BaseCommand): Config.system = Config.load_system() # Resolve the models list into names models = cast( - List[Type[StatorModel]], + list[type[StatorModel]], [apps.get_model(label) for label in model_labels], ) if not models: diff --git a/stator/models.py b/stator/models.py index 5257ac9..261584c 100644 --- a/stator/models.py +++ b/stator/models.py @@ -1,7 +1,7 @@ import datetime import pprint import traceback -from typing import ClassVar, List, Optional, Type, Union, cast +from typing import ClassVar, cast from asgiref.sync import sync_to_async from django.db import models, transaction @@ -17,7 +17,7 @@ class StateField(models.CharField): A special field that automatically gets choices from a state graph """ - def __init__(self, graph: Type[StateGraph], **kwargs): + def __init__(self, graph: type[StateGraph], **kwargs): # Sensible default for state length kwargs.setdefault("max_length", 100) # Add choices and initial @@ -61,7 +61,7 @@ class StatorModel(models.Model): state_locked_until = models.DateTimeField(null=True, blank=True) # Collection of subclasses of us - subclasses: ClassVar[List[Type["StatorModel"]]] = [] + subclasses: ClassVar[list[type["StatorModel"]]] = [] class Meta: abstract = True @@ -71,7 +71,7 @@ class StatorModel(models.Model): cls.subclasses.append(cls) @classproperty - def state_graph(cls) -> Type[StateGraph]: + def state_graph(cls) -> type[StateGraph]: return cls._meta.get_field("state").graph @property @@ -104,7 +104,7 @@ class StatorModel(models.Model): @classmethod def transition_get_with_lock( cls, number: int, lock_expiry: datetime.datetime - ) -> List["StatorModel"]: + ) -> list["StatorModel"]: """ Returns up to `number` tasks for execution, having locked them. """ @@ -124,7 +124,7 @@ class StatorModel(models.Model): @classmethod async def atransition_get_with_lock( cls, number: int, lock_expiry: datetime.datetime - ) -> List["StatorModel"]: + ) -> list["StatorModel"]: return await sync_to_async(cls.transition_get_with_lock)(number, lock_expiry) @classmethod @@ -143,7 +143,7 @@ class StatorModel(models.Model): self.state_ready = True self.save() - async def atransition_attempt(self) -> Optional[State]: + async def atransition_attempt(self) -> State | None: """ Attempts to transition the current state by running its handler(s). """ @@ -180,7 +180,7 @@ class StatorModel(models.Model): ) return None - def transition_perform(self, state: Union[State, str]): + def transition_perform(self, state: State | str): """ Transitions the instance to the given state name, forcibly. """ @@ -237,7 +237,7 @@ class StatorError(models.Model): async def acreate_from_instance( cls, instance: StatorModel, - exception: Optional[BaseException] = None, + exception: BaseException | None = None, ): detail = traceback.format_exc() if exception and len(exception.args) > 1: diff --git a/stator/runner.py b/stator/runner.py index ecbaae6..0d8f9ea 100644 --- a/stator/runner.py +++ b/stator/runner.py @@ -3,7 +3,6 @@ import datetime import time import traceback import uuid -from typing import List, Optional, Type from django.utils import timezone @@ -20,10 +19,10 @@ class StatorRunner: def __init__( self, - models: List[Type[StatorModel]], + models: list[type[StatorModel]], concurrency: int = 50, concurrency_per_model: int = 10, - liveness_file: Optional[str] = None, + liveness_file: str | None = None, schedule_interval: int = 30, lock_expiry: int = 300, run_for: int = 0, diff --git a/takahe/settings.py b/takahe/settings.py index 89d4d3a..7952024 100644 --- a/takahe/settings.py +++ b/takahe/settings.py @@ -3,7 +3,7 @@ import secrets import sys import urllib.parse from pathlib import Path -from typing import List, Literal, Optional, Union +from typing import Literal import dj_database_url import sentry_sdk @@ -24,7 +24,7 @@ class MediaBackendUrl(AnyUrl): allowed_schemes = {"s3", "gcs", "local"} -def as_bool(v: Optional[Union[str, List[str]]]): +def as_bool(v: str | list[str] | None): if v is None: return False @@ -48,7 +48,7 @@ class Settings(BaseSettings): """ #: The default database. - DATABASE_SERVER: Optional[ImplicitHostname] + DATABASE_SERVER: ImplicitHostname | None #: The currently running environment, used for things such as sentry #: error reporting. @@ -66,19 +66,19 @@ class Settings(BaseSettings): #: If set, a list of allowed values for the HOST header. The default value #: of '*' means any host will be accepted. - ALLOWED_HOSTS: List[str] = Field(default_factory=lambda: ["*"]) + ALLOWED_HOSTS: list[str] = Field(default_factory=lambda: ["*"]) #: If set, a list of hosts to accept for CORS. - CORS_HOSTS: List[str] = Field(default_factory=list) + CORS_HOSTS: list[str] = Field(default_factory=list) #: If set, a list of hosts to accept for CSRF. - CSRF_HOSTS: List[str] = Field(default_factory=list) + CSRF_HOSTS: list[str] = Field(default_factory=list) #: If enabled, trust the HTTP_X_FORWARDED_FOR header. USE_PROXY_HEADERS: bool = False #: An optional Sentry DSN for error reporting. - SENTRY_DSN: Optional[str] = None + SENTRY_DSN: str | None = None SENTRY_SAMPLE_RATE: float = 1.0 SENTRY_TRACES_SAMPLE_RATE: float = 1.0 @@ -87,12 +87,12 @@ class Settings(BaseSettings): EMAIL_SERVER: AnyUrl = "console://localhost" EMAIL_FROM: EmailStr = "test@example.com" - AUTO_ADMIN_EMAIL: Optional[EmailStr] = None - ERROR_EMAILS: Optional[List[EmailStr]] = None + AUTO_ADMIN_EMAIL: EmailStr | None = None + ERROR_EMAILS: list[EmailStr] | None = None MEDIA_URL: str = "/media/" MEDIA_ROOT: str = str(BASE_DIR / "media") - MEDIA_BACKEND: Optional[MediaBackendUrl] = None + MEDIA_BACKEND: MediaBackendUrl | None = None #: Maximum filesize when uploading images. Increasing this may increase memory utilization #: because all images with a dimension greater than 2000px are resized to meet that limit, which @@ -107,11 +107,11 @@ class Settings(BaseSettings): #: (placeholder setting, no effect) SEARCH: bool = True - PGHOST: Optional[str] = None - PGPORT: Optional[int] = 5432 + PGHOST: str | None = None + PGPORT: int | None = 5432 PGNAME: str = "takahe" PGUSER: str = "postgres" - PGPASSWORD: Optional[str] = None + PGPASSWORD: str | None = None @validator("PGHOST", always=True) def validate_db(cls, PGHOST, values): # noqa diff --git a/tests/activities/views/test_compose.py b/tests/activities/views/test_compose.py index 60bbca1..2b8c4ea 100644 --- a/tests/activities/views/test_compose.py +++ b/tests/activities/views/test_compose.py @@ -1,6 +1,6 @@ import re +from unittest import mock -import mock import pytest from django.core.exceptions import PermissionDenied diff --git a/tests/activities/views/test_timelines.py b/tests/activities/views/test_timelines.py index 6c8b355..74bf43d 100644 --- a/tests/activities/views/test_timelines.py +++ b/tests/activities/views/test_timelines.py @@ -1,4 +1,5 @@ -import mock +from unittest import mock + import pytest from activities.views.timelines import Home diff --git a/tests/users/views/test_auth.py b/tests/users/views/test_auth.py index 22e1fb6..f3a34c0 100644 --- a/tests/users/views/test_auth.py +++ b/tests/users/views/test_auth.py @@ -1,4 +1,5 @@ -import mock +from unittest import mock + import pytest from core.models import Config diff --git a/users/models/identity.py b/users/models/identity.py index bbedceb..c674bf4 100644 --- a/users/models/identity.py +++ b/users/models/identity.py @@ -1,5 +1,5 @@ from functools import cached_property, partial -from typing import Dict, Literal, Optional, Tuple +from typing import Literal from urllib.parse import urlparse import httpx @@ -334,7 +334,7 @@ class Identity(StatorModel): ### Actor/Webfinger fetching ### @classmethod - async def fetch_webfinger(cls, handle: str) -> Tuple[Optional[str], Optional[str]]: + async def fetch_webfinger(cls, handle: str) -> tuple[str | None, str | None]: """ Given a username@domain handle, returns a tuple of (actor uri, canonical handle) or None, None if it does not resolve. @@ -458,7 +458,7 @@ class Identity(StatorModel): raise ValueError( f"Could not save Identity at end of actor fetch: {e}" ) - self.pk: Optional[int] = other_row.pk + self.pk: int | None = other_row.pk await sync_to_async(self.save)() return True @@ -468,7 +468,7 @@ class Identity(StatorModel): self, method: Literal["get", "post"], uri: str, - body: Optional[Dict] = None, + body: dict | None = None, ): """ Performs a signed request on behalf of the System Actor. diff --git a/users/models/system_actor.py b/users/models/system_actor.py index c337d78..c4319b9 100644 --- a/users/models/system_actor.py +++ b/users/models/system_actor.py @@ -1,4 +1,4 @@ -from typing import Dict, Literal, Optional +from typing import Literal from django.conf import settings @@ -57,7 +57,7 @@ class SystemActor: self, method: Literal["get", "post"], uri: str, - body: Optional[Dict] = None, + body: dict | None = None, ): """ Performs a signed request on behalf of the System Actor. diff --git a/users/models/user.py b/users/models/user.py index 08a703e..e0cac9d 100644 --- a/users/models/user.py +++ b/users/models/user.py @@ -1,5 +1,3 @@ -from typing import List - from django.contrib.auth.models import AbstractBaseUser, BaseUserManager from django.db import models @@ -42,7 +40,7 @@ class User(AbstractBaseUser): USERNAME_FIELD = "email" EMAIL_FIELD = "email" - REQUIRED_FIELDS: List[str] = [] + REQUIRED_FIELDS: list[str] = [] objects = UserManager() diff --git a/users/views/settings/interface.py b/users/views/settings/interface.py index fe8e1e9..5c4f229 100644 --- a/users/views/settings/interface.py +++ b/users/views/settings/interface.py @@ -1,5 +1,5 @@ from functools import partial -from typing import ClassVar, Dict, List +from typing import ClassVar from django import forms from django.core.files import File @@ -21,8 +21,8 @@ class SettingsPage(FormView): options_class = Config.IdentityOptions template_name = "settings/settings.html" section: ClassVar[str] - options: Dict[str, Dict[str, str]] - layout: Dict[str, List[str]] + options: dict[str, dict[str, str]] + layout: dict[str, list[str]] def get_form_class(self): # Create the fields dict from the config object