From b0cf416c365cedb720d97d9ce378f637a657bc51 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Sat, 14 Jun 2025 21:28:00 -0400 Subject: [PATCH] using a strenum instead of enum. creating events as well as client/broker events as well --- amqtt/__init__.py | 25 ++++++++++++++++++++++++- amqtt/broker.py | 33 ++++++++++----------------------- amqtt/plugins/manager.py | 6 +++--- tests/test_broker.py | 27 ++++++++++++++------------- tests/test_paho.py | 6 +++--- 5 files changed, 54 insertions(+), 43 deletions(-) diff --git a/amqtt/__init__.py b/amqtt/__init__.py index 18a5553..0ae983c 100644 --- a/amqtt/__init__.py +++ b/amqtt/__init__.py @@ -1,3 +1,26 @@ """INIT.""" -__version__ = "0.11.0-rc" +__version__ = "0.11.0" + +from enum import StrEnum + + +class Events(StrEnum): + """Class for all events.""" + +class ClientEvents(Events): + """Events issued by the client.""" + + +class BrokerEvents(Events): + """Events issued by the broker.""" + + PRE_START = "broker_pre_start" + POST_START = "broker_post_start" + PRE_SHUTDOWN = "broker_pre_shutdown" + POST_SHUTDOWN = "broker_post_shutdown" + CLIENT_CONNECTED = "broker_client_connected" + CLIENT_DISCONNECTED = "broker_client_disconnected" + CLIENT_SUBSCRIBED = "broker_client_subscribed" + CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed" + MESSAGE_RECEIVED = "broker_message_received" diff --git a/amqtt/broker.py b/amqtt/broker.py index b8e2fd4..b06c310 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -28,6 +28,7 @@ from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session from amqtt.utils import format_client_message, gen_client_id, read_yaml_config +from . import BrokerEvents from .mqtt.disconnect import DisconnectPacket from .plugins.manager import BaseContext, PluginManager @@ -43,20 +44,6 @@ DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 -class EventBroker(Enum): - """Events issued by the broker.""" - - PRE_START = "broker_pre_start" - POST_START = "broker_post_start" - PRE_SHUTDOWN = "broker_pre_shutdown" - POST_SHUTDOWN = "broker_post_shutdown" - CLIENT_CONNECTED = "broker_client_connected" - CLIENT_DISCONNECTED = "broker_client_disconnected" - CLIENT_SUBSCRIBED = "broker_client_subscribed" - CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed" - MESSAGE_RECEIVED = "broker_message_received" - - class Action(Enum): """Actions issued by the broker.""" @@ -252,11 +239,11 @@ class Broker: msg = f"Broker instance can't be started: {exc}" raise BrokerError(msg) from exc - await self.plugins_manager.fire_event(EventBroker.PRE_START.value) + await self.plugins_manager.fire_event(BrokerEvents.PRE_START) try: await self._start_listeners() self.transitions.starting_success() - await self.plugins_manager.fire_event(EventBroker.POST_START.value) + await self.plugins_manager.fire_event(BrokerEvents.POST_START) self._broadcast_task = asyncio.ensure_future(self._broadcast_loop()) self.logger.debug("Broker started") except Exception as e: @@ -337,7 +324,7 @@ class Broker: """Stop broker instance.""" self.logger.info("Shutting down broker...") # Fire broker_shutdown event to plugins - await self.plugins_manager.fire_event(EventBroker.PRE_SHUTDOWN.value) + await self.plugins_manager.fire_event(BrokerEvents.PRE_SHUTDOWN) # Cleanup all sessions for client_id in list(self._sessions.keys()): @@ -361,7 +348,7 @@ class Broker: self._broadcast_queue.get_nowait() self.logger.info("Broker closed") - await self.plugins_manager.fire_event(EventBroker.POST_SHUTDOWN.value) + await self.plugins_manager.fire_event(BrokerEvents.POST_SHUTDOWN) self.transitions.stopping_success() async def _cleanup_session(self, client_id: str) -> None: @@ -504,7 +491,7 @@ class Broker: self._sessions[client_session.client_id] = (client_session, handler) await handler.mqtt_connack_authorize(authenticated) - await self.plugins_manager.fire_event(EventBroker.CLIENT_CONNECTED.value, client_id=client_session.client_id) + await self.plugins_manager.fire_event(BrokerEvents.CLIENT_CONNECTED, client_id=client_session.client_id) self.logger.debug(f"{client_session.client_id} Start messages handling") await handler.start() @@ -607,7 +594,7 @@ class Broker: self.logger.debug(f"{client_session.client_id} Disconnecting session") await self._stop_handler(handler) client_session.transitions.disconnect() - await self.plugins_manager.fire_event(EventBroker.CLIENT_DISCONNECTED.value, client_id=client_session.client_id) + await self.plugins_manager.fire_event(BrokerEvents.CLIENT_DISCONNECTED, client_id=client_session.client_id) async def _handle_subscription( @@ -624,7 +611,7 @@ class Broker: for index, subscription in enumerate(subscriptions.topics): if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED: await self.plugins_manager.fire_event( - EventBroker.CLIENT_SUBSCRIBED.value, + BrokerEvents.CLIENT_SUBSCRIBED, client_id=client_session.client_id, topic=subscription[0], qos=subscription[1], @@ -643,7 +630,7 @@ class Broker: for topic in unsubscription.topics: self._del_subscription(topic, client_session) await self.plugins_manager.fire_event( - EventBroker.CLIENT_UNSUBSCRIBED.value, + BrokerEvents.CLIENT_UNSUBSCRIBED, client_id=client_session.client_id, topic=topic, ) @@ -678,7 +665,7 @@ class Broker: self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.") else: await self.plugins_manager.fire_event( - EventBroker.MESSAGE_RECEIVED.value, + BrokerEvents.MESSAGE_RECEIVED, client_id=client_session.client_id, message=app_message, ) diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index 39e9bb0..9bb4d36 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -8,9 +8,9 @@ from importlib.metadata import EntryPoint, EntryPoints, entry_points import logging from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar -from amqtt.session import Session - +from amqtt import Events from amqtt.errors import PluginImportError, PluginInitError +from amqtt.session import Session _LOGGER = logging.getLogger(__name__) @@ -153,7 +153,7 @@ class PluginManager(Generic[C]): def _schedule_coro(self, coro: Awaitable[str | bool | None]) -> asyncio.Future[str | bool | None]: return asyncio.ensure_future(coro) - async def fire_event(self, event_name: str, *args: Any, wait: bool = False, **kwargs: Any) -> None: + async def fire_event(self, event_name: Events, *args: Any, wait: bool = False, **kwargs: Any) -> None: """Fire an event to plugins. PluginManager schedules async calls for each plugin on method called "on_" + event_name. diff --git a/tests/test_broker.py b/tests/test_broker.py index 3591112..140e005 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -6,8 +6,9 @@ from unittest.mock import MagicMock, call, patch import psutil import pytest +from amqtt import BrokerEvents from amqtt.adapters import StreamReaderAdapter, StreamWriterAdapter -from amqtt.broker import EventBroker, Broker +from amqtt.broker import Broker from amqtt.client import MQTTClient from amqtt.errors import ConnectError from amqtt.mqtt.connack import ConnackPacket @@ -56,8 +57,8 @@ def test_split_bindaddr_port(input_str, output_addr, output_port): async def test_start_stop(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ - call().fire_event(EventBroker.PRE_START.value), - call().fire_event(EventBroker.POST_START.value), + call().fire_event(BrokerEvents.PRE_START), + call().fire_event(BrokerEvents.POST_START), ], any_order=True, ) @@ -65,8 +66,8 @@ async def test_start_stop(broker, mock_plugin_manager): await broker.shutdown() mock_plugin_manager.assert_has_calls( [ - call().fire_event(EventBroker.PRE_SHUTDOWN.value), - call().fire_event(EventBroker.POST_SHUTDOWN.value), + call().fire_event(BrokerEvents.PRE_SHUTDOWN), + call().fire_event(BrokerEvents.POST_SHUTDOWN), ], any_order=True, ) @@ -87,11 +88,11 @@ async def test_client_connect(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EventBroker.CLIENT_CONNECTED.value, + BrokerEvents.CLIENT_CONNECTED, client_id=client.session.client_id, ), call().fire_event( - EventBroker.CLIENT_DISCONNECTED.value, + BrokerEvents.CLIENT_DISCONNECTED, client_id=client.session.client_id, ), ], @@ -224,7 +225,7 @@ async def test_client_subscribe(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EventBroker.CLIENT_SUBSCRIBED.value, + BrokerEvents.CLIENT_SUBSCRIBED, client_id=client.session.client_id, topic="/topic", qos=QOS_0, @@ -261,7 +262,7 @@ async def test_client_subscribe_twice(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EventBroker.CLIENT_SUBSCRIBED.value, + BrokerEvents.CLIENT_SUBSCRIBED, client_id=client.session.client_id, topic="/topic", qos=QOS_0, @@ -295,13 +296,13 @@ async def test_client_unsubscribe(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EventBroker.CLIENT_SUBSCRIBED.value, + BrokerEvents.CLIENT_SUBSCRIBED, client_id=client.session.client_id, topic="/topic", qos=QOS_0, ), call().fire_event( - EventBroker.CLIENT_UNSUBSCRIBED.value, + BrokerEvents.CLIENT_UNSUBSCRIBED, client_id=client.session.client_id, topic="/topic", ), @@ -326,7 +327,7 @@ async def test_client_publish(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EventBroker.MESSAGE_RECEIVED.value, + BrokerEvents.MESSAGE_RECEIVED, client_id=pub_client.session.client_id, message=ret_message, ), @@ -498,7 +499,7 @@ async def test_client_publish_big(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EventBroker.MESSAGE_RECEIVED.value, + BrokerEvents.MESSAGE_RECEIVED, client_id=pub_client.session.client_id, message=ret_message, ), diff --git a/tests/test_paho.py b/tests/test_paho.py index 2ba8936..09d78db 100644 --- a/tests/test_paho.py +++ b/tests/test_paho.py @@ -6,7 +6,7 @@ from unittest.mock import MagicMock, call, patch import pytest from paho.mqtt import client as mqtt_client -from amqtt.broker import EventBroker +from amqtt.broker import BrokerEvents from amqtt.client import MQTTClient from amqtt.mqtt.constants import QOS_1, QOS_2 @@ -53,11 +53,11 @@ async def test_paho_connect(broker, mock_plugin_manager): broker.plugins_manager.assert_has_calls( [ call.fire_event( - EventBroker.CLIENT_CONNECTED.value, + BrokerEvents.CLIENT_CONNECTED, client_id=client_id, ), call.fire_event( - EventBroker.CLIENT_DISCONNECTED.value, + BrokerEvents.CLIENT_DISCONNECTED, client_id=client_id, ), ],