diff --git a/app/incoming_activities.py b/app/incoming_activities.py index da1bf26..ba4b186 100644 --- a/app/incoming_activities.py +++ b/app/incoming_activities.py @@ -69,13 +69,11 @@ def _set_next_try( async def fetch_next_incoming_activity( db_session: AsyncSession, - in_flight: set[int], ) -> models.IncomingActivity | None: where = [ models.IncomingActivity.next_try <= now(), models.IncomingActivity.is_errored.is_(False), models.IncomingActivity.is_processed.is_(False), - models.IncomingActivity.id.not_in(in_flight), ] q_count = await db_session.scalar( select(func.count(models.IncomingActivity.id)).where(*where) @@ -144,11 +142,11 @@ class IncomingActivityWorker(Worker[models.IncomingActivity]): self, db_session: AsyncSession, ) -> models.IncomingActivity | None: - return await fetch_next_incoming_activity(db_session, self.in_flight_ids()) + return await fetch_next_incoming_activity(db_session) async def loop() -> None: - await IncomingActivityWorker(workers_count=1).run_forever() + await IncomingActivityWorker().run_forever() if __name__ == "__main__": diff --git a/app/outgoing_activities.py b/app/outgoing_activities.py index 9f08de6..31bbb96 100644 --- a/app/outgoing_activities.py +++ b/app/outgoing_activities.py @@ -170,13 +170,11 @@ def _set_next_try( async def fetch_next_outgoing_activity( db_session: AsyncSession, - in_fligh: set[int], ) -> models.OutgoingActivity | None: where = [ models.OutgoingActivity.next_try <= now(), models.OutgoingActivity.is_errored.is_(False), models.OutgoingActivity.is_sent.is_(False), - models.OutgoingActivity.id.not_in(in_fligh), ] q_count = await db_session.scalar( select(func.count(models.OutgoingActivity.id)).where(*where) @@ -289,14 +287,14 @@ class OutgoingActivityWorker(Worker[models.OutgoingActivity]): self, db_session: AsyncSession, ) -> models.OutgoingActivity | None: - return await fetch_next_outgoing_activity(db_session, self.in_flight_ids()) + return await fetch_next_outgoing_activity(db_session) async def startup(self, db_session: AsyncSession) -> None: await _send_actor_update_if_needed(db_session) async def loop() -> None: - await OutgoingActivityWorker(workers_count=3).run_forever() + await OutgoingActivityWorker().run_forever() if __name__ == "__main__": diff --git a/app/utils/workers.py b/app/utils/workers.py index a817834..25ed331 100644 --- a/app/utils/workers.py +++ b/app/utils/workers.py @@ -12,30 +12,9 @@ T = TypeVar("T") class Worker(Generic[T]): - def __init__(self, workers_count: int) -> None: + def __init__(self) -> None: self._loop = asyncio.get_event_loop() - self._in_flight: set[int] = set() - self._queue: asyncio.Queue[T] = asyncio.Queue(maxsize=1) self._stop_event = asyncio.Event() - self._workers_count = workers_count - - async def _consumer(self, db_session: AsyncSession) -> None: - while not self._stop_event.is_set(): - message = await self._queue.get() - try: - await self.process_message(db_session, message) - finally: - self._in_flight.remove(message.id) # type: ignore - self._queue.task_done() - - async def _producer(self, db_session: AsyncSession) -> None: - while not self._stop_event.is_set(): - next_message = await self.get_next_message(db_session) - if next_message: - self._in_flight.add(next_message.id) # type: ignore - await self._queue.put(next_message) - else: - await asyncio.sleep(1) async def process_message(self, db_session: AsyncSession, message: T) -> None: raise NotImplementedError @@ -46,8 +25,16 @@ class Worker(Generic[T]): async def startup(self, db_session: AsyncSession) -> None: return None - def in_flight_ids(self) -> set[int]: - return self._in_flight + async def _main_loop(self, db_session: AsyncSession) -> None: + while not self._stop_event.is_set(): + next_message = await self.get_next_message(db_session) + if next_message: + await self.process_message(db_session, next_message) + else: + await asyncio.sleep(1) + + async def _until_stopped(self) -> None: + await self._stop_event.wait() async def run_forever(self) -> None: signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) @@ -59,13 +46,14 @@ class Worker(Generic[T]): async with async_session() as db_session: await self.startup(db_session) - self._loop.create_task(self._producer(db_session)) - for _ in range(self._workers_count): - self._loop.create_task(self._consumer(db_session)) + task = self._loop.create_task(self._main_loop(db_session)) + stop_task = self._loop.create_task(self._until_stopped()) - await self._stop_event.wait() - logger.info("Waiting for tasks to finish") - await self._queue.join() + done, pending = await asyncio.wait( + {task, stop_task}, return_when=asyncio.FIRST_COMPLETED + ) + logger.info(f"Waiting for tasks to finish {done=}/{pending=}") + await asyncio.sleep(5) tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] logger.info(f"Cancelling {len(tasks)} tasks") [task.cancel() for task in tasks] diff --git a/tests/test_inbox.py b/tests/test_inbox.py index 73aded9..1e83e95 100644 --- a/tests/test_inbox.py +++ b/tests/test_inbox.py @@ -24,7 +24,7 @@ from tests.utils import setup_remote_actor_as_follower async def _process_next_incoming_activity(db_session: AsyncSession) -> None: - next_activity = await fetch_next_incoming_activity(db_session, set()) + next_activity = await fetch_next_incoming_activity(db_session) assert next_activity await process_next_incoming_activity(db_session, next_activity) diff --git a/tests/test_process_outgoing_activities.py b/tests/test_process_outgoing_activities.py index 8b6e951..7da510e 100644 --- a/tests/test_process_outgoing_activities.py +++ b/tests/test_process_outgoing_activities.py @@ -70,7 +70,7 @@ async def test_process_next_outgoing_activity__no_next_activity( respx_mock: respx.MockRouter, async_db_session: AsyncSession, ) -> None: - next_activity = await fetch_next_outgoing_activity(async_db_session, set()) + next_activity = await fetch_next_outgoing_activity(async_db_session) assert next_activity is None @@ -94,7 +94,7 @@ async def test_process_next_outgoing_activity__server_200( # When processing the next outgoing activity # Then it is processed - next_activity = await fetch_next_outgoing_activity(async_db_session, set()) + next_activity = await fetch_next_outgoing_activity(async_db_session) assert next_activity await process_next_outgoing_activity(async_db_session, next_activity) @@ -129,7 +129,7 @@ async def test_process_next_outgoing_activity__webmention( # When processing the next outgoing activity # Then it is processed - next_activity = await fetch_next_outgoing_activity(async_db_session, set()) + next_activity = await fetch_next_outgoing_activity(async_db_session) assert next_activity await process_next_outgoing_activity(async_db_session, next_activity) @@ -165,7 +165,7 @@ async def test_process_next_outgoing_activity__error_500( # When processing the next outgoing activity # Then it is processed - next_activity = await fetch_next_outgoing_activity(async_db_session, set()) + next_activity = await fetch_next_outgoing_activity(async_db_session) assert next_activity await process_next_outgoing_activity(async_db_session, next_activity) @@ -203,7 +203,7 @@ async def test_process_next_outgoing_activity__errored( # When processing the next outgoing activity # Then it is processed - next_activity = await fetch_next_outgoing_activity(async_db_session, set()) + next_activity = await fetch_next_outgoing_activity(async_db_session) assert next_activity await process_next_outgoing_activity(async_db_session, next_activity) @@ -218,7 +218,7 @@ async def test_process_next_outgoing_activity__errored( assert outgoing_activity.is_errored is True # And it is skipped from processing - next_activity = await fetch_next_outgoing_activity(async_db_session, set()) + next_activity = await fetch_next_outgoing_activity(async_db_session) assert next_activity is None @@ -241,7 +241,7 @@ async def test_process_next_outgoing_activity__connect_error( # When processing the next outgoing activity # Then it is processed - next_activity = await fetch_next_outgoing_activity(async_db_session, set()) + next_activity = await fetch_next_outgoing_activity(async_db_session) assert next_activity await process_next_outgoing_activity(async_db_session, next_activity)