updating type-hinting to set generics for base, client and broker contexts

pull/212/head
Andrew Mirsky 2025-06-12 19:52:16 -04:00
rodzic 43efa4c829
commit 5bd7513e80
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
6 zmienionych plików z 42 dodań i 32 usunięć

Wyświetl plik

@ -107,7 +107,7 @@ class MQTTClient:
# Init plugins manager
context = ClientContext()
context.config = self.config
self.plugins_manager = PluginManager("amqtt.client.plugins", context)
self.plugins_manager: PluginManager[ClientContext] = PluginManager("amqtt.client.plugins", context)
self.client_tasks: deque[asyncio.Task[Any]] = deque()
async def connect(

Wyświetl plik

@ -1,5 +1,6 @@
import asyncio
from asyncio import AbstractEventLoop, Queue
from typing import TYPE_CHECKING
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.errors import MQTTError
@ -28,6 +29,8 @@ from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT
_MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
if TYPE_CHECKING:
from amqtt.broker import BrokerContext
class Subscription:
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
@ -41,10 +44,10 @@ class UnSubscription:
self.topics = topics
class BrokerProtocolHandler(ProtocolHandler):
class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]):
def __init__(
self,
plugins_manager: PluginManager,
plugins_manager: PluginManager["BrokerContext"],
session: Session | None = None,
loop: AbstractEventLoop | None = None,
) -> None:
@ -156,7 +159,7 @@ class BrokerProtocolHandler(ProtocolHandler):
cls,
reader: ReaderAdapter,
writer: WriterAdapter,
plugins_manager: PluginManager,
plugins_manager: PluginManager["BrokerContext"],
loop: asyncio.AbstractEventLoop | None = None,
) -> tuple["BrokerProtocolHandler", Session]:
"""Initialize from a CONNECT packet and validates the connection."""

Wyświetl plik

