From 5bd7513e8032c51d6d6fc258242b092807f794d9 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Thu, 12 Jun 2025 19:52:16 -0400 Subject: [PATCH] updating type-hinting to set generics for base, client and broker contexts --- amqtt/client.py | 2 +- amqtt/mqtt/protocol/broker_handler.py | 9 ++++--- amqtt/mqtt/protocol/client_handler.py | 8 +++--- amqtt/mqtt/protocol/handler.py | 11 ++++---- amqtt/plugins/logging_amqtt.py | 5 ++-- amqtt/plugins/manager.py | 39 ++++++++++++++------------- 6 files changed, 42 insertions(+), 32 deletions(-) diff --git a/amqtt/client.py b/amqtt/client.py index 8da9aba..f3250ca 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -107,7 +107,7 @@ class MQTTClient: # Init plugins manager context = ClientContext() context.config = self.config - self.plugins_manager = PluginManager("amqtt.client.plugins", context) + self.plugins_manager: PluginManager[ClientContext] = PluginManager("amqtt.client.plugins", context) self.client_tasks: deque[asyncio.Task[Any]] = deque() async def connect( diff --git a/amqtt/mqtt/protocol/broker_handler.py b/amqtt/mqtt/protocol/broker_handler.py index c23211e..c8cae86 100644 --- a/amqtt/mqtt/protocol/broker_handler.py +++ b/amqtt/mqtt/protocol/broker_handler.py @@ -1,5 +1,6 @@ import asyncio from asyncio import AbstractEventLoop, Queue +from typing import TYPE_CHECKING from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.errors import MQTTError @@ -28,6 +29,8 @@ from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT _MQTT_PROTOCOL_LEVEL_SUPPORTED = 4 +if TYPE_CHECKING: + from amqtt.broker import BrokerContext class Subscription: def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None: @@ -41,10 +44,10 @@ class UnSubscription: self.topics = topics -class BrokerProtocolHandler(ProtocolHandler): +class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]): def __init__( self, - plugins_manager: PluginManager, + plugins_manager: PluginManager["BrokerContext"], session: Session | None = None, loop: AbstractEventLoop | None = None, ) -> None: @@ -156,7 +159,7 @@ class BrokerProtocolHandler(ProtocolHandler): cls, reader: ReaderAdapter, writer: WriterAdapter, - plugins_manager: PluginManager, + plugins_manager: PluginManager["BrokerContext"], loop: asyncio.AbstractEventLoop | None = None, ) -> tuple["BrokerProtocolHandler", Session]: """Initialize from a CONNECT packet and validates the connection.""" diff --git a/amqtt/mqtt/protocol/client_handler.py b/amqtt/mqtt/protocol/client_handler.py index 32755f8..b307a0a 100644 --- a/amqtt/mqtt/protocol/client_handler.py +++ b/amqtt/mqtt/protocol/client_handler.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any +from typing import TYPE_CHECKING, Any from amqtt.errors import AMQTTError from amqtt.mqtt.connack import ConnackPacket @@ -15,11 +15,13 @@ from amqtt.mqtt.unsubscribe import UnsubscribePacket from amqtt.plugins.manager import PluginManager from amqtt.session import Session +if TYPE_CHECKING: + from amqtt.client import ClientContext -class ClientProtocolHandler(ProtocolHandler): +class ClientProtocolHandler(ProtocolHandler["ClientContext"]): def __init__( self, - plugins_manager: PluginManager, + plugins_manager: PluginManager["ClientContext"], session: Session | None = None, loop: asyncio.AbstractEventLoop | None = None, ) -> None: diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index 240c520..2153834 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -17,7 +17,7 @@ except ImportError: import collections import itertools import logging -from typing import cast +from typing import Generic, TypeVar, cast from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError @@ -56,19 +56,20 @@ from amqtt.mqtt.suback import SubackPacket from amqtt.mqtt.subscribe import SubscribePacket from amqtt.mqtt.unsuback import UnsubackPacket from amqtt.mqtt.unsubscribe import UnsubscribePacket -from amqtt.plugins.manager import PluginManager +from amqtt.plugins.manager import BaseContext, PluginManager from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session EVENT_MQTT_PACKET_SENT = "mqtt_packet_sent" EVENT_MQTT_PACKET_RECEIVED = "mqtt_packet_received" +C = TypeVar("C", bound=BaseContext) -class ProtocolHandler: +class ProtocolHandler(Generic[C]): """Class implementing the MQTT communication protocol using asyncio features.""" def __init__( self, - plugins_manager: PluginManager, + plugins_manager: PluginManager[C], session: Session | None = None, loop: asyncio.AbstractEventLoop | None = None, ) -> None: @@ -79,7 +80,7 @@ class ProtocolHandler: self.session: Session | None = None self.reader: ReaderAdapter | None = None self.writer: WriterAdapter | None = None - self.plugins_manager: PluginManager = plugins_manager + self.plugins_manager: PluginManager[C] = plugins_manager try: self._loop = loop if loop is not None else asyncio.get_running_loop() diff --git a/amqtt/plugins/logging_amqtt.py b/amqtt/plugins/logging_amqtt.py index 84cf6d7..13b0f20 100644 --- a/amqtt/plugins/logging_amqtt.py +++ b/amqtt/plugins/logging_amqtt.py @@ -4,12 +4,13 @@ import logging from typing import TYPE_CHECKING, Any from amqtt.plugins.base import BasePlugin +from amqtt.plugins.manager import BaseContext if TYPE_CHECKING: from amqtt.session import Session -class EventLoggerPlugin(BasePlugin): +class EventLoggerPlugin(BasePlugin[BaseContext]): """A plugin to log events dynamically based on method names.""" async def log_event(self, *args: Any, **kwargs: Any) -> None: @@ -25,7 +26,7 @@ class EventLoggerPlugin(BasePlugin): raise AttributeError(msg) -class PacketLoggerPlugin(BasePlugin): +class PacketLoggerPlugin(BasePlugin[BaseContext]): """A plugin to log MQTT packets sent and received.""" async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None: diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index c621e97..177c539 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -6,7 +6,7 @@ import contextlib import copy from importlib.metadata import EntryPoint, EntryPoints, entry_points import logging -from typing import TYPE_CHECKING, Any, NamedTuple, Optional +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar from amqtt.session import Session @@ -24,10 +24,10 @@ class Plugin(NamedTuple): object: Any -plugins_manager: dict[str, "PluginManager"] = {} +plugins_manager: dict[str, "PluginManager[Any]"] = {} -def get_plugin_manager(namespace: str) -> "PluginManager | None": +def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None": """Get the plugin manager for a given namespace. :param namespace: The namespace of the plugin manager to retrieve. @@ -43,14 +43,17 @@ class BaseContext: self.config: dict[str, Any] | None = None -class PluginManager: +C = TypeVar("C", bound=BaseContext) + + +class PluginManager(Generic[C]): """Wraps contextlib Entry point mechanism to provide a basic plugin system. Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to run plugin calls asynchronously in an event queue. """ - def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None: + def __init__(self, namespace: str, context: C | None, loop: asyncio.AbstractEventLoop | None = None) -> None: try: self._loop = loop if loop is not None else asyncio.get_running_loop() except RuntimeError: @@ -60,7 +63,7 @@ class PluginManager: self.logger = logging.getLogger(namespace) self.context = context if context is not None else BaseContext() self.context.loop = self._loop - self._plugins: list[BasePlugin] = [] + self._plugins: list[BasePlugin[C]] = [] self._auth_plugins: list[BaseAuthPlugin] = [] self._topic_plugins: list[BaseTopicPlugin] = [] self._load_plugins(namespace) @@ -114,7 +117,7 @@ class PluginManager: return None - def get_plugin(self, name: str) -> Optional["BasePlugin"]: + def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]: """Get a plugin by its name from the plugins loaded for the current namespace. :param name: @@ -134,7 +137,7 @@ class PluginManager: self._fired_events.clear() @property - def plugins(self) -> list["BasePlugin"]: + def plugins(self) -> list["BasePlugin[C]"]: """Get the loaded plugins list. :return: @@ -179,7 +182,7 @@ class PluginManager: await asyncio.wait(tasks) self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}") - async def map_plugin_auth(self, session: Session) -> dict["BasePlugin", str | bool | None]: + async def map_plugin_auth(self, session: Session) -> dict["BasePlugin[C]", str | bool | None]: """Schedule a coroutine for plugin 'authenticate' calls. :param session: the client session associated with the authentication check @@ -198,17 +201,17 @@ class PluginManager: coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session) tasks.append(asyncio.ensure_future(coro_instance)) - ret_dict: dict[BasePlugin, str | bool | None] = {} + ret_dict: dict[BasePlugin[C], str | bool | None] = {} if tasks: ret_list = await asyncio.gather(*tasks) # Create result map plugin => ret - ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) + ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) # type: ignore[arg-type] return ret_dict async def map_plugin_topic(self, session: Session, topic: str, action: "Action" - ) -> dict["BasePlugin", str | bool | None]: + ) -> dict["BasePlugin[C]", str | bool | None]: """Schedule a coroutine for plugin 'topic_filtering' calls. :param session: the client session associated with the topic_filtering check @@ -229,21 +232,21 @@ class PluginManager: coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action) tasks.append(asyncio.ensure_future(coro_instance)) - ret_dict: dict[BasePlugin, str | bool | None] = {} + ret_dict: dict[BasePlugin[C], str | bool | None] = {} if tasks: ret_list = await asyncio.gather(*tasks) # Create result map plugin => ret - ret_dict= dict(zip(self._auth_plugins, ret_list, strict=False)) + ret_dict= dict(zip(self._topic_plugins, ret_list, strict=False)) # type: ignore[arg-type] return ret_dict - async def map_plugin_close(self) -> dict["BasePlugin", str | bool | None]: + async def map_plugin_close(self) -> dict["BasePlugin[C]", str | bool | None]: tasks: list[asyncio.Future[Any]] = [] for plugin in self._plugins: - async def close_coro(p: "BasePlugin") -> None: + async def close_coro(p: "BasePlugin[C]") -> None: await p.close() if not hasattr(plugin, "close"): @@ -252,10 +255,10 @@ class PluginManager: coro_instance: Awaitable[str | bool | None] = close_coro(plugin) tasks.append(asyncio.ensure_future(coro_instance)) - ret_dict: dict[BasePlugin, str | bool | None] = {} + ret_dict: dict[BasePlugin[C], str | bool | None] = {} if tasks: ret_list = await asyncio.gather(*tasks) # Create result map plugin => ret - ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) + ret_dict = dict(zip(self._plugins, ret_list, strict=False)) return ret_dict