Most of the way through the stator refactor

pull/2/head
Andrew Godwin 2022-11-09 22:29:33 -07:00
rodzic 61c324508e
commit 7746abbbb7
16 zmienionych plików z 277 dodań i 152 usunięć

Wyświetl plik

@ -1,8 +1,17 @@
from django.contrib import admin
from stator.models import StatorTask
from stator.models import StatorError
@admin.register(StatorTask)
@admin.register(StatorError)
class DomainAdmin(admin.ModelAdmin):
list_display = ["id", "model_label", "instance_pk", "locked_until"]
list_display = [
"id",
"date",
"model_label",
"instance_pk",
"from_state",
"to_state",
"error",
]
ordering = ["-date"]

Wyświetl plik

@ -1,9 +1,16 @@
import datetime
from functools import wraps
from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
from django.db import models
from django.utils import timezone
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union,
cast,
)
class StateGraph:
@ -13,7 +20,7 @@ class StateGraph:
"""
states: ClassVar[Dict[str, "State"]]
choices: ClassVar[List[Tuple[str, str]]]
choices: ClassVar[List[Tuple[object, str]]]
initial_state: ClassVar["State"]
terminal_states: ClassVar[Set["State"]]
@ -50,7 +57,7 @@ class StateGraph:
cls.initial_state = initial_state
cls.terminal_states = terminal_states
# Generate choices
cls.choices = [(name, name) for name in cls.states.keys()]
cls.choices = [(state, name) for name, state in cls.states.items()]
class State:
@ -63,7 +70,7 @@ class State:
self.parents: Set["State"] = set()
self.children: Dict["State", "Transition"] = {}
def _add_to_graph(self, graph: StateGraph, name: str):
def _add_to_graph(self, graph: Type[StateGraph], name: str):
self.graph = graph
self.name = name
self.graph.states[name] = self
@ -71,13 +78,19 @@ class State:
def __repr__(self):
return f"<State {self.name}>"
def __str__(self):
return self.name
def __len__(self):
return len(self.name)
def add_transition(
self,
other: "State",
handler: Optional[Union[str, Callable]] = None,
handler: Optional[Callable] = None,
priority: int = 0,
) -> Callable:
def decorator(handler: Union[str, Callable]):
def decorator(handler: Callable[[Any], bool]):
self.children[other] = Transition(
self,
other,
@ -85,9 +98,7 @@ class State:
priority=priority,
)
other.parents.add(self)
# All handlers should be class methods, so do that automatically.
if callable(handler):
return classmethod(handler)
return handler
# If we're not being called as a decorator, invoke it immediately
if handler is not None:
@ -113,7 +124,7 @@ class State:
if automatic_only:
transitions = [t for t in self.children.values() if t.automatic]
else:
transitions = self.children.values()
transitions = list(self.children.values())
return sorted(transitions, key=lambda t: t.priority, reverse=True)
@ -141,7 +152,10 @@ class Transition:
"""
if isinstance(self.handler, str):
self.handler = getattr(self.from_state.graph, self.handler)
return self.handler
return cast(Callable, self.handler)
def __repr__(self):
return f"<Transition {self.from_state} -> {self.to_state}>"
class ManualTransition(Transition):
@ -157,6 +171,5 @@ class ManualTransition(Transition):
):
self.from_state = from_state
self.to_state = to_state
self.handler = None
self.priority = 0
self.automatic = False

Wyświetl plik

@ -0,0 +1,28 @@
from typing import List, Type, cast
from asgiref.sync import async_to_sync
from django.apps import apps
from django.core.management.base import BaseCommand
from stator.models import StatorModel
from stator.runner import StatorRunner
class Command(BaseCommand):
help = "Runs a Stator runner for a short period"
def add_arguments(self, parser):
parser.add_argument("model_labels", nargs="*", type=str)
def handle(self, model_labels: List[str], *args, **options):
# Resolve the models list into names
models = cast(
List[Type[StatorModel]],
[apps.get_model(label) for label in model_labels],
)
if not models:
models = StatorModel.subclasses
print("Running for models: " + " ".join(m._meta.label_lower for m in models))
# Run a runner
runner = StatorRunner(models)
async_to_sync(runner.run)()

