From 43efa4c829e5cbaba11fabb727dcbdce47135969 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Thu, 12 Jun 2025 19:12:33 -0400 Subject: [PATCH] instead of filtering upon call of the plugin's coro, filter upon plugin loading --- amqtt/broker.py | 15 +- amqtt/errors.py | 1 - amqtt/plugins/authentication.py | 2 +- amqtt/plugins/base.py | 16 +-- amqtt/plugins/manager.py | 239 ++++++++++++-------------------- amqtt/plugins/sys/broker.py | 2 +- amqtt/plugins/topic_checking.py | 2 +- amqtt/utils.py | 32 ----- pyproject.toml | 6 +- tests/plugins/test_config.py | 117 ---------------- tests/plugins/test_manager.py | 78 ++++++----- 11 files changed, 144 insertions(+), 366 deletions(-) delete mode 100644 tests/plugins/test_config.py diff --git a/amqtt/broker.py b/amqtt/broker.py index f4e6b3e..34ef347 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -689,11 +689,6 @@ class Broker: :param listener: :return: """ - auth_plugins = None - auth_config = self.config.get("auth", None) - if isinstance(auth_config, dict): - auth_plugins = auth_config.get("plugins", None) - # returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins) returns = await self.plugins_manager.map_plugin_auth(session=session) auth_result = True if returns: @@ -765,22 +760,14 @@ class Broker: """ topic_config = self.config.get("topic-check", {}) enabled = False - topic_plugins: list[str] | None = None + if isinstance(topic_config, dict): enabled = topic_config.get("enabled", False) - topic_plugins = topic_config.get("plugins") if not enabled: return True results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action) - # results = await self.plugins_manager.map_plugin_coro( - # "topic_filtering", - # session=session, - # topic=topic, - # action=action, - # filter_plugins=topic_plugins, - # ) return all(result for result in results.values()) async def _delete_session(self, client_id: str) -> None: diff --git a/amqtt/errors.py b/amqtt/errors.py index a82ba07..c54cf43 100644 --- a/amqtt/errors.py +++ b/amqtt/errors.py @@ -34,4 +34,3 @@ class ProtocolHandlerError(Exception): class PluginLoadError(Exception): """Exception thrown when loading a plugin.""" - diff --git a/amqtt/plugins/authentication.py b/amqtt/plugins/authentication.py index 9f8bdd7..90f7eb9 100644 --- a/amqtt/plugins/authentication.py +++ b/amqtt/plugins/authentication.py @@ -11,7 +11,7 @@ from amqtt.session import Session _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line -class BaseAuthPlugin(BasePlugin): +class BaseAuthPlugin(BasePlugin[BaseContext]): """Base class for authentication plugins.""" def __init__(self, context: BaseContext) -> None: diff --git a/amqtt/plugins/base.py b/amqtt/plugins/base.py index d93ceca..e81835c 100644 --- a/amqtt/plugins/base.py +++ b/amqtt/plugins/base.py @@ -1,18 +1,19 @@ -from dataclasses import dataclass -from typing import Any +from typing import Any, Generic, TypeVar from amqtt.plugins.manager import BaseContext +C = TypeVar("C", bound=BaseContext) -class BasePlugin: + +class BasePlugin(Generic[C]): """The base from which all plugins should inherit.""" - def __init__(self, context: BaseContext) -> None: - self.context = context + def __init__(self, context: C) -> None: + self.context: C = context def _get_config_section(self, name: str) -> dict[str, Any] | None: - if not self.context.config or not hasattr(self.context.config, 'get') or self.context.config.get(name, None): + if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None): return None section_config: int | dict[str, Any] | None = self.context.config.get(name, None) @@ -21,6 +22,5 @@ class BasePlugin: return None return section_config - @dataclass - class Config: + async def close(self) -> None: pass diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index 10d94a4..c621e97 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -1,25 +1,22 @@ __all__ = ["BaseContext", "PluginManager", "get_plugin_manager"] import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable import contextlib import copy from importlib.metadata import EntryPoint, EntryPoints, entry_points import logging -from typing import Any, NamedTuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, NamedTuple, Optional -from amqtt.errors import MQTTError, PluginLoadError from amqtt.session import Session -from amqtt.utils import import_string -from dacite import from_dict, Config, DaciteError _LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: - from amqtt.plugins.base import BasePlugin - from amqtt.plugins.authentication import BaseAuthPlugin - from amqtt.plugins.topic_checking import BaseTopicPlugin from amqtt.broker import Action + from amqtt.plugins.authentication import BaseAuthPlugin + from amqtt.plugins.base import BasePlugin + from amqtt.plugins.topic_checking import BaseTopicPlugin class Plugin(NamedTuple): name: str @@ -75,52 +72,31 @@ class PluginManager: return self.context def _load_plugins(self, namespace: str) -> None: - from amqtt.plugins.authentication import BaseAuthPlugin - from amqtt.plugins.topic_checking import BaseTopicPlugin - if 'plugins' in self.app_context.config: - self.logger.info("Loading plugins from config file") - for plugin_info in self.app_context.config['plugins']: + self.logger.debug(f"Loading plugins for namespace {namespace}") - if isinstance(plugin_info, dict): - assert len(plugin_info.keys()) == 1 - plugin_path = list(plugin_info.keys())[0] - plugin_cfg = plugin_info[plugin_path] - plugin = self._load_str_plugin(plugin_path, plugin_cfg) - elif isinstance(plugin_info, str): - plugin = self._load_str_plugin(plugin_info, {}) - else: - msg = 'Unexpected entry in plugins config' - raise PluginLoadError(msg) + auth_filter_list = [] + topic_filter_list = [] + if self.app_context.config and "auth" in self.app_context.config: + auth_filter_list = self.app_context.config["auth"].get("plugins", []) + if self.app_context.config and "topic" in self.app_context.config: + topic_filter_list = self.app_context.config["topic"].get("plugins", []) - self._plugins.append(plugin) - if isinstance(plugin, BaseAuthPlugin): - self._auth_plugins.append(plugin) - if isinstance(plugin, BaseTopicPlugin): - self._topic_plugins.append(plugin) + ep: EntryPoints | list[EntryPoint] = [] + if hasattr(entry_points(), "select"): + ep = entry_points().select(group=namespace) + elif namespace in entry_points(): + ep = [entry_points()[namespace]] - - - else: - self.logger.debug(f"Loading plugins for namespace {namespace}") - - auth_filter_list = self.app_context.config['auth'].get('plugins', []) if 'auth' in self.app_context.config else [] - topic_filter_list = self.app_context.config['topic'].get('plugins', []) if 'topic' in self.app_context.config else [] - ep: EntryPoints | list[EntryPoint] = [] - if hasattr(entry_points(), "select"): - ep = entry_points().select(group=namespace) - elif namespace in entry_points(): - ep = [entry_points()[namespace]] - - for item in ep: - plugin = self._load_ep_plugin(item) - if plugin is not None: - self._plugins.append(plugin.object) - if plugin.name in auth_filter_list: - self._auth_plugins.append(plugin.object) - elif plugin.name in topic_filter_list: - self._topic_plugins.append(plugin.object) - self.logger.debug(f" Plugin {item.name} ready") + for item in ep: + plugin = self._load_ep_plugin(item) + if plugin is not None: + self._plugins.append(plugin.object) + if (not auth_filter_list or plugin.name in auth_filter_list) and hasattr(plugin.object, "authenticate"): + self._auth_plugins.append(plugin.object) + if (not topic_filter_list or plugin.name in topic_filter_list) and hasattr(plugin.object, "topic_filtering"): + self._topic_plugins.append(plugin.object) + self.logger.debug(f" Plugin {item.name} ready") def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None: try: @@ -138,53 +114,27 @@ class PluginManager: return None - def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> 'BasePlugin': - from amqtt.plugins.base import BasePlugin - from amqtt.plugins.authentication import BaseAuthPlugin - from amqtt.plugins.topic_checking import BaseTopicPlugin + def get_plugin(self, name: str) -> Optional["BasePlugin"]: + """Get a plugin by its name from the plugins loaded for the current namespace. - try: - plugin_class = import_string(plugin_path) - except ModuleNotFoundError as ep: - self.logger.error(f"Plugin import failed: {plugin_path}") - raise MQTTError() from ep - - if not issubclass(plugin_class, BasePlugin): - msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'" - raise PluginLoadError(msg) - - plugin_context = copy.copy(self.app_context) - plugin_context.logger = self.logger.getChild(plugin_class.__name__) - try: - plugin_context.config = from_dict(data_class=plugin_class.Config, data=plugin_cfg or {}, config=Config(strict=True)) - except DaciteError as e: - raise PluginLoadError from e - - try: - return plugin_class(plugin_context) - except ImportError as e: - raise PluginLoadError from e - - # def get_plugin(self, name: str) -> Plugin | None: - # """Get a plugin by its name from the plugins loaded for the current namespace. - # - # :param name: - # :return: - # """ - # for p in self._plugins: - # if p.name == name: - # return p - # return None + :param name: + :return: + """ + for p in self._plugins: + self.logger.debug(f"plugin name >>>> {p.__class__.__name__}") + if p.__class__.__name__ == name: + return p + return None async def close(self) -> None: """Free PluginManager resources and cancel pending event methods.""" - await self.map_plugin_coro("close") + await self.map_plugin_close() for task in self._fired_events: task.cancel() self._fired_events.clear() @property - def plugins(self) -> list['BasePlugin']: + def plugins(self) -> list["BasePlugin"]: """Get the loaded plugins list. :return: @@ -229,100 +179,83 @@ class PluginManager: await asyncio.wait(tasks) self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}") - async def map( - self, - coro: Callable[[Plugin, Any], Awaitable[str | bool | None]], - *args: Any, - **kwargs: Any, - ) -> dict[Plugin, str | bool | None]: - """Schedule a given coroutine call for each plugin. + async def map_plugin_auth(self, session: Session) -> dict["BasePlugin", str | bool | None]: + """Schedule a coroutine for plugin 'authenticate' calls. - The coro called gets the Plugin instance as the first argument of its method call. - :param coro: coro to call on each plugin - :param filter_plugins: list of plugin names to filter (only plugin whose name is - in the filter are called). None will call all plugins. [] will call None. - :param args: arguments to pass to coro - :param kwargs: arguments to pass to coro + :param session: the client session associated with the authentication check :return: dict containing return from coro call for each plugin. """ - p_list = kwargs.pop("filter_plugins", None) - if p_list is None: - p_list = [p.name for p in self.plugins] - tasks: list[asyncio.Future[Any]] = [] - plugins_list: list[Plugin] = [] - for plugin in self._plugins: - if plugin.name in p_list: - coro_instance = coro(plugin, *args, **kwargs) - if coro_instance: - try: - tasks.append(self._schedule_coro(coro_instance)) - plugins_list.append(plugin) - except AssertionError: - self.logger.exception(f"Method '{coro!r}' on plugin '{plugin.name}' is not a coroutine") - if tasks: - ret_list = await asyncio.gather(*tasks) - # Create result map plugin => ret - ret_dict = dict(zip(plugins_list, ret_list, strict=False)) - else: - ret_dict = {} - return ret_dict - - @staticmethod - async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> str | bool | None: - if not hasattr(plugin.object, coro_name): - _LOGGER.warning(f"Plugin doesn't implement coro_name '{coro_name}': {plugin.name}") - return None - - coro: Awaitable[str | bool | None] = getattr(plugin.object, coro_name)(*args, **kwargs) - return await coro - - async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, str | bool | None]: - """Call a plugin declared by plugin by its name. - - :param coro_name: - :param args: - :param kwargs: - :return: - """ - return await self.map(self._call_coro, coro_name, *args, **kwargs) - - - async def map_plugin_auth(self, session: Session) -> dict['BaseAuthPlugin', str | bool | None]: - tasks: list[asyncio.Future[Any]] = [] for plugin in self._auth_plugins: - async def auth_coro(p: 'BaseAuthPlugin', s: Session) -> str | bool | None: + async def auth_coro(p: "BaseAuthPlugin", s: Session) -> str | bool | None: return await p.authenticate(session=s) + if not hasattr(plugin, "authenticate"): + continue + coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session) tasks.append(asyncio.ensure_future(coro_instance)) + ret_dict: dict[BasePlugin, str | bool | None] = {} if tasks: ret_list = await asyncio.gather(*tasks) # Create result map plugin => ret ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) - else: - ret_dict = {} + return ret_dict - async def map_plugin_topic(self, session: Session, topic: str, action: 'Action') -> dict['BaseTopicPlugin', str | bool | None]: + async def map_plugin_topic(self, + session: Session, topic: str, action: "Action" + ) -> dict["BasePlugin", str | bool | None]: + """Schedule a coroutine for plugin 'topic_filtering' calls. + :param session: the client session associated with the topic_filtering check + :param topic: the topic that needs to be filtered + :param action: the action being executed + :return: dict containing return from coro call for each plugin. + """ tasks: list[asyncio.Future[Any]] = [] for plugin in self._topic_plugins: - async def topic_coro(p: 'BaseTopicPlugin', s: Session, t: str, a: 'Action') -> str | bool | None: + async def topic_coro(p: "BaseTopicPlugin", s: Session, t: str, a: "Action") -> str | bool | None: return await p.topic_filtering(session=s, topic=t, action=a) + if not hasattr(plugin, "topic_filtering"): + continue + coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action) tasks.append(asyncio.ensure_future(coro_instance)) + ret_dict: dict[BasePlugin, str | bool | None] = {} if tasks: ret_list = await asyncio.gather(*tasks) # Create result map plugin => ret - ret_dict = {dict(zip(self._auth_plugins, ret_list, strict=False))} - else: - ret_dict = {} + ret_dict= dict(zip(self._auth_plugins, ret_list, strict=False)) + + return ret_dict + + async def map_plugin_close(self) -> dict["BasePlugin", str | bool | None]: + + tasks: list[asyncio.Future[Any]] = [] + + for plugin in self._plugins: + + async def close_coro(p: "BasePlugin") -> None: + await p.close() + + if not hasattr(plugin, "close"): + continue + + coro_instance: Awaitable[str | bool | None] = close_coro(plugin) + tasks.append(asyncio.ensure_future(coro_instance)) + + ret_dict: dict[BasePlugin, str | bool | None] = {} + if tasks: + ret_list = await asyncio.gather(*tasks) + # Create result map plugin => ret + ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) + return ret_dict diff --git a/amqtt/plugins/sys/broker.py b/amqtt/plugins/sys/broker.py index 996e399..4f42a5f 100644 --- a/amqtt/plugins/sys/broker.py +++ b/amqtt/plugins/sys/broker.py @@ -42,7 +42,7 @@ STAT_CLIENTS_CONNECTED = "clients_connected" STAT_CLIENTS_DISCONNECTED = "clients_disconnected" -class BrokerSysPlugin(BasePlugin): +class BrokerSysPlugin(BasePlugin[BrokerContext]): def __init__(self, context: BrokerContext) -> None: super().__init__(context) # Broker statistics initialization diff --git a/amqtt/plugins/topic_checking.py b/amqtt/plugins/topic_checking.py index 0309317..d92d672 100644 --- a/amqtt/plugins/topic_checking.py +++ b/amqtt/plugins/topic_checking.py @@ -6,7 +6,7 @@ from amqtt.plugins.manager import BaseContext from amqtt.session import Session -class BaseTopicPlugin(BasePlugin): +class BaseTopicPlugin(BasePlugin[BaseContext]): """Base class for topic plugins.""" def __init__(self, context: BaseContext) -> None: diff --git a/amqtt/utils.py b/amqtt/utils.py index a439e64..ca14ad2 100644 --- a/amqtt/utils.py +++ b/amqtt/utils.py @@ -1,13 +1,10 @@ from __future__ import annotations import logging -import sys -from importlib import import_module from pathlib import Path import secrets import string import typing -from types import ModuleType from typing import Any import yaml @@ -51,32 +48,3 @@ def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None: except yaml.YAMLError: logger.exception(f"Invalid config_file {config_file}") return None - -def cached_import(module_path: str, class_name: str=None) -> ModuleType: - # Check whether module is loaded and fully initialized. - if not ((module := sys.modules.get(module_path)) - and (spec := getattr(module, "__spec__", None)) # noqa - and getattr(spec, "_initializing", False) is False): # noqa - module = import_module(module_path) - if class_name: - return getattr(module, class_name) - return module - - -# TODO : figure out proper return type -def import_string(dotted_path) -> Any: - """ - Import a dotted module path and return the attribute/class designated by the - last name in the path. Raise ImportError if the import failed. - """ - try: - module_path, class_name = dotted_path.rsplit(".", 1) - except ValueError as err: - raise ImportError(f"{dotted_path} doesn't look like a module path") from err - - try: - return cached_import(module_path, class_name) - except AttributeError as err: - raise ImportError( - f'Module "{module_path}" does not define a "{class_name}" attribute/class' - ) from err diff --git a/pyproject.toml b/pyproject.toml index bf38122..7420115 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,13 +185,13 @@ max-returns = 10 # ----------------------------------- PYTEST ----------------------------------- [tool.pytest.ini_options] -#addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"] +addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"] testpaths = ["tests"] asyncio_mode = "auto" timeout = 10 asyncio_default_fixture_loop_scope = "function" -addopts = ["--tb=short", "--capture=tee-sys"] -log_cli = true +#addopts = ["--tb=short", "--capture=tee-sys"] +#log_cli = true log_level = "DEBUG" # ------------------------------------ MYPY ------------------------------------ diff --git a/tests/plugins/test_config.py b/tests/plugins/test_config.py deleted file mode 100644 index 9d4b38a..0000000 --- a/tests/plugins/test_config.py +++ /dev/null @@ -1,117 +0,0 @@ -import asyncio -import logging -from dataclasses import dataclass, field -from typing import Any - -import pytest -import yaml - -from amqtt.broker import Broker -from yaml import CLoader as Loader -from dacite import from_dict, Config, UnexpectedDataError - -from amqtt.client import MQTTClient -from amqtt.errors import PluginLoadError - -logger = logging.getLogger(__name__) - -plugin_config = """--- -listeners: - default: - type: tcp - bind: 0.0.0.0:1883 -plugins: - - tests.plugins.mocks.TestSimplePlugin: - - tests.plugins.mocks.TestConfigPlugin: - option1: 1 - option2: bar -""" - - -plugin_invalid_config_one = """--- -listeners: - default: - type: tcp - bind: 0.0.0.0:1883 -plugins: - - tests.plugins.mocks.TestSimplePlugin: - option1: 1 - option2: bar -""" - -plugin_invalid_config_two = """--- -listeners: - default: - type: tcp - bind: 0.0.0.0:1883 -plugins: - - tests.plugins.mocks.TestConfigPlugin: -""" - -plugin_config_auth = """--- -listeners: - default: - type: tcp - bind: 0.0.0.0:1883 -plugins: - - tests.plugins.mocks.TestAuthPlugin: -""" - -plugin_config_topic = """--- -listeners: - default: - type: tcp - bind: 0.0.0.0:1883 -plugins: - - tests.plugins.mocks.TestTopicPlugin: -""" - -@pytest.mark.asyncio -async def test_plugin_config_extra_fields(): - - cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader) - - with pytest.raises(PluginLoadError): - _ = Broker(config=cfg) - - -@pytest.mark.asyncio -async def test_plugin_config_missing_fields(): - cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader) - - with pytest.raises(PluginLoadError): - _ = Broker(config=cfg) - - -@pytest.mark.asyncio -async def test_alternate_plugin_load(): - - cfg: dict[str, Any] = yaml.load(plugin_config, Loader=Loader) - - broker = Broker(config=cfg) - await broker.start() - await broker.shutdown() - - -@pytest.mark.asyncio -async def test_auth_plugin_load(): - cfg: dict[str, Any] = yaml.load(plugin_config_auth, Loader=Loader) - broker = Broker(config=cfg) - await broker.start() - await asyncio.sleep(0.5) - - client1 = MQTTClient() - await client1.connect() - await client1.publish('my/topic', b'my message') - await client1.disconnect() - - await asyncio.sleep(0.5) - await broker.shutdown() - - -@pytest.mark.asyncio -async def test_topic_plugin_load(): - cfg: dict[str, Any] = yaml.load(plugin_config_topic, Loader=Loader) - broker = Broker(config=cfg) - await broker.start() - await broker.shutdown() diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_manager.py index 33b0d73..14a5ff6 100644 --- a/tests/plugins/test_manager.py +++ b/tests/plugins/test_manager.py @@ -1,9 +1,12 @@ import asyncio import logging -from typing import Any import unittest -from amqtt.plugins.manager import BaseContext, Plugin, PluginManager +from amqtt.broker import Action +from amqtt.plugins.authentication import BaseAuthPlugin +from amqtt.plugins.manager import BaseContext, PluginManager +from amqtt.plugins.topic_checking import BaseTopicPlugin +from amqtt.session import Session formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" logging.basicConfig(level=logging.INFO, format=formatter) @@ -14,21 +17,29 @@ class EmptyTestPlugin: self.context = context -class EventTestPlugin: +class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin): def __init__(self, context: BaseContext) -> None: - self.context = context - self.test_flag = False - self.coro_flag = False + super().__init__(context) + self.test_close_flag = False + self.test_auth_flag = False + self.test_topic_flag = False + self.test_event_flag = False async def on_test(self) -> None: - self.test_flag = True - self.context.logger.info("on_test") + self.test_event_flag = True - async def test_coro(self) -> None: - self.coro_flag = True + async def authenticate(self, *, session: Session) -> bool | None: + self.test_auth_flag = True + return None - async def ret_coro(self) -> str: - return "TEST" + async def topic_filtering( + self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None + ) -> bool: + self.test_topic_flag = True + return False + + async def close(self) -> None: + self.test_close_flag = True class TestPluginManager(unittest.TestCase): @@ -47,9 +58,9 @@ class TestPluginManager(unittest.TestCase): manager = PluginManager("amqtt.test.plugins", context=None) self.loop.run_until_complete(fire_event()) - plugin = manager.get_plugin("event_plugin") + plugin = manager.get_plugin("EventTestPlugin") assert plugin is not None - assert plugin.object.test_flag + assert plugin.test_event_flag def test_fire_event_wait(self) -> None: async def fire_event() -> None: @@ -58,36 +69,33 @@ class TestPluginManager(unittest.TestCase): manager = PluginManager("amqtt.test.plugins", context=None) self.loop.run_until_complete(fire_event()) - plugin = manager.get_plugin("event_plugin") + plugin = manager.get_plugin("EventTestPlugin") assert plugin is not None - assert plugin.object.test_flag + assert plugin.test_event_flag - def test_map_coro(self) -> None: - async def call_coro() -> None: - await manager.map_plugin_coro("test_coro") + def test_plugin_close_coro(self) -> None: manager = PluginManager("amqtt.test.plugins", context=None) - self.loop.run_until_complete(call_coro()) - plugin = manager.get_plugin("event_plugin") + self.loop.run_until_complete(manager.map_plugin_close()) + self.loop.run_until_complete(asyncio.sleep(0.5)) + plugin = manager.get_plugin("EventTestPlugin") assert plugin is not None - assert plugin.object.test_coro + assert plugin.test_close_flag - def test_map_coro_return(self) -> None: - async def call_coro() -> dict[Plugin, str]: - return await manager.map_plugin_coro("ret_coro") + def test_plugin_auth_coro(self) -> None: manager = PluginManager("amqtt.test.plugins", context=None) - ret = self.loop.run_until_complete(call_coro()) - plugin = manager.get_plugin("event_plugin") + self.loop.run_until_complete(manager.map_plugin_auth(session=Session())) + self.loop.run_until_complete(asyncio.sleep(0.5)) + plugin = manager.get_plugin("EventTestPlugin") assert plugin is not None - assert ret[plugin] == "TEST" + assert plugin.test_auth_flag - def test_map_coro_filter(self) -> None: - """Run plugin coro but expect no return as an empty filter is given.""" - - async def call_coro() -> dict[Plugin, Any]: - return await manager.map_plugin_coro("ret_coro", filter_plugins=[]) + def test_plugin_topic_coro(self) -> None: manager = PluginManager("amqtt.test.plugins", context=None) - ret = self.loop.run_until_complete(call_coro()) - assert len(ret) == 0 + self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH)) + self.loop.run_until_complete(asyncio.sleep(0.5)) + plugin = manager.get_plugin("EventTestPlugin") + assert plugin is not None + assert plugin.test_topic_flag