@ -1,5 +1,5 @@
import asyncio
from typing import Any
from typing import TYPE_CHECKING, Any
from amqtt.errors import AMQTTError
from amqtt.mqtt.connack import ConnackPacket
@ -15,11 +15,13 @@ from amqtt.mqtt.unsubscribe import UnsubscribePacket
from amqtt.plugins.manager import PluginManager
from amqtt.session import Session
if TYPE_CHECKING:
from amqtt.client import ClientContext
class ClientProtocolHandler(ProtocolHandler):
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
def __init__(
self,
plugins_manager: PluginManager,
plugins_manager: PluginManager["ClientContext"],
session: Session | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:

Wyświetl plik

@ -17,7 +17,7 @@ except ImportError:
import collections
import itertools
import logging
from typing import cast
from typing import Generic, TypeVar, cast
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError
@ -56,19 +56,20 @@ from amqtt.mqtt.suback import SubackPacket
from amqtt.mqtt.subscribe import SubscribePacket
from amqtt.mqtt.unsuback import UnsubackPacket
from amqtt.mqtt.unsubscribe import UnsubscribePacket
from amqtt.plugins.manager import PluginManager
from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session
EVENT_MQTT_PACKET_SENT = "mqtt_packet_sent"
EVENT_MQTT_PACKET_RECEIVED = "mqtt_packet_received"
C = TypeVar("C", bound=BaseContext)
class ProtocolHandler:
class ProtocolHandler(Generic[C]):
"""Class implementing the MQTT communication protocol using asyncio features."""
def __init__(
self,
plugins_manager: PluginManager,
plugins_manager: PluginManager[C],
session: Session | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
@ -79,7 +80,7 @@ class ProtocolHandler:
self.session: Session | None = None
self.reader: ReaderAdapter | None = None
self.writer: WriterAdapter | None = None
self.plugins_manager: PluginManager = plugins_manager
self.plugins_manager: PluginManager[C] = plugins_manager
try:
self._loop = loop if loop is not None else asyncio.get_running_loop()

Wyświetl plik

@ -4,12 +4,13 @@ import logging
from typing import TYPE_CHECKING, Any
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
if TYPE_CHECKING:
from amqtt.session import Session
class EventLoggerPlugin(BasePlugin):
class EventLoggerPlugin(BasePlugin[BaseContext]):
"""A plugin to log events dynamically based on method names."""
async def log_event(self, *args: Any, **kwargs: Any) -> None:
@ -25,7 +26,7 @@ class EventLoggerPlugin(BasePlugin):
raise AttributeError(msg)
class PacketLoggerPlugin(BasePlugin):
class PacketLoggerPlugin(BasePlugin[BaseContext]):
"""A plugin to log MQTT packets sent and received."""
async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None:

Wyświetl plik

@ -6,7 +6,7 @@ import contextlib
import copy
from importlib.metadata import EntryPoint, EntryPoints, entry_points
import logging
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
from amqtt.session import Session
@ -24,10 +24,10 @@ class Plugin(NamedTuple):
object: Any
plugins_manager: dict[str, "PluginManager"] = {}
plugins_manager: dict[str, "PluginManager[Any]"] = {}
def get_plugin_manager(namespace: str) -> "PluginManager | None":
def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None":
"""Get the plugin manager for a given namespace.
:param namespace: The namespace of the plugin manager to retrieve.
@ -43,14 +43,17 @@ class BaseContext:
self.config: dict[str, Any] | None = None
class PluginManager:
C = TypeVar("C", bound=BaseContext)
class PluginManager(Generic[C]):
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to
run plugin calls asynchronously in an event queue.
"""
def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
def __init__(self, namespace: str, context: C | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
try:
self._loop = loop if loop is not None else asyncio.get_running_loop()
except RuntimeError:
@ -60,7 +63,7 @@ class PluginManager:
self.logger = logging.getLogger(namespace)
self.context = context if context is not None else BaseContext()
self.context.loop = self._loop
self._plugins: list[BasePlugin] = []
self._plugins: list[BasePlugin[C]] = []
self._auth_plugins: list[BaseAuthPlugin] = []
self._topic_plugins: list[BaseTopicPlugin] = []
self._load_plugins(namespace)
@ -114,7 +117,7 @@ class PluginManager:
return None
def get_plugin(self, name: str) -> Optional["BasePlugin"]:
def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]:
"""Get a plugin by its name from the plugins loaded for the current namespace.
:param name:
@ -134,7 +137,7 @@ class PluginManager:
self._fired_events.clear()
@property
def plugins(self) -> list["BasePlugin"]:
def plugins(self) -> list["BasePlugin[C]"]:
"""Get the loaded plugins list.
:return:
@ -179,7 +182,7 @@ class PluginManager:
await asyncio.wait(tasks)
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
async def map_plugin_auth(self, session: Session) -> dict["BasePlugin", str | bool | None]:
async def map_plugin_auth(self, session: Session) -> dict["BasePlugin[C]", str | bool | None]:
"""Schedule a coroutine for plugin 'authenticate' calls.
:param session: the client session associated with the authentication check
@ -198,17 +201,17 @@ class PluginManager:
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {}
ret_dict: dict[BasePlugin[C], str | bool | None] = {}
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) # type: ignore[arg-type]
return ret_dict
async def map_plugin_topic(self,
session: Session, topic: str, action: "Action"
) -> dict["BasePlugin", str | bool | None]:
) -> dict["BasePlugin[C]", str | bool | None]:
"""Schedule a coroutine for plugin 'topic_filtering' calls.
:param session: the client session associated with the topic_filtering check
@ -229,21 +232,21 @@ class PluginManager:
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {}
ret_dict: dict[BasePlugin[C], str | bool | None] = {}
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict= dict(zip(self._auth_plugins, ret_list, strict=False))
ret_dict= dict(zip(self._topic_plugins, ret_list, strict=False)) # type: ignore[arg-type]
return ret_dict
async def map_plugin_close(self) -> dict["BasePlugin", str | bool | None]:
async def map_plugin_close(self) -> dict["BasePlugin[C]", str | bool | None]:
tasks: list[asyncio.Future[Any]] = []
for plugin in self._plugins:
async def close_coro(p: "BasePlugin") -> None:
async def close_coro(p: "BasePlugin[C]") -> None:
await p.close()
if not hasattr(plugin, "close"):
@ -252,10 +255,10 @@ class PluginManager:
coro_instance: Awaitable[str | bool | None] = close_coro(plugin)
tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {}
ret_dict: dict[BasePlugin[C], str | bool | None] = {}
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
ret_dict = dict(zip(self._plugins, ret_list, strict=False))
return ret_dict