Wyświetl plik

@ -1,4 +1,4 @@
# Generated by Django 4.1.3 on 2022-11-09 05:46
# Generated by Django 4.1.3 on 2022-11-10 03:24
from django.db import migrations, models
@ -11,7 +11,7 @@ class Migration(migrations.Migration):
operations = [
migrations.CreateModel(
name="StatorTask",
name="StatorError",
fields=[
(
"id",
@ -24,8 +24,11 @@ class Migration(migrations.Migration):
),
("model_label", models.CharField(max_length=200)),
("instance_pk", models.CharField(max_length=200)),
("locked_until", models.DateTimeField(blank=True, null=True)),
("priority", models.IntegerField(default=0)),
("from_state", models.CharField(max_length=200)),
("to_state", models.CharField(max_length=200)),
("date", models.DateTimeField(auto_now_add=True)),
("error", models.TextField()),
("error_details", models.TextField(blank=True, null=True)),
],
),
]

Wyświetl plik

@ -1,14 +1,13 @@
import datetime
from functools import reduce
from typing import Type, cast
import traceback
from typing import ClassVar, List, Optional, Type, cast
from asgiref.sync import sync_to_async
from django.apps import apps
from django.db import models, transaction
from django.utils import timezone
from django.utils.functional import classproperty
from stator.graph import State, StateGraph
from stator.graph import State, StateGraph, Transition
class StateField(models.CharField):
@ -55,6 +54,9 @@ class StatorModel(models.Model):
concrete model yourself.
"""
# If this row is up for transition attempts
state_ready = models.BooleanField(default=False)
# When the state last actually changed, or the date of instance creation
state_changed = models.DateTimeField(auto_now_add=True)
@ -62,68 +64,128 @@ class StatorModel(models.Model):
# (and not successful, as this is cleared on transition)
state_attempted = models.DateTimeField(blank=True, null=True)
# If a lock is out on this row, when it is locked until
# (we don't identify the lock owner, as there's no heartbeats)
state_locked_until = models.DateTimeField(null=True, blank=True)
# Collection of subclasses of us
subclasses: ClassVar[List[Type["StatorModel"]]] = []
class Meta:
abstract = True
@classmethod
def schedule_overdue(cls, now=None) -> models.QuerySet:
"""
Finds instances of this model that need to run and schedule them.
"""
q = models.Q()
for transition in cls.state_graph.transitions(automatic_only=True):
q = q | transition.get_query(now=now)
return cls.objects.filter(q)
def __init_subclass__(cls) -> None:
if cls is not StatorModel:
cls.subclasses.append(cls)
@classproperty
def state_graph(cls) -> Type[StateGraph]:
return cls._meta.get_field("state").graph
def schedule_transition(self, priority: int = 0):
@classmethod
async def atransition_schedule_due(cls, now=None) -> models.QuerySet:
"""
Finds instances of this model that need to run and schedule them.
"""
q = models.Q()
for state in cls.state_graph.states.values():
state = cast(State, state)
if not state.terminal:
q = q | models.Q(
(
models.Q(
state_attempted__lte=timezone.now()
- datetime.timedelta(seconds=state.try_interval)
)
| models.Q(state_attempted__isnull=True)
),
state=state.name,
)
await cls.objects.filter(q).aupdate(state_ready=True)
@classmethod
def transition_get_with_lock(
cls, number: int, lock_expiry: datetime.datetime
) -> List["StatorModel"]:
"""
Returns up to `number` tasks for execution, having locked them.
"""
with transaction.atomic():
selected = list(
cls.objects.filter(state_locked_until__isnull=True, state_ready=True)[
:number
].select_for_update()
)
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
state_locked_until=timezone.now()
)
return selected
@classmethod
async def atransition_get_with_lock(
cls, number: int, lock_expiry: datetime.datetime
) -> List["StatorModel"]:
return await sync_to_async(cls.transition_get_with_lock)(number, lock_expiry)
@classmethod
async def atransition_clean_locks(cls):
await cls.objects.filter(state_locked_until__lte=timezone.now()).aupdate(
state_locked_until=None
)
def transition_schedule(self):
"""
Adds this instance to the queue to get its state transition attempted.
The scheduler will call this, but you can also call it directly if you
know it'll be ready and want to lower latency.
"""
StatorTask.schedule_for_execution(self, priority=priority)
self.state_ready = True
self.save()
async def attempt_transition(self):
async def atransition_attempt(self) -> bool:
"""
Attempts to transition the current state by running its handler(s).
"""
# Try each transition in priority order
for transition in self.state_graph.states[self.state].transitions(
automatic_only=True
):
success = await transition.get_handler()(self)
for transition in self.state.transitions(automatic_only=True):
try:
success = await transition.get_handler()(self)
except BaseException as e:
await StatorError.acreate_from_instance(self, transition, e)
traceback.print_exc()
continue
if success:
await self.perform_transition(transition.to_state.name)
return
await self.atransition_perform(transition.to_state.name)
return True
await self.__class__.objects.filter(pk=self.pk).aupdate(
state_attempted=timezone.now()
state_attempted=timezone.now(),
state_locked_until=None,
state_ready=False,
)
return False
async def perform_transition(self, state_name):
def transition_perform(self, state_name):
"""
Transitions the instance to the given state name
Transitions the instance to the given state name, forcibly.
"""
if state_name not in self.state_graph.states:
raise ValueError(f"Invalid state {state_name}")
await self.__class__.objects.filter(pk=self.pk).aupdate(
self.__class__.objects.filter(pk=self.pk).update(
state=state_name,
state_changed=timezone.now(),
state_attempted=None,
state_locked_until=None,
state_ready=False,
)
atransition_perform = sync_to_async(transition_perform)
class StatorTask(models.Model):
class StatorError(models.Model):
"""
The model that we use for an internal scheduling queue.
Entries in this queue are up for checking and execution - it also performs
locking to ensure we get closer to exactly-once execution (but we err on
the side of at-least-once)
Tracks any errors running the transitions.
Meant to be cleaned out regularly. Should probably be a log.
"""
# appname.modelname (lowercased) label for the model this represents
@ -132,60 +194,33 @@ class StatorTask(models.Model):
# The primary key of that model (probably int or str)
instance_pk = models.CharField(max_length=200)
# Locking columns (no runner ID, as we have no heartbeats - all runners
# only live for a short amount of time anyway)
locked_until = models.DateTimeField(null=True, blank=True)
# The state we moved from
from_state = models.CharField(max_length=200)
# Basic total ordering priority - higher is more important
priority = models.IntegerField(default=0)
# The state we moved to (or tried to)
to_state = models.CharField(max_length=200)
def __str__(self):
return f"#{self.pk}: {self.model_label}.{self.instance_pk}"
# When it happened
date = models.DateTimeField(auto_now_add=True)
# Error name
error = models.TextField()
# Error details
error_details = models.TextField(blank=True, null=True)
@classmethod
def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0):
# We don't do a transaction here as it's fine to occasionally double up
model_label = model_instance._meta.label_lower
pk = model_instance.pk
# TODO: Increase priority of existing if present
if not cls.objects.filter(
model_label=model_label, instance_pk=pk, locked__isnull=True
).exists():
StatorTask.objects.create(
model_label=model_label,
instance_pk=pk,
priority=priority,
)
@classmethod
def get_for_execution(cls, number: int, lock_expiry: datetime.datetime):
"""
Returns up to `number` tasks for execution, having locked them.
"""
with transaction.atomic():
selected = list(
cls.objects.filter(locked_until__isnull=True)[
:number
].select_for_update()
)
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
locked_until=timezone.now()
)
return selected
@classmethod
async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime):
return await sync_to_async(cls.get_for_execution)(number, lock_expiry)
@classmethod
async def aclean_old_locks(cls):
await cls.objects.filter(locked_until__lte=timezone.now()).aupdate(
locked_until=None
async def acreate_from_instance(
cls,
instance: StatorModel,
transition: Transition,
exception: Optional[BaseException] = None,
):
return await cls.objects.acreate(
model_label=instance._meta.label_lower,
instance_pk=str(instance.pk),
from_state=transition.from_state,
to_state=transition.to_state,
error=str(exception),
error_details=traceback.format_exc(),
)
async def aget_model_instance(self) -> StatorModel:
model = apps.get_model(self.model_label)
return cast(StatorModel, await model.objects.aget(pk=self.pk))
async def adelete(self):
self.__class__.objects.adelete(pk=self.pk)

Wyświetl plik

@ -4,11 +4,9 @@ import time
import uuid
from typing import List, Type
from asgiref.sync import sync_to_async
from django.db import transaction
from django.utils import timezone
from stator.models import StatorModel, StatorTask
from stator.models import StatorModel
class StatorRunner:
@ -22,6 +20,7 @@ class StatorRunner:
LOCK_TIMEOUT = 120
MAX_TASKS = 30
MAX_TASKS_PER_MODEL = 5
def __init__(self, models: List[Type[StatorModel]]):
self.models = models
@ -32,38 +31,44 @@ class StatorRunner:
self.handled = 0
self.tasks = []
# Clean up old locks
await StatorTask.aclean_old_locks()
# Examine what needs scheduling
print("Running initial cleaning and scheduling")
initial_tasks = []
for model in self.models:
initial_tasks.append(model.atransition_clean_locks())
initial_tasks.append(model.atransition_schedule_due())
await asyncio.gather(*initial_tasks)
# For the first time period, launch tasks
print("Running main task loop")
while (time.monotonic() - start_time) < self.START_TIMEOUT:
self.remove_completed_tasks()
space_remaining = self.MAX_TASKS - len(self.tasks)
# Fetch new tasks
if space_remaining > 0:
for new_task in await StatorTask.aget_for_execution(
space_remaining,
timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
):
self.tasks.append(asyncio.create_task(self.run_task(new_task)))
self.handled += 1
for model in self.models:
if space_remaining > 0:
for instance in await model.atransition_get_with_lock(
min(space_remaining, self.MAX_TASKS_PER_MODEL),
timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
):
print(
f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
)
self.tasks.append(
asyncio.create_task(instance.atransition_attempt())
)
self.handled += 1
space_remaining -= 1
# Prevent busylooping
await asyncio.sleep(0.01)
await asyncio.sleep(0.1)
# Then wait for tasks to finish
print("Waiting for tasks to complete")
while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
self.remove_completed_tasks()
if not self.tasks:
break
# Prevent busylooping
await asyncio.sleep(1)
print("Complete")
return self.handled
async def run_task(self, task: StatorTask):
# Resolve the model instance
model_instance = await task.aget_model_instance()
await model_instance.attempt_transition()
# Remove ourselves from the database as complete
await task.adelete()
def remove_completed_tasks(self):
self.tasks = [t for t in self.tasks if not t.done()]

Wyświetl plik

@ -51,14 +51,14 @@ def test_bad_declarations():
# More than one initial state
with pytest.raises(ValueError):
class TestGraph(StateGraph):
class TestGraph2(StateGraph):
initial = State()
initial2 = State()
# No initial states
with pytest.raises(ValueError):
class TestGraph(StateGraph):
class TestGraph3(StateGraph):
loop = State()
loop2 = State()

Wyświetl plik

@ -0,0 +1,23 @@
# Generated by Django 4.1.3 on 2022-11-10 03:24
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("users", "0004_remove_follow_state_locked_and_more"),
]
operations = [
migrations.AddField(
model_name="follow",
name="state_locked_until",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="follow",
name="state_ready",
field=models.BooleanField(default=False),
),
]

Wyświetl plik

@ -1,6 +1,6 @@
from .block import Block # noqa
from .domain import Domain # noqa
from .follow import Follow # noqa
from .identity import Identity # noqa
from .follow import Follow, FollowStates # noqa
from .identity import Identity, IdentityStates # noqa
from .user import User # noqa
from .user_event import UserEvent # noqa

Wyświetl plik

@ -55,7 +55,7 @@ class Domain(models.Model):
return cls.objects.create(domain=domain, local=False)
@classmethod
def get_local_domain(cls, domain: str) -> Optional["Domain"]:
def get_domain(cls, domain: str) -> Optional["Domain"]:
try:
return cls.objects.get(
models.Q(domain=domain) | models.Q(service_domain=domain)

Wyświetl plik

@ -6,13 +6,13 @@ from stator.models import State, StateField, StateGraph, StatorModel
class FollowStates(StateGraph):
pending = State(try_interval=3600)
pending = State(try_interval=30)
requested = State()
accepted = State()
@pending.add_transition(requested)
async def try_request(cls, instance):
print("Would have tried to follow")
async def try_request(instance: "Follow"): # type:ignore
print("Would have tried to follow on", instance)
return False
requested.add_manual_transition(accepted)
@ -73,11 +73,3 @@ class Follow(StatorModel):
follow.state = FollowStates.accepted
follow.save()
return follow
def undo(self):
"""
Undoes this follow
"""
if not self.target.local:
Task.submit("follow_undo", str(self.pk))
self.delete()

Wyświetl plik

@ -14,9 +14,21 @@ from django.utils import timezone
from OpenSSL import crypto
from core.ld import canonicalise
from stator.models import State, StateField, StateGraph, StatorModel
from users.models.domain import Domain
class IdentityStates(StateGraph):
outdated = State(try_interval=3600)
updated = State()
@outdated.add_transition(updated)
async def fetch_identity(identity: "Identity"): # type:ignore
if identity.local:
return True
return await identity.fetch_actor()
def upload_namer(prefix, instance, filename):
"""
Names uploaded images etc.
@ -26,7 +38,7 @@ def upload_namer(prefix, instance, filename):
return f"{prefix}/{now.year}/{now.month}/{now.day}/{filename}"
class Identity(models.Model):
class Identity(StatorModel):
"""
Represents both local and remote Fediverse identities (actors)
"""
@ -35,6 +47,8 @@ class Identity(models.Model):
# one around as well for making nice URLs etc.
actor_uri = models.CharField(max_length=500, unique=True)
state = StateField(IdentityStates)
local = models.BooleanField()
users = models.ManyToManyField("users.User", related_name="identities")

