diff --git a/stator/admin.py b/stator/admin.py index 025f225..790fc38 100644 --- a/stator/admin.py +++ b/stator/admin.py @@ -10,8 +10,7 @@ class DomainAdmin(admin.ModelAdmin): "date", "model_label", "instance_pk", - "from_state", - "to_state", + "state", "error", ] ordering = ["-date"] diff --git a/stator/graph.py b/stator/graph.py index 7fc23f7..7a8455c 100644 --- a/stator/graph.py +++ b/stator/graph.py @@ -1,16 +1,4 @@ -from typing import ( - Any, - Callable, - ClassVar, - Dict, - List, - Optional, - Set, - Tuple, - Type, - Union, - cast, -) +from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type class StateGraph: @@ -44,20 +32,43 @@ class StateGraph: terminal_states = set() initial_state = None for state in cls.states.values(): + # Check for multiple initial states if state.initial: if initial_state: raise ValueError( f"The graph has more than one initial state: {initial_state} and {state}" ) initial_state = state + # Collect terminal states if state.terminal: terminal_states.add(state) + # Ensure they do NOT have a handler + try: + state.handler + except AttributeError: + pass + else: + raise ValueError( + f"Terminal state '{state}' should not have a handler method ({state.handler_name})" + ) + else: + # Ensure non-terminal states have a try interval and a handler + if not state.try_interval: + raise ValueError( + f"State '{state}' has no try_interval and is not terminal" + ) + try: + state.handler + except AttributeError: + raise ValueError( + f"State '{state}' does not have a handler method ({state.handler_name})" + ) if initial_state is None: raise ValueError("The graph has no initial state") cls.initial_state = initial_state cls.terminal_states = terminal_states # Generate choices - cls.choices = [(state, name) for name, state in cls.states.items()] + cls.choices = [(name, name) for name in cls.states.keys()] class State: @@ -65,49 +76,37 @@ class State: Represents an individual state """ - def __init__(self, try_interval: float = 300): + def __init__( + self, + try_interval: Optional[float] = None, + handler_name: Optional[str] = None, + ): self.try_interval = try_interval + self.handler_name = handler_name self.parents: Set["State"] = set() - self.children: Dict["State", "Transition"] = {} + self.children: Set["State"] = set() def _add_to_graph(self, graph: Type[StateGraph], name: str): self.graph = graph self.name = name self.graph.states[name] = self + if self.handler_name is None: + self.handler_name = f"handle_{self.name}" def __repr__(self): return f"" - def __str__(self): - return self.name + def __eq__(self, other): + if isinstance(other, State): + return self is other + return self.name == other - def __len__(self): - return len(self.name) + def __hash__(self): + return hash(id(self)) - def add_transition( - self, - other: "State", - handler: Optional[Callable] = None, - priority: int = 0, - ) -> Callable: - def decorator(handler: Callable[[Any], bool]): - self.children[other] = Transition( - self, - other, - handler, - priority=priority, - ) - other.parents.add(self) - return handler - - # If we're not being called as a decorator, invoke it immediately - if handler is not None: - decorator(handler) - return decorator - - def add_manual_transition(self, other: "State"): - self.children[other] = ManualTransition(self, other) - other.parents.add(self) + def transitions_to(self, other: "State"): + self.children.add(other) + other.parents.add(other) @property def initial(self): @@ -117,59 +116,8 @@ class State: def terminal(self): return not self.children - def transitions(self, automatic_only=False) -> List["Transition"]: - """ - Returns all transitions from this State in priority order - """ - if automatic_only: - transitions = [t for t in self.children.values() if t.automatic] - else: - transitions = list(self.children.values()) - return sorted(transitions, key=lambda t: t.priority, reverse=True) - - -class Transition: - """ - A possible transition from one state to another - """ - - def __init__( - self, - from_state: State, - to_state: State, - handler: Union[str, Callable], - priority: int = 0, - ): - self.from_state = from_state - self.to_state = to_state - self.handler = handler - self.priority = priority - self.automatic = True - - def get_handler(self) -> Callable: - """ - Returns the handler (it might need resolving from a string) - """ - if isinstance(self.handler, str): - self.handler = getattr(self.from_state.graph, self.handler) - return cast(Callable, self.handler) - - def __repr__(self): - return f" {self.to_state}>" - - -class ManualTransition(Transition): - """ - A possible transition from one state to another that cannot be done by - the stator task runner, and must come from an external source. - """ - - def __init__( - self, - from_state: State, - to_state: State, - ): - self.from_state = from_state - self.to_state = to_state - self.priority = 0 - self.automatic = False + @property + def handler(self) -> Callable[[Any], Optional[str]]: + if self.handler_name is None: + raise AttributeError("No handler defined") + return getattr(self.graph, self.handler_name) diff --git a/stator/migrations/0001_initial.py b/stator/migrations/0001_initial.py index d56ed5c..f7d652e 100644 --- a/stator/migrations/0001_initial.py +++ b/stator/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.3 on 2022-11-10 03:24 +# Generated by Django 4.1.3 on 2022-11-10 05:56 from django.db import migrations, models @@ -24,8 +24,7 @@ class Migration(migrations.Migration): ), ("model_label", models.CharField(max_length=200)), ("instance_pk", models.CharField(max_length=200)), - ("from_state", models.CharField(max_length=200)), - ("to_state", models.CharField(max_length=200)), + ("state", models.CharField(max_length=200)), ("date", models.DateTimeField(auto_now_add=True)), ("error", models.TextField()), ("error_details", models.TextField(blank=True, null=True)), diff --git a/stator/models.py b/stator/models.py index 235b18c..50ee622 100644 --- a/stator/models.py +++ b/stator/models.py @@ -1,13 +1,13 @@ import datetime import traceback -from typing import ClassVar, List, Optional, Type, cast +from typing import ClassVar, List, Optional, Type, Union, cast from asgiref.sync import sync_to_async from django.db import models, transaction from django.utils import timezone from django.utils.functional import classproperty -from stator.graph import State, StateGraph, Transition +from stator.graph import State, StateGraph class StateField(models.CharField): @@ -29,16 +29,6 @@ class StateField(models.CharField): kwargs["graph"] = self.graph return name, path, args, kwargs - def from_db_value(self, value, expression, connection): - if value is None: - return value - return self.graph.states[value] - - def to_python(self, value): - if isinstance(value, State) or value is None: - return value - return self.graph.states[value] - def get_prep_value(self, value): if isinstance(value, State): return value.name @@ -95,7 +85,9 @@ class StatorModel(models.Model): ( models.Q( state_attempted__lte=timezone.now() - - datetime.timedelta(seconds=state.try_interval) + - datetime.timedelta( + seconds=cast(float, state.try_interval) + ) ) | models.Q(state_attempted__isnull=True) ), @@ -117,7 +109,7 @@ class StatorModel(models.Model): ].select_for_update() ) cls.objects.filter(pk__in=[i.pk for i in selected]).update( - state_locked_until=timezone.now() + state_locked_until=lock_expiry ) return selected @@ -143,36 +135,36 @@ class StatorModel(models.Model): self.state_ready = True self.save() - async def atransition_attempt(self) -> bool: + async def atransition_attempt(self) -> Optional[str]: """ Attempts to transition the current state by running its handler(s). """ - # Try each transition in priority order - for transition in self.state.transitions(automatic_only=True): - try: - success = await transition.get_handler()(self) - except BaseException as e: - await StatorError.acreate_from_instance(self, transition, e) - traceback.print_exc() - continue - if success: - await self.atransition_perform(transition.to_state.name) - return True + try: + next_state = await self.state_graph.states[self.state].handler(self) + except BaseException as e: + await StatorError.acreate_from_instance(self, e) + traceback.print_exc() + else: + if next_state: + await self.atransition_perform(next_state) + return next_state await self.__class__.objects.filter(pk=self.pk).aupdate( state_attempted=timezone.now(), state_locked_until=None, state_ready=False, ) - return False + return None - def transition_perform(self, state_name): + def transition_perform(self, state: Union[State, str]): """ Transitions the instance to the given state name, forcibly. """ - if state_name not in self.state_graph.states: - raise ValueError(f"Invalid state {state_name}") + if isinstance(state, State): + state = state.name + if state not in self.state_graph.states: + raise ValueError(f"Invalid state {state}") self.__class__.objects.filter(pk=self.pk).update( - state=state_name, + state=state, state_changed=timezone.now(), state_attempted=None, state_locked_until=None, @@ -194,11 +186,8 @@ class StatorError(models.Model): # The primary key of that model (probably int or str) instance_pk = models.CharField(max_length=200) - # The state we moved from - from_state = models.CharField(max_length=200) - - # The state we moved to (or tried to) - to_state = models.CharField(max_length=200) + # The state we were on + state = models.CharField(max_length=200) # When it happened date = models.DateTimeField(auto_now_add=True) @@ -213,14 +202,12 @@ class StatorError(models.Model): async def acreate_from_instance( cls, instance: StatorModel, - transition: Transition, exception: Optional[BaseException] = None, ): return await cls.objects.acreate( model_label=instance._meta.label_lower, instance_pk=str(instance.pk), - from_state=transition.from_state, - to_state=transition.to_state, + state=instance.state, error=str(exception), error_details=traceback.format_exc(), ) diff --git a/stator/runner.py b/stator/runner.py index f9c726e..1392e4d 100644 --- a/stator/runner.py +++ b/stator/runner.py @@ -1,6 +1,7 @@ import asyncio import datetime import time +import traceback import uuid from typing import List, Type @@ -53,7 +54,7 @@ class StatorRunner: f"Attempting transition on {instance._meta.label_lower}#{instance.pk}" ) self.tasks.append( - asyncio.create_task(instance.atransition_attempt()) + asyncio.create_task(self.run_transition(instance)) ) self.handled += 1 space_remaining -= 1 @@ -70,5 +71,17 @@ class StatorRunner: print("Complete") return self.handled + async def run_transition(self, instance: StatorModel): + """ + Wrapper for atransition_attempt with fallback error handling + """ + try: + await instance.atransition_attempt() + except BaseException: + traceback.print_exc() + def remove_completed_tasks(self): + """ + Removes all completed asyncio.Tasks from our local in-progress list + """ self.tasks = [t for t in self.tasks if not t.done()] diff --git a/stator/tests/test_graph.py b/stator/tests/test_graph.py index 0a7113d..c66f441 100644 --- a/stator/tests/test_graph.py +++ b/stator/tests/test_graph.py @@ -9,39 +9,29 @@ def test_declare(): lookups. """ - fake_handler = lambda: True - class TestGraph(StateGraph): - initial = State() - second = State() + initial = State(try_interval=3600) + second = State(try_interval=1) third = State() - fourth = State() - final = State() - initial.add_transition(second, 60, handler=fake_handler) - second.add_transition(third, 60, handler="check_third") + initial.transitions_to(second) + second.transitions_to(third) - def check_third(cls): - return True + @classmethod + def handle_initial(cls): + pass - @third.add_transition(fourth, 60) - def check_fourth(cls): - return True - - fourth.add_manual_transition(final) + @classmethod + def handle_second(cls): + pass assert TestGraph.initial_state == TestGraph.initial - assert TestGraph.terminal_states == {TestGraph.final} + assert TestGraph.terminal_states == {TestGraph.third} - assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler - assert ( - TestGraph.second.children[TestGraph.third].get_handler() - == TestGraph.check_third - ) - assert ( - TestGraph.third.children[TestGraph.fourth].get_handler().__name__ - == "check_fourth" - ) + assert TestGraph.initial.handler == TestGraph.handle_initial + assert TestGraph.initial.try_interval == 3600 + assert TestGraph.second.handler == TestGraph.handle_second + assert TestGraph.second.try_interval == 1 def test_bad_declarations(): @@ -62,5 +52,18 @@ def test_bad_declarations(): loop = State() loop2 = State() - loop.add_transition(loop2, 1, handler="fake") - loop2.add_transition(loop, 1, handler="fake") + loop.transitions_to(loop2) + loop2.transitions_to(loop) + + +def test_state(): + """ + Tests basic values of the State class + """ + + class TestGraph(StateGraph): + initial = State() + + assert "initial" == TestGraph.initial + assert TestGraph.initial == "initial" + assert TestGraph.initial == TestGraph.initial diff --git a/statuses/migrations/0001_initial.py b/statuses/migrations/0001_initial.py index 55c6c6c..c4a8fec 100644 --- a/statuses/migrations/0001_initial.py +++ b/statuses/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.3 on 2022-11-07 04:19 +# Generated by Django 4.1.3 on 2022-11-10 05:58 import django.db.models.deletion from django.db import migrations, models diff --git a/users/admin.py b/users/admin.py index e517b0a..f2b807c 100644 --- a/users/admin.py +++ b/users/admin.py @@ -20,7 +20,7 @@ class UserEventAdmin(admin.ModelAdmin): @admin.register(Identity) class IdentityAdmin(admin.ModelAdmin): - list_display = ["id", "handle", "actor_uri", "name", "local"] + list_display = ["id", "handle", "actor_uri", "state", "local"] @admin.register(Follow) diff --git a/users/migrations/0001_initial.py b/users/migrations/0001_initial.py index f5ebf55..2f64337 100644 --- a/users/migrations/0001_initial.py +++ b/users/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.3 on 2022-11-07 04:19 +# Generated by Django 4.1.3 on 2022-11-10 05:58 import functools @@ -6,7 +6,10 @@ import django.db.models.deletion from django.conf import settings from django.db import migrations, models +import stator.models +import users.models.follow import users.models.identity +import users.models.inbox_message class Migration(migrations.Migration): @@ -77,6 +80,37 @@ class Migration(migrations.Migration): ), ], ), + migrations.CreateModel( + name="InboxMessage", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("state_ready", models.BooleanField(default=False)), + ("state_changed", models.DateTimeField(auto_now_add=True)), + ("state_attempted", models.DateTimeField(blank=True, null=True)), + ("state_locked_until", models.DateTimeField(blank=True, null=True)), + ("message", models.JSONField()), + ( + "state", + stator.models.StateField( + choices=[("received", "received"), ("processed", "processed")], + default="received", + graph=users.models.inbox_message.InboxMessageStates, + max_length=100, + ), + ), + ], + options={ + "abstract": False, + }, + ), migrations.CreateModel( name="UserEvent", fields=[ @@ -124,7 +158,20 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), + ("state_ready", models.BooleanField(default=False)), + ("state_changed", models.DateTimeField(auto_now_add=True)), + ("state_attempted", models.DateTimeField(blank=True, null=True)), + ("state_locked_until", models.DateTimeField(blank=True, null=True)), ("actor_uri", models.CharField(max_length=500, unique=True)), + ( + "state", + stator.models.StateField( + choices=[("outdated", "outdated"), ("updated", "updated")], + default="outdated", + graph=users.models.identity.IdentityStates, + max_length=100, + ), + ), ("local", models.BooleanField()), ("username", models.CharField(blank=True, max_length=500, null=True)), ("name", models.CharField(blank=True, max_length=500, null=True)), @@ -239,10 +286,25 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), + ("state_ready", models.BooleanField(default=False)), + ("state_changed", models.DateTimeField(auto_now_add=True)), + ("state_attempted", models.DateTimeField(blank=True, null=True)), + ("state_locked_until", models.DateTimeField(blank=True, null=True)), ("uri", models.CharField(blank=True, max_length=500, null=True)), ("note", models.TextField(blank=True, null=True)), - ("requested", models.BooleanField(default=False)), - ("accepted", models.BooleanField(default=False)), + ( + "state", + stator.models.StateField( + choices=[ + ("pending", "pending"), + ("requested", "requested"), + ("accepted", "accepted"), + ], + default="pending", + graph=users.models.follow.FollowStates, + max_length=100, + ), + ), ("created", models.DateTimeField(auto_now_add=True)), ("updated", models.DateTimeField(auto_now=True)), ( diff --git a/users/migrations/0002_follow_state_follow_state_attempted_and_more.py b/users/migrations/0002_follow_state_follow_state_attempted_and_more.py deleted file mode 100644 index b33236a..0000000 --- a/users/migrations/0002_follow_state_follow_state_attempted_and_more.py +++ /dev/null @@ -1,44 +0,0 @@ -# Generated by Django 4.1.3 on 2022-11-07 19:22 - -import django.utils.timezone -from django.db import migrations, models - -import stator.models -import users.models.follow - - -class Migration(migrations.Migration): - - dependencies = [ - ("users", "0001_initial"), - ] - - operations = [ - migrations.AddField( - model_name="follow", - name="state", - field=stator.models.StateField( - choices=[ - ("pending", "pending"), - ("requested", "requested"), - ("accepted", "accepted"), - ], - default="pending", - graph=users.models.follow.FollowStates, - max_length=100, - ), - ), - migrations.AddField( - model_name="follow", - name="state_attempted", - field=models.DateTimeField(blank=True, null=True), - ), - migrations.AddField( - model_name="follow", - name="state_changed", - field=models.DateTimeField( - auto_now_add=True, default=django.utils.timezone.now - ), - preserve_default=False, - ), - ] diff --git a/users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py b/users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py deleted file mode 100644 index 180bfdd..0000000 --- a/users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py +++ /dev/null @@ -1,31 +0,0 @@ -# Generated by Django 4.1.3 on 2022-11-08 03:58 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("users", "0002_follow_state_follow_state_attempted_and_more"), - ] - - operations = [ - migrations.RemoveField( - model_name="follow", - name="accepted", - ), - migrations.RemoveField( - model_name="follow", - name="requested", - ), - migrations.AddField( - model_name="follow", - name="state_locked", - field=models.DateTimeField(blank=True, null=True), - ), - migrations.AddField( - model_name="follow", - name="state_runner", - field=models.CharField(blank=True, max_length=100, null=True), - ), - ] diff --git a/users/migrations/0004_remove_follow_state_locked_and_more.py b/users/migrations/0004_remove_follow_state_locked_and_more.py deleted file mode 100644 index bf98080..0000000 --- a/users/migrations/0004_remove_follow_state_locked_and_more.py +++ /dev/null @@ -1,21 +0,0 @@ -# Generated by Django 4.1.3 on 2022-11-09 05:15 - -from django.db import migrations - - -class Migration(migrations.Migration): - - dependencies = [ - ("users", "0003_remove_follow_accepted_remove_follow_requested_and_more"), - ] - - operations = [ - migrations.RemoveField( - model_name="follow", - name="state_locked", - ), - migrations.RemoveField( - model_name="follow", - name="state_runner", - ), - ] diff --git a/users/migrations/0005_follow_state_locked_until_follow_state_ready.py b/users/migrations/0005_follow_state_locked_until_follow_state_ready.py deleted file mode 100644 index 3aba08e..0000000 --- a/users/migrations/0005_follow_state_locked_until_follow_state_ready.py +++ /dev/null @@ -1,23 +0,0 @@ -# Generated by Django 4.1.3 on 2022-11-10 03:24 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("users", "0004_remove_follow_state_locked_and_more"), - ] - - operations = [ - migrations.AddField( - model_name="follow", - name="state_locked_until", - field=models.DateTimeField(blank=True, null=True), - ), - migrations.AddField( - model_name="follow", - name="state_ready", - field=models.BooleanField(default=False), - ), - ] diff --git a/users/models/__init__.py b/users/models/__init__.py index d46003f..28d62b0 100644 --- a/users/models/__init__.py +++ b/users/models/__init__.py @@ -2,5 +2,6 @@ from .block import Block # noqa from .domain import Domain # noqa from .follow import Follow, FollowStates # noqa from .identity import Identity, IdentityStates # noqa +from .inbox_message import InboxMessage, InboxMessageStates # noqa from .user import User # noqa from .user_event import UserEvent # noqa diff --git a/users/models/follow.py b/users/models/follow.py index 3325a0b..6f62481 100644 --- a/users/models/follow.py +++ b/users/models/follow.py @@ -6,16 +6,20 @@ from stator.models import State, StateField, StateGraph, StatorModel class FollowStates(StateGraph): - pending = State(try_interval=30) - requested = State() + unrequested = State(try_interval=30) + requested = State(try_interval=24 * 60 * 60) accepted = State() - @pending.add_transition(requested) - async def try_request(instance: "Follow"): # type:ignore - print("Would have tried to follow on", instance) - return False + unrequested.transitions_to(requested) + requested.transitions_to(accepted) - requested.add_manual_transition(accepted) + @classmethod + async def handle_unrequested(cls, instance: "Follow"): + print("Would have tried to follow on", instance) + + @classmethod + async def handle_requested(cls, instance: "Follow"): + print("Would have tried to requested on", instance) class Follow(StatorModel): @@ -73,3 +77,17 @@ class Follow(StatorModel): follow.state = FollowStates.accepted follow.save() return follow + + @classmethod + def remote_created(cls, source, target, uri): + follow = cls.maybe_get(source=source, target=target) + if follow is None: + follow = Follow.objects.create(source=source, target=target, uri=uri) + if follow.state == FollowStates.fresh: + follow.transition_perform(FollowStates.requested) + + @classmethod + def remote_accepted(cls, source, target): + follow = cls.maybe_get(source=source, target=target) + if follow and follow.state == FollowStates.requested: + follow.transition_perform(FollowStates.accepted) diff --git a/users/models/identity.py b/users/models/identity.py index 5e2cd06..7dff492 100644 --- a/users/models/identity.py +++ b/users/models/identity.py @@ -22,11 +22,16 @@ class IdentityStates(StateGraph): outdated = State(try_interval=3600) updated = State() - @outdated.add_transition(updated) - async def fetch_identity(identity: "Identity"): # type:ignore + outdated.transitions_to(updated) + + @classmethod + async def handle_outdated(cls, identity: "Identity"): + # Local identities never need fetching if identity.local: - return True - return await identity.fetch_actor() + return "updated" + # Run the actor fetch and progress to updated if it succeeds + if await identity.fetch_actor(): + return "updated" def upload_namer(prefix, instance, filename): diff --git a/users/models/inbox_message.py b/users/models/inbox_message.py new file mode 100644 index 0000000..0dbdc3a --- /dev/null +++ b/users/models/inbox_message.py @@ -0,0 +1,71 @@ +from asgiref.sync import sync_to_async +from django.db import models + +from stator.models import State, StateField, StateGraph, StatorModel +from users.models import Follow, Identity + + +class InboxMessageStates(StateGraph): + received = State(try_interval=300) + processed = State() + + received.transitions_to(processed) + + @classmethod + async def handle_received(cls, instance: "InboxMessage"): + type = instance.message["type"].lower() + if type == "follow": + await instance.follow_request() + elif type == "accept": + inner_type = instance.message["object"]["type"].lower() + if inner_type == "follow": + await instance.follow_accepted() + else: + raise ValueError(f"Cannot handle activity of type accept.{inner_type}") + elif type == "undo": + inner_type = instance.message["object"]["type"].lower() + if inner_type == "follow": + await instance.follow_undo() + else: + raise ValueError(f"Cannot handle activity of type undo.{inner_type}") + else: + raise ValueError(f"Cannot handle activity of type {type}") + + +class InboxMessage(StatorModel): + """ + an incoming inbox message that needs processing. + + Yes, this is kind of its own message queue built on the state graph system. + It's fine. It'll scale up to a decent point. + """ + + message = models.JSONField() + + state = StateField(InboxMessageStates) + + @sync_to_async + def follow_request(self): + """ + Handles an incoming follow request + """ + Follow.remote_created( + source=Identity.by_actor_uri_with_create(self.message["actor"]), + target=Identity.by_actor_uri(self.message["object"]), + uri=self.message["id"], + ) + + @sync_to_async + def follow_accepted(self): + """ + Handles an incoming acceptance of one of our follow requests + """ + Follow.remote_accepted( + source=Identity.by_actor_uri_with_create(self.message["actor"]), + target=Identity.by_actor_uri(self.message["object"]), + ) + + async def follow_undo(self): + """ + Handles an incoming follow undo + """ diff --git a/users/shortcuts.py b/users/shortcuts.py index 3e7618a..8e20a09 100644 --- a/users/shortcuts.py +++ b/users/shortcuts.py @@ -19,10 +19,7 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity: else: username, domain = handle.split("@", 1) # Resolve the domain to the display domain - domain_instance = Domain.get_domain(domain) - if domain_instance is None: - raise Http404("No matching domains found") - domain = domain_instance.domain + domain = Domain.get_remote_domain(domain).domain identity = Identity.by_username_and_domain( username, domain, diff --git a/users/tasks/identity.py b/users/tasks/identity.py deleted file mode 100644 index f5cd214..0000000 --- a/users/tasks/identity.py +++ /dev/null @@ -1,11 +0,0 @@ -from asgiref.sync import sync_to_async - -from users.models import Identity - - -async def handle_identity_fetch(task_handler): - # Get the actor URI via webfinger - actor_uri, handle = await Identity.fetch_webfinger(task_handler.subject) - # Get or create the identity, then fetch - identity = await sync_to_async(Identity.by_actor_uri_with_create)(actor_uri) - await identity.fetch_actor() diff --git a/users/tasks/inbox.py b/users/tasks/inbox.py deleted file mode 100644 index 27c602d..0000000 --- a/users/tasks/inbox.py +++ /dev/null @@ -1,56 +0,0 @@ -from asgiref.sync import sync_to_async - -from users.models import Follow, Identity - - -async def handle_inbox_item(task_handler): - type = task_handler.payload["type"].lower() - if type == "follow": - await inbox_follow(task_handler.payload) - elif type == "accept": - inner_type = task_handler.payload["object"]["type"].lower() - if inner_type == "follow": - await sync_to_async(accept_follow)(task_handler.payload["object"]) - else: - raise ValueError(f"Cannot handle activity of type accept.{inner_type}") - elif type == "undo": - inner_type = task_handler.payload["object"]["type"].lower() - if inner_type == "follow": - await inbox_unfollow(task_handler.payload["object"]) - else: - raise ValueError(f"Cannot handle activity of type undo.{inner_type}") - else: - raise ValueError(f"Cannot handle activity of type {inner_type}") - - -async def inbox_follow(payload): - """ - Handles an incoming follow request - """ - # TODO: Manually approved follows - source = Identity.by_actor_uri_with_create(payload["actor"]) - target = Identity.by_actor_uri(payload["object"]) - # See if this follow already exists - try: - follow = Follow.objects.get(source=source, target=target) - except Follow.DoesNotExist: - follow = Follow.objects.create(source=source, target=target, uri=payload["id"]) - # See if we need to acknowledge it - if not follow.acknowledged: - pass - - -async def inbox_unfollow(payload): - pass - - -def accept_follow(payload): - """ - Another server has acknowledged our follow request - """ - source = Identity.by_actor_uri_with_create(payload["actor"]) - target = Identity.by_actor_uri(payload["object"]) - follow = Follow.maybe_get(source, target) - if follow: - follow.accepted = True - follow.save() diff --git a/users/views/identity.py b/users/views/identity.py index d02505f..3e69dae 100644 --- a/users/views/identity.py +++ b/users/views/identity.py @@ -17,7 +17,7 @@ from core.forms import FormHelper from core.ld import canonicalise from core.signatures import HttpSignature from users.decorators import identity_required -from users.models import Domain, Follow, Identity, IdentityStates +from users.models import Domain, Follow, Identity, IdentityStates, InboxMessage from users.shortcuts import by_handle_or_404 @@ -117,9 +117,13 @@ class CreateIdentity(FormView): def clean(self): # Check for existing users - username = self.cleaned_data["username"] - domain = self.cleaned_data["domain"] - if Identity.objects.filter(username=username, domain=domain).exists(): + username = self.cleaned_data.get("username") + domain = self.cleaned_data.get("domain") + if ( + username + and domain + and Identity.objects.filter(username=username, domain=domain).exists() + ): raise forms.ValidationError(f"{username}@{domain} is already taken") def get_form(self): @@ -219,7 +223,7 @@ class Inbox(View): ): return HttpResponseBadRequest("Bad signature") # Hand off the item to be processed by the queue - Task.submit("inbox_item", subject=identity.actor_uri, payload=document) + InboxMessage.objects.create(message=document) return HttpResponse(status=202)