kopia lustrzana https://github.com/Yakifo/amqtt
updating type-hinting to set generics for base, client and broker contexts
rodzic
43efa4c829
commit
5bd7513e80
|
@ -107,7 +107,7 @@ class MQTTClient:
|
|||
# Init plugins manager
|
||||
context = ClientContext()
|
||||
context.config = self.config
|
||||
self.plugins_manager = PluginManager("amqtt.client.plugins", context)
|
||||
self.plugins_manager: PluginManager[ClientContext] = PluginManager("amqtt.client.plugins", context)
|
||||
self.client_tasks: deque[asyncio.Task[Any]] = deque()
|
||||
|
||||
async def connect(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
from asyncio import AbstractEventLoop, Queue
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.errors import MQTTError
|
||||
|
@ -28,6 +29,8 @@ from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT
|
|||
|
||||
_MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.broker import BrokerContext
|
||||
|
||||
class Subscription:
|
||||
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
|
||||
|
@ -41,10 +44,10 @@ class UnSubscription:
|
|||
self.topics = topics
|
||||
|
||||
|
||||
class BrokerProtocolHandler(ProtocolHandler):
|
||||
class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]):
|
||||
def __init__(
|
||||
self,
|
||||
plugins_manager: PluginManager,
|
||||
plugins_manager: PluginManager["BrokerContext"],
|
||||
session: Session | None = None,
|
||||
loop: AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
|
@ -156,7 +159,7 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
cls,
|
||||
reader: ReaderAdapter,
|
||||
writer: WriterAdapter,
|
||||
plugins_manager: PluginManager,
|
||||
plugins_manager: PluginManager["BrokerContext"],
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> tuple["BrokerProtocolHandler", Session]:
|
||||
"""Initialize from a CONNECT packet and validates the connection."""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.connack import ConnackPacket
|
||||
|
@ -15,11 +15,13 @@ from amqtt.mqtt.unsubscribe import UnsubscribePacket
|
|||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.session import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.client import ClientContext
|
||||
|
||||
class ClientProtocolHandler(ProtocolHandler):
|
||||
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
|
||||
def __init__(
|
||||
self,
|
||||
plugins_manager: PluginManager,
|
||||
plugins_manager: PluginManager["ClientContext"],
|
||||
session: Session | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
|
|
|
@ -17,7 +17,7 @@ except ImportError:
|
|||
import collections
|
||||
import itertools
|
||||
import logging
|
||||
from typing import cast
|
||||
from typing import Generic, TypeVar, cast
|
||||
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError
|
||||
|
@ -56,19 +56,20 @@ from amqtt.mqtt.suback import SubackPacket
|
|||
from amqtt.mqtt.subscribe import SubscribePacket
|
||||
from amqtt.mqtt.unsuback import UnsubackPacket
|
||||
from amqtt.mqtt.unsubscribe import UnsubscribePacket
|
||||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.plugins.manager import BaseContext, PluginManager
|
||||
from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session
|
||||
|
||||
EVENT_MQTT_PACKET_SENT = "mqtt_packet_sent"
|
||||
EVENT_MQTT_PACKET_RECEIVED = "mqtt_packet_received"
|
||||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
class ProtocolHandler:
|
||||
class ProtocolHandler(Generic[C]):
|
||||
"""Class implementing the MQTT communication protocol using asyncio features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugins_manager: PluginManager,
|
||||
plugins_manager: PluginManager[C],
|
||||
session: Session | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
|
@ -79,7 +80,7 @@ class ProtocolHandler:
|
|||
self.session: Session | None = None
|
||||
self.reader: ReaderAdapter | None = None
|
||||
self.writer: WriterAdapter | None = None
|
||||
self.plugins_manager: PluginManager = plugins_manager
|
||||
self.plugins_manager: PluginManager[C] = plugins_manager
|
||||
|
||||
try:
|
||||
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||
|
|
|
@ -4,12 +4,13 @@ import logging
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.session import Session
|
||||
|
||||
|
||||
class EventLoggerPlugin(BasePlugin):
|
||||
class EventLoggerPlugin(BasePlugin[BaseContext]):
|
||||
"""A plugin to log events dynamically based on method names."""
|
||||
|
||||
async def log_event(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
@ -25,7 +26,7 @@ class EventLoggerPlugin(BasePlugin):
|
|||
raise AttributeError(msg)
|
||||
|
||||
|
||||
class PacketLoggerPlugin(BasePlugin):
|
||||
class PacketLoggerPlugin(BasePlugin[BaseContext]):
|
||||
"""A plugin to log MQTT packets sent and received."""
|
||||
|
||||
async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
|
|
@ -6,7 +6,7 @@ import contextlib
|
|||
import copy
|
||||
from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
|
||||
|
||||
from amqtt.session import Session
|
||||
|
||||
|
@ -24,10 +24,10 @@ class Plugin(NamedTuple):
|
|||
object: Any
|
||||
|
||||
|
||||
plugins_manager: dict[str, "PluginManager"] = {}
|
||||
plugins_manager: dict[str, "PluginManager[Any]"] = {}
|
||||
|
||||
|
||||
def get_plugin_manager(namespace: str) -> "PluginManager | None":
|
||||
def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None":
|
||||
"""Get the plugin manager for a given namespace.
|
||||
|
||||
:param namespace: The namespace of the plugin manager to retrieve.
|
||||
|
@ -43,14 +43,17 @@ class BaseContext:
|
|||
self.config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PluginManager:
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
||||
class PluginManager(Generic[C]):
|
||||
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
|
||||
|
||||
Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to
|
||||
run plugin calls asynchronously in an event queue.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
|
||||
def __init__(self, namespace: str, context: C | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
|
||||
try:
|
||||
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
|
@ -60,7 +63,7 @@ class PluginManager:
|
|||
self.logger = logging.getLogger(namespace)
|
||||
self.context = context if context is not None else BaseContext()
|
||||
self.context.loop = self._loop
|
||||
self._plugins: list[BasePlugin] = []
|
||||
self._plugins: list[BasePlugin[C]] = []
|
||||
self._auth_plugins: list[BaseAuthPlugin] = []
|
||||
self._topic_plugins: list[BaseTopicPlugin] = []
|
||||
self._load_plugins(namespace)
|
||||
|
@ -114,7 +117,7 @@ class PluginManager:
|
|||
|
||||
return None
|
||||
|
||||
def get_plugin(self, name: str) -> Optional["BasePlugin"]:
|
||||
def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]:
|
||||
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
||||
|
||||
:param name:
|
||||
|
@ -134,7 +137,7 @@ class PluginManager:
|
|||
self._fired_events.clear()
|
||||
|
||||
@property
|
||||
def plugins(self) -> list["BasePlugin"]:
|
||||
def plugins(self) -> list["BasePlugin[C]"]:
|
||||
"""Get the loaded plugins list.
|
||||
|
||||
:return:
|
||||
|
@ -179,7 +182,7 @@ class PluginManager:
|
|||
await asyncio.wait(tasks)
|
||||
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
|
||||
|
||||
async def map_plugin_auth(self, session: Session) -> dict["BasePlugin", str | bool | None]:
|
||||
async def map_plugin_auth(self, session: Session) -> dict["BasePlugin[C]", str | bool | None]:
|
||||
"""Schedule a coroutine for plugin 'authenticate' calls.
|
||||
|
||||
:param session: the client session associated with the authentication check
|
||||
|
@ -198,17 +201,17 @@ class PluginManager:
|
|||
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
ret_dict: dict[BasePlugin, str | bool | None] = {}
|
||||
ret_dict: dict[BasePlugin[C], str | bool | None] = {}
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) # type: ignore[arg-type]
|
||||
|
||||
return ret_dict
|
||||
|
||||
async def map_plugin_topic(self,
|
||||
session: Session, topic: str, action: "Action"
|
||||
) -> dict["BasePlugin", str | bool | None]:
|
||||
) -> dict["BasePlugin[C]", str | bool | None]:
|
||||
"""Schedule a coroutine for plugin 'topic_filtering' calls.
|
||||
|
||||
:param session: the client session associated with the topic_filtering check
|
||||
|
@ -229,21 +232,21 @@ class PluginManager:
|
|||
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
ret_dict: dict[BasePlugin, str | bool | None] = {}
|
||||
ret_dict: dict[BasePlugin[C], str | bool | None] = {}
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict= dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||
ret_dict= dict(zip(self._topic_plugins, ret_list, strict=False)) # type: ignore[arg-type]
|
||||
|
||||
return ret_dict
|
||||
|
||||
async def map_plugin_close(self) -> dict["BasePlugin", str | bool | None]:
|
||||
async def map_plugin_close(self) -> dict["BasePlugin[C]", str | bool | None]:
|
||||
|
||||
tasks: list[asyncio.Future[Any]] = []
|
||||
|
||||
for plugin in self._plugins:
|
||||
|
||||
async def close_coro(p: "BasePlugin") -> None:
|
||||
async def close_coro(p: "BasePlugin[C]") -> None:
|
||||
await p.close()
|
||||
|
||||
if not hasattr(plugin, "close"):
|
||||
|
@ -252,10 +255,10 @@ class PluginManager:
|
|||
coro_instance: Awaitable[str | bool | None] = close_coro(plugin)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
ret_dict: dict[BasePlugin, str | bool | None] = {}
|
||||
ret_dict: dict[BasePlugin[C], str | bool | None] = {}
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
|
||||
ret_dict = dict(zip(self._plugins, ret_list, strict=False))
|
||||
|
||||
return ret_dict
|
||||
|
|
Ładowanie…
Reference in New Issue