updating type-hinting to set generics for base, client and broker contexts

pull/212/head
Andrew Mirsky 2025-06-12 19:52:16 -04:00
rodzic 43efa4c829
commit 5bd7513e80
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
6 zmienionych plików z 42 dodań i 32 usunięć

Wyświetl plik

@ -107,7 +107,7 @@ class MQTTClient:
# Init plugins manager # Init plugins manager
context = ClientContext() context = ClientContext()
context.config = self.config context.config = self.config
self.plugins_manager = PluginManager("amqtt.client.plugins", context) self.plugins_manager: PluginManager[ClientContext] = PluginManager("amqtt.client.plugins", context)
self.client_tasks: deque[asyncio.Task[Any]] = deque() self.client_tasks: deque[asyncio.Task[Any]] = deque()
async def connect( async def connect(

Wyświetl plik

@ -1,5 +1,6 @@
import asyncio import asyncio
from asyncio import AbstractEventLoop, Queue from asyncio import AbstractEventLoop, Queue
from typing import TYPE_CHECKING
from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.errors import MQTTError from amqtt.errors import MQTTError
@ -28,6 +29,8 @@ from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT
_MQTT_PROTOCOL_LEVEL_SUPPORTED = 4 _MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
if TYPE_CHECKING:
from amqtt.broker import BrokerContext
class Subscription: class Subscription:
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None: def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
@ -41,10 +44,10 @@ class UnSubscription:
self.topics = topics self.topics = topics
class BrokerProtocolHandler(ProtocolHandler): class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]):
def __init__( def __init__(
self, self,
plugins_manager: PluginManager, plugins_manager: PluginManager["BrokerContext"],
session: Session | None = None, session: Session | None = None,
loop: AbstractEventLoop | None = None, loop: AbstractEventLoop | None = None,
) -> None: ) -> None:
@ -156,7 +159,7 @@ class BrokerProtocolHandler(ProtocolHandler):
cls, cls,
reader: ReaderAdapter, reader: ReaderAdapter,
writer: WriterAdapter, writer: WriterAdapter,
plugins_manager: PluginManager, plugins_manager: PluginManager["BrokerContext"],
loop: asyncio.AbstractEventLoop | None = None, loop: asyncio.AbstractEventLoop | None = None,
) -> tuple["BrokerProtocolHandler", Session]: ) -> tuple["BrokerProtocolHandler", Session]:
"""Initialize from a CONNECT packet and validates the connection.""" """Initialize from a CONNECT packet and validates the connection."""

Wyświetl plik

@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import Any from typing import TYPE_CHECKING, Any
from amqtt.errors import AMQTTError from amqtt.errors import AMQTTError
from amqtt.mqtt.connack import ConnackPacket from amqtt.mqtt.connack import ConnackPacket
@ -15,11 +15,13 @@ from amqtt.mqtt.unsubscribe import UnsubscribePacket
from amqtt.plugins.manager import PluginManager from amqtt.plugins.manager import PluginManager
from amqtt.session import Session from amqtt.session import Session
if TYPE_CHECKING:
from amqtt.client import ClientContext
class ClientProtocolHandler(ProtocolHandler): class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
def __init__( def __init__(
self, self,
plugins_manager: PluginManager, plugins_manager: PluginManager["ClientContext"],
session: Session | None = None, session: Session | None = None,
loop: asyncio.AbstractEventLoop | None = None, loop: asyncio.AbstractEventLoop | None = None,
) -> None: ) -> None:

Wyświetl plik

