kopia lustrzana https://github.com/Yakifo/amqtt
Merge pull request #212 from ajmirsky/plugin_call_optimization
reduce call logic for plugin coros: authenticate, topic_filtering, closepull/216/head
commit
aace4d65f7
|
@ -711,20 +711,16 @@ class Broker:
|
|||
:param session:
|
||||
:return:
|
||||
"""
|
||||
auth_plugins = None
|
||||
auth_config = self.config.get("auth", None)
|
||||
if isinstance(auth_config, dict):
|
||||
auth_plugins = auth_config.get("plugins", None)
|
||||
returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
|
||||
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
||||
auth_result = True
|
||||
if returns:
|
||||
for plugin in returns:
|
||||
res = returns[plugin]
|
||||
if res is False:
|
||||
auth_result = False
|
||||
self.logger.debug(f"Authentication failed due to '{plugin.name}' plugin result: {res}")
|
||||
self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}")
|
||||
else:
|
||||
self.logger.debug(f"'{plugin.name}' plugin result: {res}")
|
||||
self.logger.debug(f"'{plugin.__class__}' plugin result: {res}")
|
||||
# If all plugins returned True, authentication is success
|
||||
return auth_result
|
||||
|
||||
|
@ -785,20 +781,14 @@ class Broker:
|
|||
"""
|
||||
topic_config = self.config.get("topic-check", {})
|
||||
enabled = False
|
||||
topic_plugins: list[str] | None = None
|
||||
|
||||
if isinstance(topic_config, dict):
|
||||
enabled = topic_config.get("enabled", False)
|
||||
topic_plugins = topic_config.get("plugins")
|
||||
|
||||
if not enabled:
|
||||
return True
|
||||
results = await self.plugins_manager.map_plugin_coro(
|
||||
"topic_filtering",
|
||||
session=session,
|
||||
topic=topic,
|
||||
action=action,
|
||||
filter_plugins=topic_plugins,
|
||||
)
|
||||
|
||||
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
|
||||
return all(result for result in results.values())
|
||||
|
||||
async def _delete_session(self, client_id: str) -> None:
|
||||
|
|
|
@ -110,7 +110,7 @@ class MQTTClient:
|
|||
# Init plugins manager
|
||||
context = ClientContext()
|
||||
context.config = self.config
|
||||
self.plugins_manager = PluginManager("amqtt.client.plugins", context)
|
||||
self.plugins_manager: PluginManager[ClientContext] = PluginManager("amqtt.client.plugins", context)
|
||||
self.client_tasks: deque[asyncio.Task[Any]] = deque()
|
||||
|
||||
async def connect(
|
||||
|
|
|
@ -50,3 +50,7 @@ class ConnectError(ClientError):
|
|||
|
||||
class ProtocolHandlerError(Exception):
|
||||
"""Exceptions thrown by protocol handle."""
|
||||
|
||||
|
||||
class PluginLoadError(Exception):
|
||||
"""Exception thrown when loading a plugin."""
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
from asyncio import AbstractEventLoop, Queue
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.errors import MQTTError
|
||||
|
@ -28,6 +29,8 @@ from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT
|
|||
|
||||
_MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.broker import BrokerContext
|
||||
|
||||
class Subscription:
|
||||
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
|
||||
|
@ -41,10 +44,10 @@ class UnSubscription:
|
|||
self.topics = topics
|
||||
|
||||
|
||||
class BrokerProtocolHandler(ProtocolHandler):
|
||||
class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]):
|
||||
def __init__(
|
||||
self,
|
||||
plugins_manager: PluginManager,
|
||||
plugins_manager: PluginManager["BrokerContext"],
|
||||
session: Session | None = None,
|
||||
loop: AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
|
@ -156,7 +159,7 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
cls,
|
||||
reader: ReaderAdapter,
|
||||
writer: WriterAdapter,
|
||||
plugins_manager: PluginManager,
|
||||
plugins_manager: PluginManager["BrokerContext"],
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> tuple["BrokerProtocolHandler", Session]:
|
||||
"""Initialize from a CONNECT packet and validates the connection."""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from amqtt.errors import AMQTTError, NoDataError
|
||||
from amqtt.mqtt.connack import ConnackPacket
|
||||
|
@ -15,11 +15,13 @@ from amqtt.mqtt.unsubscribe import UnsubscribePacket
|
|||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.session import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.client import ClientContext
|
||||
|
||||
class ClientProtocolHandler(ProtocolHandler):
|
||||
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
|
||||
def __init__(
|
||||
self,
|
||||
plugins_manager: PluginManager,
|
||||
plugins_manager: PluginManager["ClientContext"],
|
||||
session: Session | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
|
|
|
@ -17,7 +17,7 @@ except ImportError:
|
|||
import collections
|
||||
import itertools
|
||||
import logging
|
||||
from typing import cast
|
||||
from typing import Generic, TypeVar, cast
|
||||
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError
|
||||
|
@ -56,19 +56,20 @@ from amqtt.mqtt.suback import SubackPacket
|
|||
from amqtt.mqtt.subscribe import SubscribePacket
|
||||
from amqtt.mqtt.unsuback import UnsubackPacket
|
||||
from amqtt.mqtt.unsubscribe import UnsubscribePacket
|
||||
from amqtt.plugins.manager import PluginManager
|
||||
from amqtt.plugins.manager import BaseContext, PluginManager
|
||||
from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session
|
||||
|
||||
EVENT_MQTT_PACKET_SENT = "mqtt_packet_sent"
|
||||
EVENT_MQTT_PACKET_RECEIVED = "mqtt_packet_received"
|
||||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
class ProtocolHandler:
|
||||
class ProtocolHandler(Generic[C]):
|
||||
"""Class implementing the MQTT communication protocol using asyncio features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugins_manager: PluginManager,
|
||||
plugins_manager: PluginManager[C],
|
||||
session: Session | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
|
@ -79,7 +80,7 @@ class ProtocolHandler:
|
|||
self.session: Session | None = None
|
||||
self.reader: ReaderAdapter | None = None
|
||||
self.writer: WriterAdapter | None = None
|
||||
self.plugins_manager: PluginManager = plugins_manager
|
||||
self.plugins_manager: PluginManager[C] = plugins_manager
|
||||
|
||||
try:
|
||||
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||
|
|
|
@ -5,15 +5,16 @@ from passlib.apps import custom_app_context as pwd_context
|
|||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
|
||||
|
||||
|
||||
class BaseAuthPlugin(BasePlugin):
|
||||
class BaseAuthPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for authentication plugins."""
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
|
||||
|
|
|
@ -1,19 +1,26 @@
|
|||
from typing import Any
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
class BasePlugin(Generic[C]):
|
||||
"""The base from which all plugins should inherit."""
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
self.context = context
|
||||
def __init__(self, context: C) -> None:
|
||||
self.context: C = context
|
||||
|
||||
def _get_config_section(self, name: str) -> dict[str, Any] | None:
|
||||
if not self.context.config or not self.context.config.get(name, None):
|
||||
|
||||
if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None):
|
||||
return None
|
||||
|
||||
section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
|
||||
# mypy has difficulty excluding int from `config`'s type, unless isinstance` is its own check
|
||||
if isinstance(section_config, int):
|
||||
return None
|
||||
return section_config
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Override if plugin needs to clean up resources upon shutdown."""
|
||||
|
|
|
@ -4,12 +4,13 @@ import logging
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.session import Session
|
||||
|
||||
|
||||
class EventLoggerPlugin(BasePlugin):
|
||||
class EventLoggerPlugin(BasePlugin[BaseContext]):
|
||||
"""A plugin to log events dynamically based on method names."""
|
||||
|
||||
async def log_event(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
@ -25,7 +26,7 @@ class EventLoggerPlugin(BasePlugin):
|
|||
raise AttributeError(msg)
|
||||
|
||||
|
||||
class PacketLoggerPlugin(BasePlugin):
|
||||
class PacketLoggerPlugin(BasePlugin[BaseContext]):
|
||||
"""A plugin to log MQTT packets sent and received."""
|
||||
|
||||
async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
|
|
@ -1,17 +1,24 @@
|
|||
__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"]
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable
|
||||
import contextlib
|
||||
import copy
|
||||
from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
||||
import logging
|
||||
from typing import Any, NamedTuple
|
||||
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
|
||||
|
||||
from amqtt.session import Session
|
||||
|
||||
from amqtt.errors import PluginImportError, PluginInitError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
|
||||
class Plugin(NamedTuple):
|
||||
name: str
|
||||
|
@ -19,10 +26,10 @@ class Plugin(NamedTuple):
|
|||
object: Any
|
||||
|
||||
|
||||
plugins_manager: dict[str, "PluginManager"] = {}
|
||||
plugins_manager: dict[str, "PluginManager[Any]"] = {}
|
||||
|
||||
|
||||
def get_plugin_manager(namespace: str) -> "PluginManager | None":
|
||||
def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None":
|
||||
"""Get the plugin manager for a given namespace.
|
||||
|
||||
:param namespace: The namespace of the plugin manager to retrieve.
|
||||
|
@ -38,14 +45,17 @@ class BaseContext:
|
|||
self.config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PluginManager:
|
||||
C = TypeVar("C", bound=BaseContext)
|
||||
|
||||
|
||||
class PluginManager(Generic[C]):
|
||||
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
|
||||
|
||||
Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to
|
||||
run plugin calls asynchronously in an event queue.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
|
||||
def __init__(self, namespace: str, context: C | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
|
||||
try:
|
||||
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
|
@ -55,7 +65,9 @@ class PluginManager:
|
|||
self.logger = logging.getLogger(namespace)
|
||||
self.context = context if context is not None else BaseContext()
|
||||
self.context.loop = self._loop
|
||||
self._plugins: list[Plugin] = []
|
||||
self._plugins: list[BasePlugin[C]] = []
|
||||
self._auth_plugins: list[BaseAuthPlugin] = []
|
||||
self._topic_plugins: list[BaseTopicPlugin] = []
|
||||
self._load_plugins(namespace)
|
||||
self._fired_events: list[asyncio.Future[Any]] = []
|
||||
plugins_manager[namespace] = self
|
||||
|
@ -65,7 +77,16 @@ class PluginManager:
|
|||
return self.context
|
||||
|
||||
def _load_plugins(self, namespace: str) -> None:
|
||||
|
||||
self.logger.debug(f"Loading plugins for namespace {namespace}")
|
||||
|
||||
auth_filter_list = []
|
||||
topic_filter_list = []
|
||||
if self.app_context.config and "auth" in self.app_context.config:
|
||||
auth_filter_list = self.app_context.config["auth"].get("plugins", [])
|
||||
if self.app_context.config and "topic" in self.app_context.config:
|
||||
topic_filter_list = self.app_context.config["topic"].get("plugins", [])
|
||||
|
||||
ep: EntryPoints | list[EntryPoint] = []
|
||||
if hasattr(entry_points(), "select"):
|
||||
ep = entry_points().select(group=namespace)
|
||||
|
@ -73,12 +94,16 @@ class PluginManager:
|
|||
ep = [entry_points()[namespace]]
|
||||
|
||||
for item in ep:
|
||||
plugin = self._load_plugin(item)
|
||||
plugin = self._load_ep_plugin(item)
|
||||
if plugin is not None:
|
||||
self._plugins.append(plugin)
|
||||
self._plugins.append(plugin.object)
|
||||
if (not auth_filter_list or plugin.name in auth_filter_list) and hasattr(plugin.object, "authenticate"):
|
||||
self._auth_plugins.append(plugin.object)
|
||||
if (not topic_filter_list or plugin.name in topic_filter_list) and hasattr(plugin.object, "topic_filtering"):
|
||||
self._topic_plugins.append(plugin.object)
|
||||
self.logger.debug(f" Plugin {item.name} ready")
|
||||
|
||||
def _load_plugin(self, ep: EntryPoint) -> Plugin | None:
|
||||
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
|
||||
try:
|
||||
self.logger.debug(f" Loading plugin {ep!s}")
|
||||
plugin = ep.load()
|
||||
|
@ -98,26 +123,27 @@ class PluginManager:
|
|||
self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
|
||||
raise PluginInitError(ep) from e
|
||||
|
||||
def get_plugin(self, name: str) -> Plugin | None:
|
||||
def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]:
|
||||
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
||||
|
||||
:param name:
|
||||
:return:
|
||||
"""
|
||||
for p in self._plugins:
|
||||
if p.name == name:
|
||||
self.logger.debug(f"plugin name >>>> {p.__class__.__name__}")
|
||||
if p.__class__.__name__ == name:
|
||||
return p
|
||||
return None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Free PluginManager resources and cancel pending event methods."""
|
||||
await self.map_plugin_coro("close")
|
||||
await self.map_plugin_close()
|
||||
for task in self._fired_events:
|
||||
task.cancel()
|
||||
self._fired_events.clear()
|
||||
|
||||
@property
|
||||
def plugins(self) -> list[Plugin]:
|
||||
def plugins(self) -> list["BasePlugin[C]"]:
|
||||
"""Get the loaded plugins list.
|
||||
|
||||
:return:
|
||||
|
@ -143,7 +169,7 @@ class PluginManager:
|
|||
tasks: list[asyncio.Future[Any]] = []
|
||||
event_method_name = "on_" + event_name
|
||||
for plugin in self._plugins:
|
||||
event_method = getattr(plugin.object, event_method_name, None)
|
||||
event_method = getattr(plugin, event_method_name, None)
|
||||
if event_method:
|
||||
try:
|
||||
task = self._schedule_coro(event_method(*args, **kwargs))
|
||||
|
@ -155,66 +181,73 @@ class PluginManager:
|
|||
|
||||
task.add_done_callback(clean_fired_events)
|
||||
except AssertionError:
|
||||
self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.name}' is not a coroutine")
|
||||
self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.__class__}' is not a coroutine")
|
||||
|
||||
self._fired_events.extend(tasks)
|
||||
if wait and tasks:
|
||||
await asyncio.wait(tasks)
|
||||
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
|
||||
|
||||
async def map(
|
||||
self,
|
||||
coro: Callable[[Plugin, Any], Awaitable[str | bool | None]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> dict[Plugin, str | bool | None]:
|
||||
"""Schedule a given coroutine call for each plugin.
|
||||
@staticmethod
|
||||
async def _map_plugin_method(
|
||||
plugins: list["BasePlugin[C]"],
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
) -> dict["BasePlugin[C]", str | bool | None]:
|
||||
"""Call plugin coroutines.
|
||||
|
||||
The coro called gets the Plugin instance as the first argument of its method call.
|
||||
:param coro: coro to call on each plugin
|
||||
:param filter_plugins: list of plugin names to filter (only plugin whose name is
|
||||
in the filter are called). None will call all plugins. [] will call None.
|
||||
:param args: arguments to pass to coro
|
||||
:param kwargs: arguments to pass to coro
|
||||
:param plugins: List of plugins to execute the method on
|
||||
:param method_name: Name of the method to call on each plugin
|
||||
:param method_kwargs: Keyword arguments to pass to the method
|
||||
:return: dict containing return from coro call for each plugin.
|
||||
"""
|
||||
p_list = kwargs.pop("filter_plugins", None)
|
||||
if p_list is None:
|
||||
p_list = [p.name for p in self.plugins]
|
||||
tasks: list[asyncio.Future[Any]] = []
|
||||
plugins_list: list[Plugin] = []
|
||||
for plugin in self._plugins:
|
||||
if plugin.name in p_list:
|
||||
coro_instance = coro(plugin, *args, **kwargs)
|
||||
if coro_instance:
|
||||
try:
|
||||
tasks.append(self._schedule_coro(coro_instance))
|
||||
plugins_list.append(plugin)
|
||||
except AssertionError:
|
||||
self.logger.exception(f"Method '{coro!r}' on plugin '{plugin.name}' is not a coroutine")
|
||||
|
||||
for plugin in plugins:
|
||||
if not hasattr(plugin, method_name):
|
||||
continue
|
||||
|
||||
async def call_method(p: "BasePlugin[C]", kwargs: dict[str, Any]) -> Any:
|
||||
method = getattr(p, method_name)
|
||||
return await method(**kwargs)
|
||||
|
||||
coro_instance: Awaitable[Any] = call_method(plugin, method_kwargs)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
ret_dict: dict[BasePlugin[C], str | bool | None] = {}
|
||||
if tasks:
|
||||
ret_list = await asyncio.gather(*tasks)
|
||||
# Create result map plugin => ret
|
||||
ret_dict = dict(zip(plugins_list, ret_list, strict=False))
|
||||
else:
|
||||
ret_dict = {}
|
||||
ret_dict = dict(zip(plugins, ret_list, strict=False))
|
||||
|
||||
return ret_dict
|
||||
|
||||
@staticmethod
|
||||
async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> str | bool | None:
|
||||
if not hasattr(plugin.object, coro_name):
|
||||
_LOGGER.warning(f"Plugin doesn't implement coro_name '{coro_name}': {plugin.name}")
|
||||
return None
|
||||
async def map_plugin_auth(self, *, session: Session) -> dict["BasePlugin[C]", str | bool | None]:
|
||||
"""Schedule a coroutine for plugin 'authenticate' calls.
|
||||
|
||||
coro: Awaitable[str | bool | None] = getattr(plugin.object, coro_name)(*args, **kwargs)
|
||||
return await coro
|
||||
|
||||
async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, str | bool | None]:
|
||||
"""Call a plugin declared by plugin by its name.
|
||||
|
||||
:param coro_name:
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
:param session: the client session associated with the authentication check
|
||||
:return: dict containing return from coro call for each plugin.
|
||||
"""
|
||||
return await self.map(self._call_coro, coro_name, *args, **kwargs)
|
||||
return await self._map_plugin_method(
|
||||
self._auth_plugins, "authenticate", {"session": session }) # type: ignore[arg-type]
|
||||
|
||||
async def map_plugin_topic(
|
||||
self, *, session: Session, topic: str, action: "Action"
|
||||
) -> dict["BasePlugin[C]", str | bool | None]:
|
||||
"""Schedule a coroutine for plugin 'topic_filtering' calls.
|
||||
|
||||
:param session: the client session associated with the topic_filtering check
|
||||
:param topic: the topic that needs to be filtered
|
||||
:param action: the action being executed
|
||||
:return: dict containing return from coro call for each plugin.
|
||||
"""
|
||||
return await self._map_plugin_method(
|
||||
self._topic_plugins, "topic_filtering", # type: ignore[arg-type]
|
||||
{"session": session, "topic": topic, "action": action}
|
||||
)
|
||||
|
||||
async def map_plugin_close(self) -> None:
|
||||
"""Schedule a coroutine for plugin 'close' calls.
|
||||
|
||||
:return: dict containing return from coro call for each plugin.
|
||||
"""
|
||||
await self._map_plugin_method(self._plugins, "close", {})
|
||||
|
|
|
@ -42,7 +42,7 @@ STAT_CLIENTS_CONNECTED = "clients_connected"
|
|||
STAT_CLIENTS_DISCONNECTED = "clients_disconnected"
|
||||
|
||||
|
||||
class BrokerSysPlugin(BasePlugin):
|
||||
class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
# Broker statistics initialization
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from typing import Any
|
||||
|
||||
from amqtt.broker import Action, BrokerContext
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.session import Session
|
||||
|
||||
|
||||
class BaseTopicPlugin(BasePlugin):
|
||||
class BaseTopicPlugin(BasePlugin[BaseContext]):
|
||||
"""Base class for topic plugins."""
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
|
||||
|
@ -37,7 +38,7 @@ class BaseTopicPlugin(BasePlugin):
|
|||
|
||||
|
||||
class TopicTabooPlugin(BaseTopicPlugin):
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
self._taboo: list[str] = ["prohibited", "top-secret", "data/classified"]
|
||||
|
||||
|
|
|
@ -6,3 +6,5 @@ default_retain: false
|
|||
auto_reconnect: true
|
||||
reconnect_max_interval: 10
|
||||
reconnect_retries: 2
|
||||
broker:
|
||||
uri: "mqtt://127.0.0.1"
|
|
@ -23,7 +23,9 @@ its own variables to configure its behavior.
|
|||
|
||||
::: amqtt.plugins.base.BasePlugin
|
||||
|
||||
Plugins that are defined in the`project.entry-points` are loaded and notified of events by when the subclass
|
||||
## Events
|
||||
|
||||
Plugins that are defined in the`project.entry-points` are notified of events if the subclass
|
||||
implements one or more of these methods:
|
||||
|
||||
- `on_mqtt_packet_sent`
|
||||
|
|
|
@ -32,7 +32,8 @@ dependencies = [
|
|||
"websockets==15.0.1", # https://pypi.org/project/websockets
|
||||
"passlib==1.7.4", # https://pypi.org/project/passlib
|
||||
"PyYAML==6.0.2", # https://pypi.org/project/PyYAML
|
||||
"typer==0.15.4"
|
||||
"typer==0.15.4",
|
||||
"dacite>=1.9.2",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
---
|
||||
listeners:
|
||||
default:
|
||||
type: tcp
|
||||
bind: 0.0.0.0:1883
|
||||
plugins:
|
||||
- test.plugins.plugins.TestSimplePlugin
|
||||
- test.plugins.plugins.TestConfigPlugin:
|
||||
option1: foo
|
||||
option2: bar
|
|
@ -1,18 +1,55 @@
|
|||
import logging
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from amqtt.broker import Action
|
||||
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NoAuthPlugin(BaseAuthPlugin):
|
||||
class TestSimplePlugin(BasePlugin):
|
||||
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
|
||||
|
||||
class TestConfigPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
option1: int
|
||||
option2: str
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
return False
|
||||
|
||||
class AuthPlugin(BaseAuthPlugin):
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
return True
|
||||
|
||||
|
||||
class NoAuthPlugin(BaseAuthPlugin):
|
||||
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
return False
|
||||
|
||||
|
||||
class TestTopicPlugin(BaseTopicPlugin):
|
||||
|
||||
def __init__(self, context: BaseContext):
|
||||
super().__init__(context)
|
||||
|
||||
def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
return True
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
import unittest
|
||||
|
||||
from amqtt.plugins.manager import BaseContext, Plugin, PluginManager
|
||||
from amqtt.broker import Action
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.plugins.manager import BaseContext, PluginManager
|
||||
from amqtt.plugins.topic_checking import BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=formatter)
|
||||
|
@ -14,21 +17,29 @@ class EmptyTestPlugin:
|
|||
self.context = context
|
||||
|
||||
|
||||
class EventTestPlugin:
|
||||
class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin):
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
self.context = context
|
||||
self.test_flag = False
|
||||
self.coro_flag = False
|
||||
super().__init__(context)
|
||||
self.test_close_flag = False
|
||||
self.test_auth_flag = False
|
||||
self.test_topic_flag = False
|
||||
self.test_event_flag = False
|
||||
|
||||
async def on_test(self) -> None:
|
||||
self.test_flag = True
|
||||
self.context.logger.info("on_test")
|
||||
self.test_event_flag = True
|
||||
|
||||
async def test_coro(self) -> None:
|
||||
self.coro_flag = True
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
self.test_auth_flag = True
|
||||
return None
|
||||
|
||||
async def ret_coro(self) -> str:
|
||||
return "TEST"
|
||||
async def topic_filtering(
|
||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool:
|
||||
self.test_topic_flag = True
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
self.test_close_flag = True
|
||||
|
||||
|
||||
class TestPluginManager(unittest.TestCase):
|
||||
|
@ -47,9 +58,9 @@ class TestPluginManager(unittest.TestCase):
|
|||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
self.loop.run_until_complete(fire_event())
|
||||
plugin = manager.get_plugin("event_plugin")
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert plugin.object.test_flag
|
||||
assert plugin.test_event_flag
|
||||
|
||||
def test_fire_event_wait(self) -> None:
|
||||
async def fire_event() -> None:
|
||||
|
@ -58,36 +69,33 @@ class TestPluginManager(unittest.TestCase):
|
|||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
self.loop.run_until_complete(fire_event())
|
||||
plugin = manager.get_plugin("event_plugin")
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert plugin.object.test_flag
|
||||
assert plugin.test_event_flag
|
||||
|
||||
def test_map_coro(self) -> None:
|
||||
async def call_coro() -> None:
|
||||
await manager.map_plugin_coro("test_coro")
|
||||
def test_plugin_close_coro(self) -> None:
|
||||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
self.loop.run_until_complete(call_coro())
|
||||
plugin = manager.get_plugin("event_plugin")
|
||||
self.loop.run_until_complete(manager.map_plugin_close())
|
||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert plugin.object.test_coro
|
||||
assert plugin.test_close_flag
|
||||
|
||||
def test_map_coro_return(self) -> None:
|
||||
async def call_coro() -> dict[Plugin, str]:
|
||||
return await manager.map_plugin_coro("ret_coro")
|
||||
def test_plugin_auth_coro(self) -> None:
|
||||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
ret = self.loop.run_until_complete(call_coro())
|
||||
plugin = manager.get_plugin("event_plugin")
|
||||
self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
|
||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert ret[plugin] == "TEST"
|
||||
assert plugin.test_auth_flag
|
||||
|
||||
def test_map_coro_filter(self) -> None:
|
||||
"""Run plugin coro but expect no return as an empty filter is given."""
|
||||
|
||||
async def call_coro() -> dict[Plugin, Any]:
|
||||
return await manager.map_plugin_coro("ret_coro", filter_plugins=[])
|
||||
def test_plugin_topic_coro(self) -> None:
|
||||
|
||||
manager = PluginManager("amqtt.test.plugins", context=None)
|
||||
ret = self.loop.run_until_complete(call_coro())
|
||||
assert len(ret) == 0
|
||||
self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
|
||||
self.loop.run_until_complete(asyncio.sleep(0.5))
|
||||
plugin = manager.get_plugin("EventTestPlugin")
|
||||
assert plugin is not None
|
||||
assert plugin.test_topic_flag
|
||||
|
|
11
uv.lock
11
uv.lock
|
@ -12,6 +12,7 @@ name = "amqtt"
|
|||
version = "0.11.0rc1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "dacite" },
|
||||
{ name = "passlib" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "transitions" },
|
||||
|
@ -66,6 +67,7 @@ docs = [
|
|||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" },
|
||||
{ name = "dacite", specifier = ">=1.9.2" },
|
||||
{ name = "passlib", specifier = "==1.7.4" },
|
||||
{ name = "pyyaml", specifier = "==6.0.2" },
|
||||
{ name = "transitions", specifier = "==0.9.2" },
|
||||
|
@ -485,6 +487,15 @@ version = "0.9.5"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/8c3ac3d8bc94e6de8d7ae270bb5bc437b210bb9d6d9e46630c98f4abd20c/csscompressor-0.9.5.tar.gz", hash = "sha256:afa22badbcf3120a4f392e4d22f9fff485c044a1feda4a950ecc5eba9dd31a05", size = 237808, upload-time = "2017-11-26T21:13:08.238Z" }
|
||||
|
||||
[[package]]
|
||||
name = "dacite"
|
||||
version = "1.9.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420, upload-time = "2025-02-05T09:27:29.757Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600, upload-time = "2025-02-05T09:27:24.345Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
version = "0.4.0"
|
||||
|
|
Ładowanie…
Reference in New Issue