Wyświetl plik

@ -3,7 +3,7 @@ from django.http import Http404
from users.models import Domain, Identity
def by_handle_or_404(request, handle, local=True, fetch=False):
def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
"""
Retrieves an Identity by its long or short handle.
Domain-sensitive, so it will understand short handles on alternate domains.
@ -12,14 +12,17 @@ def by_handle_or_404(request, handle, local=True, fetch=False):
if "HTTP_HOST" not in request.META:
raise Http404("No hostname available")
username = handle
domain_instance = Domain.get_local_domain(request.META["HTTP_HOST"])
domain_instance = Domain.get_domain(request.META["HTTP_HOST"])
if domain_instance is None:
raise Http404("No matching domains found")
domain = domain_instance.domain
else:
username, domain = handle.split("@", 1)
# Resolve the domain to the display domain
domain = Domain.get_local_domain(request.META["HTTP_HOST"]).domain
domain_instance = Domain.get_domain(domain)
if domain_instance is None:
raise Http404("No matching domains found")
domain = domain_instance.domain
identity = Identity.by_username_and_domain(
username,
domain,

Wyświetl plik

@ -17,7 +17,7 @@ from core.forms import FormHelper
from core.ld import canonicalise
from core.signatures import HttpSignature
from users.decorators import identity_required
from users.models import Domain, Follow, Identity
from users.models import Domain, Follow, Identity, IdentityStates
from users.shortcuts import by_handle_or_404
@ -34,7 +34,7 @@ class ViewIdentity(TemplateView):
)
statuses = identity.statuses.all()[:100]
if identity.data_age > settings.IDENTITY_MAX_AGE:
Task.submit("identity_fetch", identity.handle)
identity.transition_perform(IdentityStates.outdated)
return {
"identity": identity,
"statuses": statuses,
@ -129,7 +129,7 @@ class CreateIdentity(FormView):
def form_valid(self, form):
username = form.cleaned_data["username"]
domain = form.cleaned_data["domain"]
domain_instance = Domain.get_local_domain(domain)
domain_instance = Domain.get_domain(domain)
new_identity = Identity.objects.create(
actor_uri=f"https://{domain_instance.uri_domain}/@{username}@{domain}/actor/",
username=username,