@ -17,7 +17,7 @@ except ImportError:
import collections import collections
import itertools import itertools
import logging import logging
from typing import cast from typing import Generic, TypeVar, cast
from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError
@ -56,19 +56,20 @@ from amqtt.mqtt.suback import SubackPacket
from amqtt.mqtt.subscribe import SubscribePacket from amqtt.mqtt.subscribe import SubscribePacket
from amqtt.mqtt.unsuback import UnsubackPacket from amqtt.mqtt.unsuback import UnsubackPacket
from amqtt.mqtt.unsubscribe import UnsubscribePacket from amqtt.mqtt.unsubscribe import UnsubscribePacket
from amqtt.plugins.manager import PluginManager from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session
EVENT_MQTT_PACKET_SENT = "mqtt_packet_sent" EVENT_MQTT_PACKET_SENT = "mqtt_packet_sent"
EVENT_MQTT_PACKET_RECEIVED = "mqtt_packet_received" EVENT_MQTT_PACKET_RECEIVED = "mqtt_packet_received"
C = TypeVar("C", bound=BaseContext)
class ProtocolHandler: class ProtocolHandler(Generic[C]):
"""Class implementing the MQTT communication protocol using asyncio features.""" """Class implementing the MQTT communication protocol using asyncio features."""
def __init__( def __init__(
self, self,
plugins_manager: PluginManager, plugins_manager: PluginManager[C],
session: Session | None = None, session: Session | None = None,
loop: asyncio.AbstractEventLoop | None = None, loop: asyncio.AbstractEventLoop | None = None,
) -> None: ) -> None:
@ -79,7 +80,7 @@ class ProtocolHandler:
self.session: Session | None = None self.session: Session | None = None
self.reader: ReaderAdapter | None = None self.reader: ReaderAdapter | None = None
self.writer: WriterAdapter | None = None self.writer: WriterAdapter | None = None
self.plugins_manager: PluginManager = plugins_manager self.plugins_manager: PluginManager[C] = plugins_manager
try: try:
self._loop = loop if loop is not None else asyncio.get_running_loop() self._loop = loop if loop is not None else asyncio.get_running_loop()

Wyświetl plik

@ -4,12 +4,13 @@ import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from amqtt.plugins.base import BasePlugin from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
if TYPE_CHECKING: if TYPE_CHECKING:
from amqtt.session import Session from amqtt.session import Session
class EventLoggerPlugin(BasePlugin): class EventLoggerPlugin(BasePlugin[BaseContext]):
"""A plugin to log events dynamically based on method names.""" """A plugin to log events dynamically based on method names."""
async def log_event(self, *args: Any, **kwargs: Any) -> None: async def log_event(self, *args: Any, **kwargs: Any) -> None:
@ -25,7 +26,7 @@ class EventLoggerPlugin(BasePlugin):
raise AttributeError(msg) raise AttributeError(msg)
class PacketLoggerPlugin(BasePlugin): class PacketLoggerPlugin(BasePlugin[BaseContext]):
"""A plugin to log MQTT packets sent and received.""" """A plugin to log MQTT packets sent and received."""
async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None: async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None:

Wyświetl plik

