diff --git a/amqtt/broker.py b/amqtt/broker.py index a42b9c3..b8e2fd4 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -711,20 +711,16 @@ class Broker: :param session: :return: """ - auth_plugins = None - auth_config = self.config.get("auth", None) - if isinstance(auth_config, dict): - auth_plugins = auth_config.get("plugins", None) - returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins) + returns = await self.plugins_manager.map_plugin_auth(session=session) auth_result = True if returns: for plugin in returns: res = returns[plugin] if res is False: auth_result = False - self.logger.debug(f"Authentication failed due to '{plugin.name}' plugin result: {res}") + self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}") else: - self.logger.debug(f"'{plugin.name}' plugin result: {res}") + self.logger.debug(f"'{plugin.__class__}' plugin result: {res}") # If all plugins returned True, authentication is success return auth_result @@ -785,20 +781,14 @@ class Broker: """ topic_config = self.config.get("topic-check", {}) enabled = False - topic_plugins: list[str] | None = None + if isinstance(topic_config, dict): enabled = topic_config.get("enabled", False) - topic_plugins = topic_config.get("plugins") if not enabled: return True - results = await self.plugins_manager.map_plugin_coro( - "topic_filtering", - session=session, - topic=topic, - action=action, - filter_plugins=topic_plugins, - ) + + results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action) return all(result for result in results.values()) async def _delete_session(self, client_id: str) -> None: diff --git a/amqtt/client.py b/amqtt/client.py index b56c236..5154a98 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -110,7 +110,7 @@ class MQTTClient: # Init plugins manager context = ClientContext() context.config = self.config - self.plugins_manager = PluginManager("amqtt.client.plugins", context) + self.plugins_manager: PluginManager[ClientContext] = PluginManager("amqtt.client.plugins", context) self.client_tasks: deque[asyncio.Task[Any]] = deque() async def connect( diff --git a/amqtt/errors.py b/amqtt/errors.py index a7d2084..4c653a8 100644 --- a/amqtt/errors.py +++ b/amqtt/errors.py @@ -50,3 +50,7 @@ class ConnectError(ClientError): class ProtocolHandlerError(Exception): """Exceptions thrown by protocol handle.""" + + +class PluginLoadError(Exception): + """Exception thrown when loading a plugin.""" diff --git a/amqtt/mqtt/protocol/broker_handler.py b/amqtt/mqtt/protocol/broker_handler.py index c23211e..c8cae86 100644 --- a/amqtt/mqtt/protocol/broker_handler.py +++ b/amqtt/mqtt/protocol/broker_handler.py @@ -1,5 +1,6 @@ import asyncio from asyncio import AbstractEventLoop, Queue +from typing import TYPE_CHECKING from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.errors import MQTTError @@ -28,6 +29,8 @@ from .handler import EVENT_MQTT_PACKET_RECEIVED, EVENT_MQTT_PACKET_SENT _MQTT_PROTOCOL_LEVEL_SUPPORTED = 4 +if TYPE_CHECKING: + from amqtt.broker import BrokerContext class Subscription: def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None: @@ -41,10 +44,10 @@ class UnSubscription: self.topics = topics -class BrokerProtocolHandler(ProtocolHandler): +class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]): def __init__( self, - plugins_manager: PluginManager, + plugins_manager: PluginManager["BrokerContext"], session: Session | None = None, loop: AbstractEventLoop | None = None, ) -> None: @@ -156,7 +159,7 @@ class BrokerProtocolHandler(ProtocolHandler): cls, reader: ReaderAdapter, writer: WriterAdapter, - plugins_manager: PluginManager, + plugins_manager: PluginManager["BrokerContext"], loop: asyncio.AbstractEventLoop | None = None, ) -> tuple["BrokerProtocolHandler", Session]: """Initialize from a CONNECT packet and validates the connection.""" diff --git a/amqtt/mqtt/protocol/client_handler.py b/amqtt/mqtt/protocol/client_handler.py index 9b7d9d8..6815ab7 100644 --- a/amqtt/mqtt/protocol/client_handler.py +++ b/amqtt/mqtt/protocol/client_handler.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any +from typing import TYPE_CHECKING, Any from amqtt.errors import AMQTTError, NoDataError from amqtt.mqtt.connack import ConnackPacket @@ -15,11 +15,13 @@ from amqtt.mqtt.unsubscribe import UnsubscribePacket from amqtt.plugins.manager import PluginManager from amqtt.session import Session +if TYPE_CHECKING: + from amqtt.client import ClientContext -class ClientProtocolHandler(ProtocolHandler): +class ClientProtocolHandler(ProtocolHandler["ClientContext"]): def __init__( self, - plugins_manager: PluginManager, + plugins_manager: PluginManager["ClientContext"], session: Session | None = None, loop: asyncio.AbstractEventLoop | None = None, ) -> None: diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index 244bd46..6197868 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -17,7 +17,7 @@ except ImportError: import collections import itertools import logging -from typing import cast +from typing import Generic, TypeVar, cast from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError @@ -56,19 +56,20 @@ from amqtt.mqtt.suback import SubackPacket from amqtt.mqtt.subscribe import SubscribePacket from amqtt.mqtt.unsuback import UnsubackPacket from amqtt.mqtt.unsubscribe import UnsubscribePacket -from amqtt.plugins.manager import PluginManager +from amqtt.plugins.manager import BaseContext, PluginManager from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session EVENT_MQTT_PACKET_SENT = "mqtt_packet_sent" EVENT_MQTT_PACKET_RECEIVED = "mqtt_packet_received" +C = TypeVar("C", bound=BaseContext) -class ProtocolHandler: +class ProtocolHandler(Generic[C]): """Class implementing the MQTT communication protocol using asyncio features.""" def __init__( self, - plugins_manager: PluginManager, + plugins_manager: PluginManager[C], session: Session | None = None, loop: asyncio.AbstractEventLoop | None = None, ) -> None: @@ -79,7 +80,7 @@ class ProtocolHandler: self.session: Session | None = None self.reader: ReaderAdapter | None = None self.writer: WriterAdapter | None = None - self.plugins_manager: PluginManager = plugins_manager + self.plugins_manager: PluginManager[C] = plugins_manager try: self._loop = loop if loop is not None else asyncio.get_running_loop() diff --git a/amqtt/plugins/authentication.py b/amqtt/plugins/authentication.py index 8b85917..90f7eb9 100644 --- a/amqtt/plugins/authentication.py +++ b/amqtt/plugins/authentication.py @@ -5,15 +5,16 @@ from passlib.apps import custom_app_context as pwd_context from amqtt.broker import BrokerContext from amqtt.plugins.base import BasePlugin +from amqtt.plugins.manager import BaseContext from amqtt.session import Session _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line -class BaseAuthPlugin(BasePlugin): +class BaseAuthPlugin(BasePlugin[BaseContext]): """Base class for authentication plugins.""" - def __init__(self, context: BrokerContext) -> None: + def __init__(self, context: BaseContext) -> None: super().__init__(context) self.auth_config: dict[str, Any] | None = self._get_config_section("auth") diff --git a/amqtt/plugins/base.py b/amqtt/plugins/base.py index 89f4443..2d5f644 100644 --- a/amqtt/plugins/base.py +++ b/amqtt/plugins/base.py @@ -1,19 +1,26 @@ -from typing import Any +from typing import Any, Generic, TypeVar -from amqtt.broker import BrokerContext +from amqtt.plugins.manager import BaseContext + +C = TypeVar("C", bound=BaseContext) -class BasePlugin: +class BasePlugin(Generic[C]): """The base from which all plugins should inherit.""" - def __init__(self, context: BrokerContext) -> None: - self.context = context + def __init__(self, context: C) -> None: + self.context: C = context def _get_config_section(self, name: str) -> dict[str, Any] | None: - if not self.context.config or not self.context.config.get(name, None): + + if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None): return None + section_config: int | dict[str, Any] | None = self.context.config.get(name, None) # mypy has difficulty excluding int from `config`'s type, unless isinstance` is its own check if isinstance(section_config, int): return None return section_config + + async def close(self) -> None: + """Override if plugin needs to clean up resources upon shutdown.""" diff --git a/amqtt/plugins/logging_amqtt.py b/amqtt/plugins/logging_amqtt.py index 84cf6d7..13b0f20 100644 --- a/amqtt/plugins/logging_amqtt.py +++ b/amqtt/plugins/logging_amqtt.py @@ -4,12 +4,13 @@ import logging from typing import TYPE_CHECKING, Any from amqtt.plugins.base import BasePlugin +from amqtt.plugins.manager import BaseContext if TYPE_CHECKING: from amqtt.session import Session -class EventLoggerPlugin(BasePlugin): +class EventLoggerPlugin(BasePlugin[BaseContext]): """A plugin to log events dynamically based on method names.""" async def log_event(self, *args: Any, **kwargs: Any) -> None: @@ -25,7 +26,7 @@ class EventLoggerPlugin(BasePlugin): raise AttributeError(msg) -class PacketLoggerPlugin(BasePlugin): +class PacketLoggerPlugin(BasePlugin[BaseContext]): """A plugin to log MQTT packets sent and received.""" async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None: diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index 61bbb49..39e9bb0 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -1,17 +1,24 @@ __all__ = ["BaseContext", "PluginManager", "get_plugin_manager"] import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable import contextlib import copy from importlib.metadata import EntryPoint, EntryPoints, entry_points import logging -from typing import Any, NamedTuple +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar + +from amqtt.session import Session from amqtt.errors import PluginImportError, PluginInitError _LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from amqtt.broker import Action + from amqtt.plugins.authentication import BaseAuthPlugin + from amqtt.plugins.base import BasePlugin + from amqtt.plugins.topic_checking import BaseTopicPlugin class Plugin(NamedTuple): name: str @@ -19,10 +26,10 @@ class Plugin(NamedTuple): object: Any -plugins_manager: dict[str, "PluginManager"] = {} +plugins_manager: dict[str, "PluginManager[Any]"] = {} -def get_plugin_manager(namespace: str) -> "PluginManager | None": +def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None": """Get the plugin manager for a given namespace. :param namespace: The namespace of the plugin manager to retrieve. @@ -38,14 +45,17 @@ class BaseContext: self.config: dict[str, Any] | None = None -class PluginManager: +C = TypeVar("C", bound=BaseContext) + + +class PluginManager(Generic[C]): """Wraps contextlib Entry point mechanism to provide a basic plugin system. Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to run plugin calls asynchronously in an event queue. """ - def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None: + def __init__(self, namespace: str, context: C | None, loop: asyncio.AbstractEventLoop | None = None) -> None: try: self._loop = loop if loop is not None else asyncio.get_running_loop() except RuntimeError: @@ -55,7 +65,9 @@ class PluginManager: self.logger = logging.getLogger(namespace) self.context = context if context is not None else BaseContext() self.context.loop = self._loop - self._plugins: list[Plugin] = [] + self._plugins: list[BasePlugin[C]] = [] + self._auth_plugins: list[BaseAuthPlugin] = [] + self._topic_plugins: list[BaseTopicPlugin] = [] self._load_plugins(namespace) self._fired_events: list[asyncio.Future[Any]] = [] plugins_manager[namespace] = self @@ -65,7 +77,16 @@ class PluginManager: return self.context def _load_plugins(self, namespace: str) -> None: + self.logger.debug(f"Loading plugins for namespace {namespace}") + + auth_filter_list = [] + topic_filter_list = [] + if self.app_context.config and "auth" in self.app_context.config: + auth_filter_list = self.app_context.config["auth"].get("plugins", []) + if self.app_context.config and "topic" in self.app_context.config: + topic_filter_list = self.app_context.config["topic"].get("plugins", []) + ep: EntryPoints | list[EntryPoint] = [] if hasattr(entry_points(), "select"): ep = entry_points().select(group=namespace) @@ -73,12 +94,16 @@ class PluginManager: ep = [entry_points()[namespace]] for item in ep: - plugin = self._load_plugin(item) + plugin = self._load_ep_plugin(item) if plugin is not None: - self._plugins.append(plugin) + self._plugins.append(plugin.object) + if (not auth_filter_list or plugin.name in auth_filter_list) and hasattr(plugin.object, "authenticate"): + self._auth_plugins.append(plugin.object) + if (not topic_filter_list or plugin.name in topic_filter_list) and hasattr(plugin.object, "topic_filtering"): + self._topic_plugins.append(plugin.object) self.logger.debug(f" Plugin {item.name} ready") - def _load_plugin(self, ep: EntryPoint) -> Plugin | None: + def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None: try: self.logger.debug(f" Loading plugin {ep!s}") plugin = ep.load() @@ -98,26 +123,27 @@ class PluginManager: self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True) raise PluginInitError(ep) from e - def get_plugin(self, name: str) -> Plugin | None: + def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]: """Get a plugin by its name from the plugins loaded for the current namespace. :param name: :return: """ for p in self._plugins: - if p.name == name: + self.logger.debug(f"plugin name >>>> {p.__class__.__name__}") + if p.__class__.__name__ == name: return p return None async def close(self) -> None: """Free PluginManager resources and cancel pending event methods.""" - await self.map_plugin_coro("close") + await self.map_plugin_close() for task in self._fired_events: task.cancel() self._fired_events.clear() @property - def plugins(self) -> list[Plugin]: + def plugins(self) -> list["BasePlugin[C]"]: """Get the loaded plugins list. :return: @@ -143,7 +169,7 @@ class PluginManager: tasks: list[asyncio.Future[Any]] = [] event_method_name = "on_" + event_name for plugin in self._plugins: - event_method = getattr(plugin.object, event_method_name, None) + event_method = getattr(plugin, event_method_name, None) if event_method: try: task = self._schedule_coro(event_method(*args, **kwargs)) @@ -155,66 +181,73 @@ class PluginManager: task.add_done_callback(clean_fired_events) except AssertionError: - self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.name}' is not a coroutine") + self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.__class__}' is not a coroutine") self._fired_events.extend(tasks) if wait and tasks: await asyncio.wait(tasks) self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}") - async def map( - self, - coro: Callable[[Plugin, Any], Awaitable[str | bool | None]], - *args: Any, - **kwargs: Any, - ) -> dict[Plugin, str | bool | None]: - """Schedule a given coroutine call for each plugin. + @staticmethod + async def _map_plugin_method( + plugins: list["BasePlugin[C]"], + method_name: str, + method_kwargs: dict[str, Any], + ) -> dict["BasePlugin[C]", str | bool | None]: + """Call plugin coroutines. - The coro called gets the Plugin instance as the first argument of its method call. - :param coro: coro to call on each plugin - :param filter_plugins: list of plugin names to filter (only plugin whose name is - in the filter are called). None will call all plugins. [] will call None. - :param args: arguments to pass to coro - :param kwargs: arguments to pass to coro + :param plugins: List of plugins to execute the method on + :param method_name: Name of the method to call on each plugin + :param method_kwargs: Keyword arguments to pass to the method :return: dict containing return from coro call for each plugin. """ - p_list = kwargs.pop("filter_plugins", None) - if p_list is None: - p_list = [p.name for p in self.plugins] tasks: list[asyncio.Future[Any]] = [] - plugins_list: list[Plugin] = [] - for plugin in self._plugins: - if plugin.name in p_list: - coro_instance = coro(plugin, *args, **kwargs) - if coro_instance: - try: - tasks.append(self._schedule_coro(coro_instance)) - plugins_list.append(plugin) - except AssertionError: - self.logger.exception(f"Method '{coro!r}' on plugin '{plugin.name}' is not a coroutine") + + for plugin in plugins: + if not hasattr(plugin, method_name): + continue + + async def call_method(p: "BasePlugin[C]", kwargs: dict[str, Any]) -> Any: + method = getattr(p, method_name) + return await method(**kwargs) + + coro_instance: Awaitable[Any] = call_method(plugin, method_kwargs) + tasks.append(asyncio.ensure_future(coro_instance)) + + ret_dict: dict[BasePlugin[C], str | bool | None] = {} if tasks: ret_list = await asyncio.gather(*tasks) - # Create result map plugin => ret - ret_dict = dict(zip(plugins_list, ret_list, strict=False)) - else: - ret_dict = {} + ret_dict = dict(zip(plugins, ret_list, strict=False)) + return ret_dict - @staticmethod - async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> str | bool | None: - if not hasattr(plugin.object, coro_name): - _LOGGER.warning(f"Plugin doesn't implement coro_name '{coro_name}': {plugin.name}") - return None + async def map_plugin_auth(self, *, session: Session) -> dict["BasePlugin[C]", str | bool | None]: + """Schedule a coroutine for plugin 'authenticate' calls. - coro: Awaitable[str | bool | None] = getattr(plugin.object, coro_name)(*args, **kwargs) - return await coro - - async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, str | bool | None]: - """Call a plugin declared by plugin by its name. - - :param coro_name: - :param args: - :param kwargs: - :return: + :param session: the client session associated with the authentication check + :return: dict containing return from coro call for each plugin. """ - return await self.map(self._call_coro, coro_name, *args, **kwargs) + return await self._map_plugin_method( + self._auth_plugins, "authenticate", {"session": session }) # type: ignore[arg-type] + + async def map_plugin_topic( + self, *, session: Session, topic: str, action: "Action" + ) -> dict["BasePlugin[C]", str | bool | None]: + """Schedule a coroutine for plugin 'topic_filtering' calls. + + :param session: the client session associated with the topic_filtering check + :param topic: the topic that needs to be filtered + :param action: the action being executed + :return: dict containing return from coro call for each plugin. + """ + return await self._map_plugin_method( + self._topic_plugins, "topic_filtering", # type: ignore[arg-type] + {"session": session, "topic": topic, "action": action} + ) + + async def map_plugin_close(self) -> None: + """Schedule a coroutine for plugin 'close' calls. + + :return: dict containing return from coro call for each plugin. + """ + await self._map_plugin_method(self._plugins, "close", {}) diff --git a/amqtt/plugins/sys/broker.py b/amqtt/plugins/sys/broker.py index 996e399..4f42a5f 100644 --- a/amqtt/plugins/sys/broker.py +++ b/amqtt/plugins/sys/broker.py @@ -42,7 +42,7 @@ STAT_CLIENTS_CONNECTED = "clients_connected" STAT_CLIENTS_DISCONNECTED = "clients_disconnected" -class BrokerSysPlugin(BasePlugin): +class BrokerSysPlugin(BasePlugin[BrokerContext]): def __init__(self, context: BrokerContext) -> None: super().__init__(context) # Broker statistics initialization diff --git a/amqtt/plugins/topic_checking.py b/amqtt/plugins/topic_checking.py index b9f91ae..d92d672 100644 --- a/amqtt/plugins/topic_checking.py +++ b/amqtt/plugins/topic_checking.py @@ -1,14 +1,15 @@ from typing import Any -from amqtt.broker import Action, BrokerContext +from amqtt.broker import Action from amqtt.plugins.base import BasePlugin +from amqtt.plugins.manager import BaseContext from amqtt.session import Session -class BaseTopicPlugin(BasePlugin): +class BaseTopicPlugin(BasePlugin[BaseContext]): """Base class for topic plugins.""" - def __init__(self, context: BrokerContext) -> None: + def __init__(self, context: BaseContext) -> None: super().__init__(context) self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check") @@ -37,7 +38,7 @@ class BaseTopicPlugin(BasePlugin): class TopicTabooPlugin(BaseTopicPlugin): - def __init__(self, context: BrokerContext) -> None: + def __init__(self, context: BaseContext) -> None: super().__init__(context) self._taboo: list[str] = ["prohibited", "top-secret", "data/classified"] diff --git a/amqtt/scripts/default_client.yaml b/amqtt/scripts/default_client.yaml index b422baf..2feb944 100644 --- a/amqtt/scripts/default_client.yaml +++ b/amqtt/scripts/default_client.yaml @@ -6,3 +6,5 @@ default_retain: false auto_reconnect: true reconnect_max_interval: 10 reconnect_retries: 2 +broker: + uri: "mqtt://127.0.0.1" \ No newline at end of file diff --git a/docs/custom_plugins.md b/docs/custom_plugins.md index 218270c..f77a3f8 100644 --- a/docs/custom_plugins.md +++ b/docs/custom_plugins.md @@ -23,7 +23,9 @@ its own variables to configure its behavior. ::: amqtt.plugins.base.BasePlugin -Plugins that are defined in the`project.entry-points` are loaded and notified of events by when the subclass +## Events + +Plugins that are defined in the`project.entry-points` are notified of events if the subclass implements one or more of these methods: - `on_mqtt_packet_sent` diff --git a/pyproject.toml b/pyproject.toml index 226ac67..7420115 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,8 @@ dependencies = [ "websockets==15.0.1", # https://pypi.org/project/websockets "passlib==1.7.4", # https://pypi.org/project/passlib "PyYAML==6.0.2", # https://pypi.org/project/PyYAML - "typer==0.15.4" + "typer==0.15.4", + "dacite>=1.9.2", ] [dependency-groups] diff --git a/tests/plugins/broker_plugin.yml b/tests/plugins/broker_plugin.yml new file mode 100644 index 0000000..cd5de3c --- /dev/null +++ b/tests/plugins/broker_plugin.yml @@ -0,0 +1,10 @@ +--- +listeners: + default: + type: tcp + bind: 0.0.0.0:1883 +plugins: + - test.plugins.plugins.TestSimplePlugin + - test.plugins.plugins.TestConfigPlugin: + option1: foo + option2: bar diff --git a/tests/plugins/mocks.py b/tests/plugins/mocks.py index 3d80ca2..dab1229 100644 --- a/tests/plugins/mocks.py +++ b/tests/plugins/mocks.py @@ -1,18 +1,55 @@ import logging +from dataclasses import dataclass + +from amqtt.broker import Action + +from amqtt.plugins.base import BasePlugin +from amqtt.plugins.manager import BaseContext +from amqtt.plugins.topic_checking import BaseTopicPlugin from amqtt.plugins.authentication import BaseAuthPlugin + from amqtt.session import Session logger = logging.getLogger(__name__) -class NoAuthPlugin(BaseAuthPlugin): +class TestSimplePlugin(BasePlugin): + + def __init__(self, context: BaseContext): + super().__init__(context) + + +class TestConfigPlugin(BasePlugin): + + def __init__(self, context: BaseContext): + super().__init__(context) + + @dataclass + class Config: + option1: int + option2: str - async def authenticate(self, *, session: Session) -> bool | None: - return False class AuthPlugin(BaseAuthPlugin): async def authenticate(self, *, session: Session) -> bool | None: return True + +class NoAuthPlugin(BaseAuthPlugin): + + + async def authenticate(self, *, session: Session) -> bool | None: + return False + + +class TestTopicPlugin(BaseTopicPlugin): + + def __init__(self, context: BaseContext): + super().__init__(context) + + def topic_filtering( + self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None + ) -> bool: + return True diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_manager.py index 33b0d73..14a5ff6 100644 --- a/tests/plugins/test_manager.py +++ b/tests/plugins/test_manager.py @@ -1,9 +1,12 @@ import asyncio import logging -from typing import Any import unittest -from amqtt.plugins.manager import BaseContext, Plugin, PluginManager +from amqtt.broker import Action +from amqtt.plugins.authentication import BaseAuthPlugin +from amqtt.plugins.manager import BaseContext, PluginManager +from amqtt.plugins.topic_checking import BaseTopicPlugin +from amqtt.session import Session formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" logging.basicConfig(level=logging.INFO, format=formatter) @@ -14,21 +17,29 @@ class EmptyTestPlugin: self.context = context -class EventTestPlugin: +class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin): def __init__(self, context: BaseContext) -> None: - self.context = context - self.test_flag = False - self.coro_flag = False + super().__init__(context) + self.test_close_flag = False + self.test_auth_flag = False + self.test_topic_flag = False + self.test_event_flag = False async def on_test(self) -> None: - self.test_flag = True - self.context.logger.info("on_test") + self.test_event_flag = True - async def test_coro(self) -> None: - self.coro_flag = True + async def authenticate(self, *, session: Session) -> bool | None: + self.test_auth_flag = True + return None - async def ret_coro(self) -> str: - return "TEST" + async def topic_filtering( + self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None + ) -> bool: + self.test_topic_flag = True + return False + + async def close(self) -> None: + self.test_close_flag = True class TestPluginManager(unittest.TestCase): @@ -47,9 +58,9 @@ class TestPluginManager(unittest.TestCase): manager = PluginManager("amqtt.test.plugins", context=None) self.loop.run_until_complete(fire_event()) - plugin = manager.get_plugin("event_plugin") + plugin = manager.get_plugin("EventTestPlugin") assert plugin is not None - assert plugin.object.test_flag + assert plugin.test_event_flag def test_fire_event_wait(self) -> None: async def fire_event() -> None: @@ -58,36 +69,33 @@ class TestPluginManager(unittest.TestCase): manager = PluginManager("amqtt.test.plugins", context=None) self.loop.run_until_complete(fire_event()) - plugin = manager.get_plugin("event_plugin") + plugin = manager.get_plugin("EventTestPlugin") assert plugin is not None - assert plugin.object.test_flag + assert plugin.test_event_flag - def test_map_coro(self) -> None: - async def call_coro() -> None: - await manager.map_plugin_coro("test_coro") + def test_plugin_close_coro(self) -> None: manager = PluginManager("amqtt.test.plugins", context=None) - self.loop.run_until_complete(call_coro()) - plugin = manager.get_plugin("event_plugin") + self.loop.run_until_complete(manager.map_plugin_close()) + self.loop.run_until_complete(asyncio.sleep(0.5)) + plugin = manager.get_plugin("EventTestPlugin") assert plugin is not None - assert plugin.object.test_coro + assert plugin.test_close_flag - def test_map_coro_return(self) -> None: - async def call_coro() -> dict[Plugin, str]: - return await manager.map_plugin_coro("ret_coro") + def test_plugin_auth_coro(self) -> None: manager = PluginManager("amqtt.test.plugins", context=None) - ret = self.loop.run_until_complete(call_coro()) - plugin = manager.get_plugin("event_plugin") + self.loop.run_until_complete(manager.map_plugin_auth(session=Session())) + self.loop.run_until_complete(asyncio.sleep(0.5)) + plugin = manager.get_plugin("EventTestPlugin") assert plugin is not None - assert ret[plugin] == "TEST" + assert plugin.test_auth_flag - def test_map_coro_filter(self) -> None: - """Run plugin coro but expect no return as an empty filter is given.""" - - async def call_coro() -> dict[Plugin, Any]: - return await manager.map_plugin_coro("ret_coro", filter_plugins=[]) + def test_plugin_topic_coro(self) -> None: manager = PluginManager("amqtt.test.plugins", context=None) - ret = self.loop.run_until_complete(call_coro()) - assert len(ret) == 0 + self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH)) + self.loop.run_until_complete(asyncio.sleep(0.5)) + plugin = manager.get_plugin("EventTestPlugin") + assert plugin is not None + assert plugin.test_topic_flag diff --git a/uv.lock b/uv.lock index b01c8f6..228dcd9 100644 --- a/uv.lock +++ b/uv.lock @@ -12,6 +12,7 @@ name = "amqtt" version = "0.11.0rc1" source = { editable = "." } dependencies = [ + { name = "dacite" }, { name = "passlib" }, { name = "pyyaml" }, { name = "transitions" }, @@ -66,6 +67,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" }, + { name = "dacite", specifier = ">=1.9.2" }, { name = "passlib", specifier = "==1.7.4" }, { name = "pyyaml", specifier = "==6.0.2" }, { name = "transitions", specifier = "==0.9.2" }, @@ -485,6 +487,15 @@ version = "0.9.5" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/8c3ac3d8bc94e6de8d7ae270bb5bc437b210bb9d6d9e46630c98f4abd20c/csscompressor-0.9.5.tar.gz", hash = "sha256:afa22badbcf3120a4f392e4d22f9fff485c044a1feda4a950ecc5eba9dd31a05", size = 237808, upload-time = "2017-11-26T21:13:08.238Z" } +[[package]] +name = "dacite" +version = "1.9.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420, upload-time = "2025-02-05T09:27:29.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600, upload-time = "2025-02-05T09:27:24.345Z" }, +] + [[package]] name = "dill" version = "0.4.0"