@ -6,7 +6,7 @@ import contextlib
import copy import copy
from importlib.metadata import EntryPoint, EntryPoints, entry_points from importlib.metadata import EntryPoint, EntryPoints, entry_points
import logging import logging
from typing import TYPE_CHECKING, Any, NamedTuple, Optional from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
from amqtt.session import Session from amqtt.session import Session
@ -24,10 +24,10 @@ class Plugin(NamedTuple):
object: Any object: Any
plugins_manager: dict[str, "PluginManager"] = {} plugins_manager: dict[str, "PluginManager[Any]"] = {}
def get_plugin_manager(namespace: str) -> "PluginManager | None": def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None":
"""Get the plugin manager for a given namespace. """Get the plugin manager for a given namespace.
:param namespace: The namespace of the plugin manager to retrieve. :param namespace: The namespace of the plugin manager to retrieve.
@ -43,14 +43,17 @@ class BaseContext:
self.config: dict[str, Any] | None = None self.config: dict[str, Any] | None = None
class PluginManager: C = TypeVar("C", bound=BaseContext)
class PluginManager(Generic[C]):
"""Wraps contextlib Entry point mechanism to provide a basic plugin system. """Wraps contextlib Entry point mechanism to provide a basic plugin system.
Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to
run plugin calls asynchronously in an event queue. run plugin calls asynchronously in an event queue.
""" """
def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None: def __init__(self, namespace: str, context: C | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
try: try:
self._loop = loop if loop is not None else asyncio.get_running_loop() self._loop = loop if loop is not None else asyncio.get_running_loop()
except RuntimeError: except RuntimeError:
@ -60,7 +63,7 @@ class PluginManager:
self.logger = logging.getLogger(namespace) self.logger = logging.getLogger(namespace)
self.context = context if context is not None else BaseContext() self.context = context if context is not None else BaseContext()
self.context.loop = self._loop self.context.loop = self._loop
self._plugins: list[BasePlugin] = [] self._plugins: list[BasePlugin[C]] = []
self._auth_plugins: list[BaseAuthPlugin] = [] self._auth_plugins: list[BaseAuthPlugin] = []
self._topic_plugins: list[BaseTopicPlugin] = [] self._topic_plugins: list[BaseTopicPlugin] = []
self._load_plugins(namespace) self._load_plugins(namespace)
@ -114,7 +117,7 @@ class PluginManager:
return None return None
def get_plugin(self, name: str) -> Optional["BasePlugin"]: def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]:
"""Get a plugin by its name from the plugins loaded for the current namespace. """Get a plugin by its name from the plugins loaded for the current namespace.
:param name: :param name:
@ -134,7 +137,7 @@ class PluginManager:
self._fired_events.clear() self._fired_events.clear()
@property @property
def plugins(self) -> list["BasePlugin"]: def plugins(self) -> list["BasePlugin[C]"]:
"""Get the loaded plugins list. """Get the loaded plugins list.
:return: :return:
@ -179,7 +182,7 @@ class PluginManager:
await asyncio.wait(tasks) await asyncio.wait(tasks)
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}") self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
async def map_plugin_auth(self, session: Session) -> dict["BasePlugin", str | bool | None]: async def map_plugin_auth(self, session: Session) -> dict["BasePlugin[C]", str | bool | None]:
"""Schedule a coroutine for plugin 'authenticate' calls. """Schedule a coroutine for plugin 'authenticate' calls.
:param session: the client session associated with the authentication check :param session: the client session associated with the authentication check
@ -198,17 +201,17 @@ class PluginManager:
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session) coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
tasks.append(asyncio.ensure_future(coro_instance)) tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {} ret_dict: dict[BasePlugin[C], str | bool | None] = {}
if tasks: if tasks:
ret_list = await asyncio.gather(*tasks) ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret # Create result map plugin => ret
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) # type: ignore[arg-type]
return ret_dict return ret_dict
async def map_plugin_topic(self, async def map_plugin_topic(self,
session: Session, topic: str, action: "Action" session: Session, topic: str, action: "Action"
) -> dict["BasePlugin", str | bool | None]: ) -> dict["BasePlugin[C]", str | bool | None]:
"""Schedule a coroutine for plugin 'topic_filtering' calls. """Schedule a coroutine for plugin 'topic_filtering' calls.
:param session: the client session associated with the topic_filtering check :param session: the client session associated with the topic_filtering check
@ -229,21 +232,21 @@ class PluginManager:
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action) coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
tasks.append(asyncio.ensure_future(coro_instance)) tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {} ret_dict: dict[BasePlugin[C], str | bool | None] = {}
if tasks: if tasks:
ret_list = await asyncio.gather(*tasks) ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret # Create result map plugin => ret
ret_dict= dict(zip(self._auth_plugins, ret_list, strict=False)) ret_dict= dict(zip(self._topic_plugins, ret_list, strict=False)) # type: ignore[arg-type]
return ret_dict return ret_dict
async def map_plugin_close(self) -> dict["BasePlugin", str | bool | None]: async def map_plugin_close(self) -> dict["BasePlugin[C]", str | bool | None]:
tasks: list[asyncio.Future[Any]] = [] tasks: list[asyncio.Future[Any]] = []
for plugin in self._plugins: for plugin in self._plugins:
async def close_coro(p: "BasePlugin") -> None: async def close_coro(p: "BasePlugin[C]") -> None:
await p.close() await p.close()
if not hasattr(plugin, "close"): if not hasattr(plugin, "close"):
@ -252,10 +255,10 @@ class PluginManager:
coro_instance: Awaitable[str | bool | None] = close_coro(plugin) coro_instance: Awaitable[str | bool | None] = close_coro(plugin)
tasks.append(asyncio.ensure_future(coro_instance)) tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {} ret_dict: dict[BasePlugin[C], str | bool | None] = {}
if tasks: if tasks:
ret_list = await asyncio.gather(*tasks) ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret # Create result map plugin => ret
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) ret_dict = dict(zip(self._plugins, ret_list, strict=False))
return ret_dict return ret_dict