instead of filtering upon call of the plugin's coro, filter upon plugin loading

pull/212/head
Andrew Mirsky 2025-06-12 19:12:33 -04:00
rodzic 7b936d785c
commit 43efa4c829
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
11 zmienionych plików z 144 dodań i 366 usunięć

Wyświetl plik

@ -689,11 +689,6 @@ class Broker:
:param listener: :param listener:
:return: :return:
""" """
auth_plugins = None
auth_config = self.config.get("auth", None)
if isinstance(auth_config, dict):
auth_plugins = auth_config.get("plugins", None)
# returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
returns = await self.plugins_manager.map_plugin_auth(session=session) returns = await self.plugins_manager.map_plugin_auth(session=session)
auth_result = True auth_result = True
if returns: if returns:
@ -765,22 +760,14 @@ class Broker:
""" """
topic_config = self.config.get("topic-check", {}) topic_config = self.config.get("topic-check", {})
enabled = False enabled = False
topic_plugins: list[str] | None = None
if isinstance(topic_config, dict): if isinstance(topic_config, dict):
enabled = topic_config.get("enabled", False) enabled = topic_config.get("enabled", False)
topic_plugins = topic_config.get("plugins")
if not enabled: if not enabled:
return True return True
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action) results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
# results = await self.plugins_manager.map_plugin_coro(
# "topic_filtering",
# session=session,
# topic=topic,
# action=action,
# filter_plugins=topic_plugins,
# )
return all(result for result in results.values()) return all(result for result in results.values())
async def _delete_session(self, client_id: str) -> None: async def _delete_session(self, client_id: str) -> None:

Wyświetl plik

@ -34,4 +34,3 @@ class ProtocolHandlerError(Exception):
class PluginLoadError(Exception): class PluginLoadError(Exception):
"""Exception thrown when loading a plugin.""" """Exception thrown when loading a plugin."""

Wyświetl plik

@ -11,7 +11,7 @@ from amqtt.session import Session
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
class BaseAuthPlugin(BasePlugin): class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins.""" """Base class for authentication plugins."""
def __init__(self, context: BaseContext) -> None: def __init__(self, context: BaseContext) -> None:

Wyświetl plik

@ -1,18 +1,19 @@
from dataclasses import dataclass from typing import Any, Generic, TypeVar
from typing import Any
from amqtt.plugins.manager import BaseContext from amqtt.plugins.manager import BaseContext
C = TypeVar("C", bound=BaseContext)
class BasePlugin:
class BasePlugin(Generic[C]):
"""The base from which all plugins should inherit.""" """The base from which all plugins should inherit."""
def __init__(self, context: BaseContext) -> None: def __init__(self, context: C) -> None:
self.context = context self.context: C = context
def _get_config_section(self, name: str) -> dict[str, Any] | None: def _get_config_section(self, name: str) -> dict[str, Any] | None:
if not self.context.config or not hasattr(self.context.config, 'get') or self.context.config.get(name, None): if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None):
return None return None
section_config: int | dict[str, Any] | None = self.context.config.get(name, None) section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
@ -21,6 +22,5 @@ class BasePlugin:
return None return None
return section_config return section_config
@dataclass async def close(self) -> None:
class Config:
pass pass

Wyświetl plik

@ -1,25 +1,22 @@
__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"] __all__ = ["BaseContext", "PluginManager", "get_plugin_manager"]
import asyncio import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable
import contextlib import contextlib
import copy import copy
from importlib.metadata import EntryPoint, EntryPoints, entry_points from importlib.metadata import EntryPoint, EntryPoints, entry_points
import logging import logging
from typing import Any, NamedTuple, TYPE_CHECKING from typing import TYPE_CHECKING, Any, NamedTuple, Optional
from amqtt.errors import MQTTError, PluginLoadError
from amqtt.session import Session from amqtt.session import Session
from amqtt.utils import import_string
from dacite import from_dict, Config, DaciteError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.broker import Action from amqtt.broker import Action
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
class Plugin(NamedTuple): class Plugin(NamedTuple):
name: str name: str
@ -75,52 +72,31 @@ class PluginManager:
return self.context return self.context
def _load_plugins(self, namespace: str) -> None: def _load_plugins(self, namespace: str) -> None:
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
if 'plugins' in self.app_context.config: self.logger.debug(f"Loading plugins for namespace {namespace}")
self.logger.info("Loading plugins from config file")
for plugin_info in self.app_context.config['plugins']:
if isinstance(plugin_info, dict): auth_filter_list = []
assert len(plugin_info.keys()) == 1 topic_filter_list = []
plugin_path = list(plugin_info.keys())[0] if self.app_context.config and "auth" in self.app_context.config:
plugin_cfg = plugin_info[plugin_path] auth_filter_list = self.app_context.config["auth"].get("plugins", [])
plugin = self._load_str_plugin(plugin_path, plugin_cfg) if self.app_context.config and "topic" in self.app_context.config:
elif isinstance(plugin_info, str): topic_filter_list = self.app_context.config["topic"].get("plugins", [])
plugin = self._load_str_plugin(plugin_info, {})
else:
msg = 'Unexpected entry in plugins config'
raise PluginLoadError(msg)
self._plugins.append(plugin) ep: EntryPoints | list[EntryPoint] = []
if isinstance(plugin, BaseAuthPlugin): if hasattr(entry_points(), "select"):
self._auth_plugins.append(plugin) ep = entry_points().select(group=namespace)
if isinstance(plugin, BaseTopicPlugin): elif namespace in entry_points():
self._topic_plugins.append(plugin) ep = [entry_points()[namespace]]
for item in ep:
plugin = self._load_ep_plugin(item)
else: if plugin is not None:
self.logger.debug(f"Loading plugins for namespace {namespace}") self._plugins.append(plugin.object)
if (not auth_filter_list or plugin.name in auth_filter_list) and hasattr(plugin.object, "authenticate"):
auth_filter_list = self.app_context.config['auth'].get('plugins', []) if 'auth' in self.app_context.config else [] self._auth_plugins.append(plugin.object)
topic_filter_list = self.app_context.config['topic'].get('plugins', []) if 'topic' in self.app_context.config else [] if (not topic_filter_list or plugin.name in topic_filter_list) and hasattr(plugin.object, "topic_filtering"):
ep: EntryPoints | list[EntryPoint] = [] self._topic_plugins.append(plugin.object)
if hasattr(entry_points(), "select"): self.logger.debug(f" Plugin {item.name} ready")
ep = entry_points().select(group=namespace)
elif namespace in entry_points():
ep = [entry_points()[namespace]]
for item in ep:
plugin = self._load_ep_plugin(item)
if plugin is not None:
self._plugins.append(plugin.object)
if plugin.name in auth_filter_list:
self._auth_plugins.append(plugin.object)
elif plugin.name in topic_filter_list:
self._topic_plugins.append(plugin.object)
self.logger.debug(f" Plugin {item.name} ready")
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None: def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
try: try:
@ -138,53 +114,27 @@ class PluginManager:
return None return None
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> 'BasePlugin': def get_plugin(self, name: str) -> Optional["BasePlugin"]:
from amqtt.plugins.base import BasePlugin """Get a plugin by its name from the plugins loaded for the current namespace.
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
try: :param name:
plugin_class = import_string(plugin_path) :return:
except ModuleNotFoundError as ep: """
self.logger.error(f"Plugin import failed: {plugin_path}") for p in self._plugins:
raise MQTTError() from ep self.logger.debug(f"plugin name >>>> {p.__class__.__name__}")
if p.__class__.__name__ == name:
if not issubclass(plugin_class, BasePlugin): return p
msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'" return None
raise PluginLoadError(msg)
plugin_context = copy.copy(self.app_context)
plugin_context.logger = self.logger.getChild(plugin_class.__name__)
try:
plugin_context.config = from_dict(data_class=plugin_class.Config, data=plugin_cfg or {}, config=Config(strict=True))
except DaciteError as e:
raise PluginLoadError from e
try:
return plugin_class(plugin_context)
except ImportError as e:
raise PluginLoadError from e
# def get_plugin(self, name: str) -> Plugin | None:
# """Get a plugin by its name from the plugins loaded for the current namespace.
#
# :param name:
# :return:
# """
# for p in self._plugins:
# if p.name == name:
# return p
# return None
async def close(self) -> None: async def close(self) -> None:
"""Free PluginManager resources and cancel pending event methods.""" """Free PluginManager resources and cancel pending event methods."""
await self.map_plugin_coro("close") await self.map_plugin_close()
for task in self._fired_events: for task in self._fired_events:
task.cancel() task.cancel()
self._fired_events.clear() self._fired_events.clear()
@property @property
def plugins(self) -> list['BasePlugin']: def plugins(self) -> list["BasePlugin"]:
"""Get the loaded plugins list. """Get the loaded plugins list.
:return: :return:
@ -229,100 +179,83 @@ class PluginManager:
await asyncio.wait(tasks) await asyncio.wait(tasks)
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}") self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
async def map( async def map_plugin_auth(self, session: Session) -> dict["BasePlugin", str | bool | None]:
self, """Schedule a coroutine for plugin 'authenticate' calls.
coro: Callable[[Plugin, Any], Awaitable[str | bool | None]],
*args: Any,
**kwargs: Any,
) -> dict[Plugin, str | bool | None]:
"""Schedule a given coroutine call for each plugin.
The coro called gets the Plugin instance as the first argument of its method call. :param session: the client session associated with the authentication check
:param coro: coro to call on each plugin
:param filter_plugins: list of plugin names to filter (only plugin whose name is
in the filter are called). None will call all plugins. [] will call None.
:param args: arguments to pass to coro
:param kwargs: arguments to pass to coro
:return: dict containing return from coro call for each plugin. :return: dict containing return from coro call for each plugin.
""" """
p_list = kwargs.pop("filter_plugins", None)
if p_list is None:
p_list = [p.name for p in self.plugins]
tasks: list[asyncio.Future[Any]] = []
plugins_list: list[Plugin] = []
for plugin in self._plugins:
if plugin.name in p_list:
coro_instance = coro(plugin, *args, **kwargs)
if coro_instance:
try:
tasks.append(self._schedule_coro(coro_instance))
plugins_list.append(plugin)
except AssertionError:
self.logger.exception(f"Method '{coro!r}' on plugin '{plugin.name}' is not a coroutine")
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = dict(zip(plugins_list, ret_list, strict=False))
else:
ret_dict = {}
return ret_dict
@staticmethod
async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> str | bool | None:
if not hasattr(plugin.object, coro_name):
_LOGGER.warning(f"Plugin doesn't implement coro_name '{coro_name}': {plugin.name}")
return None
coro: Awaitable[str | bool | None] = getattr(plugin.object, coro_name)(*args, **kwargs)
return await coro
async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, str | bool | None]:
"""Call a plugin declared by plugin by its name.
:param coro_name:
:param args:
:param kwargs:
:return:
"""
return await self.map(self._call_coro, coro_name, *args, **kwargs)
async def map_plugin_auth(self, session: Session) -> dict['BaseAuthPlugin', str | bool | None]:
tasks: list[asyncio.Future[Any]] = [] tasks: list[asyncio.Future[Any]] = []
for plugin in self._auth_plugins: for plugin in self._auth_plugins:
async def auth_coro(p: 'BaseAuthPlugin', s: Session) -> str | bool | None: async def auth_coro(p: "BaseAuthPlugin", s: Session) -> str | bool | None:
return await p.authenticate(session=s) return await p.authenticate(session=s)
if not hasattr(plugin, "authenticate"):
continue
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session) coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
tasks.append(asyncio.ensure_future(coro_instance)) tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {}
if tasks: if tasks:
ret_list = await asyncio.gather(*tasks) ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret # Create result map plugin => ret
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False)) ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
else:
ret_dict = {}
return ret_dict return ret_dict
async def map_plugin_topic(self, session: Session, topic: str, action: 'Action') -> dict['BaseTopicPlugin', str | bool | None]: async def map_plugin_topic(self,
session: Session, topic: str, action: "Action"
) -> dict["BasePlugin", str | bool | None]:
"""Schedule a coroutine for plugin 'topic_filtering' calls.
:param session: the client session associated with the topic_filtering check
:param topic: the topic that needs to be filtered
:param action: the action being executed
:return: dict containing return from coro call for each plugin.
"""
tasks: list[asyncio.Future[Any]] = [] tasks: list[asyncio.Future[Any]] = []
for plugin in self._topic_plugins: for plugin in self._topic_plugins:
async def topic_coro(p: 'BaseTopicPlugin', s: Session, t: str, a: 'Action') -> str | bool | None: async def topic_coro(p: "BaseTopicPlugin", s: Session, t: str, a: "Action") -> str | bool | None:
return await p.topic_filtering(session=s, topic=t, action=a) return await p.topic_filtering(session=s, topic=t, action=a)
if not hasattr(plugin, "topic_filtering"):
continue
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action) coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
tasks.append(asyncio.ensure_future(coro_instance)) tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {}
if tasks: if tasks:
ret_list = await asyncio.gather(*tasks) ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret # Create result map plugin => ret
ret_dict = {dict(zip(self._auth_plugins, ret_list, strict=False))} ret_dict= dict(zip(self._auth_plugins, ret_list, strict=False))
else:
ret_dict = {} return ret_dict
async def map_plugin_close(self) -> dict["BasePlugin", str | bool | None]:
tasks: list[asyncio.Future[Any]] = []
for plugin in self._plugins:
async def close_coro(p: "BasePlugin") -> None:
await p.close()
if not hasattr(plugin, "close"):
continue
coro_instance: Awaitable[str | bool | None] = close_coro(plugin)
tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {}
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
return ret_dict return ret_dict

Wyświetl plik

@ -42,7 +42,7 @@ STAT_CLIENTS_CONNECTED = "clients_connected"
STAT_CLIENTS_DISCONNECTED = "clients_disconnected" STAT_CLIENTS_DISCONNECTED = "clients_disconnected"
class BrokerSysPlugin(BasePlugin): class BrokerSysPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None: def __init__(self, context: BrokerContext) -> None:
super().__init__(context) super().__init__(context)
# Broker statistics initialization # Broker statistics initialization

Wyświetl plik

@ -6,7 +6,7 @@ from amqtt.plugins.manager import BaseContext
from amqtt.session import Session from amqtt.session import Session
class BaseTopicPlugin(BasePlugin): class BaseTopicPlugin(BasePlugin[BaseContext]):
"""Base class for topic plugins.""" """Base class for topic plugins."""
def __init__(self, context: BaseContext) -> None: def __init__(self, context: BaseContext) -> None:

Wyświetl plik

@ -1,13 +1,10 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import sys
from importlib import import_module
from pathlib import Path from pathlib import Path
import secrets import secrets
import string import string
import typing import typing
from types import ModuleType
from typing import Any from typing import Any
import yaml import yaml
@ -51,32 +48,3 @@ def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None:
except yaml.YAMLError: except yaml.YAMLError:
logger.exception(f"Invalid config_file {config_file}") logger.exception(f"Invalid config_file {config_file}")
return None return None
def cached_import(module_path: str, class_name: str=None) -> ModuleType:
# Check whether module is loaded and fully initialized.
if not ((module := sys.modules.get(module_path))
and (spec := getattr(module, "__spec__", None)) # noqa
and getattr(spec, "_initializing", False) is False): # noqa
module = import_module(module_path)
if class_name:
return getattr(module, class_name)
return module
# TODO : figure out proper return type
def import_string(dotted_path) -> Any:
"""
Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import failed.
"""
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError as err:
raise ImportError(f"{dotted_path} doesn't look like a module path") from err
try:
return cached_import(module_path, class_name)
except AttributeError as err:
raise ImportError(
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
) from err

Wyświetl plik

@ -185,13 +185,13 @@ max-returns = 10
# ----------------------------------- PYTEST ----------------------------------- # ----------------------------------- PYTEST -----------------------------------
[tool.pytest.ini_options] [tool.pytest.ini_options]
#addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"] addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
testpaths = ["tests"] testpaths = ["tests"]
asyncio_mode = "auto" asyncio_mode = "auto"
timeout = 10 timeout = 10
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"
addopts = ["--tb=short", "--capture=tee-sys"] #addopts = ["--tb=short", "--capture=tee-sys"]
log_cli = true #log_cli = true
log_level = "DEBUG" log_level = "DEBUG"
# ------------------------------------ MYPY ------------------------------------ # ------------------------------------ MYPY ------------------------------------

Wyświetl plik

@ -1,117 +0,0 @@
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
import pytest
import yaml
from amqtt.broker import Broker
from yaml import CLoader as Loader
from dacite import from_dict, Config, UnexpectedDataError
from amqtt.client import MQTTClient
from amqtt.errors import PluginLoadError
logger = logging.getLogger(__name__)
plugin_config = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestSimplePlugin:
- tests.plugins.mocks.TestConfigPlugin:
option1: 1
option2: bar
"""
plugin_invalid_config_one = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestSimplePlugin:
option1: 1
option2: bar
"""
plugin_invalid_config_two = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestConfigPlugin:
"""
plugin_config_auth = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestAuthPlugin:
"""
plugin_config_topic = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestTopicPlugin:
"""
@pytest.mark.asyncio
async def test_plugin_config_extra_fields():
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
with pytest.raises(PluginLoadError):
_ = Broker(config=cfg)
@pytest.mark.asyncio
async def test_plugin_config_missing_fields():
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
with pytest.raises(PluginLoadError):
_ = Broker(config=cfg)
@pytest.mark.asyncio
async def test_alternate_plugin_load():
cfg: dict[str, Any] = yaml.load(plugin_config, Loader=Loader)
broker = Broker(config=cfg)
await broker.start()
await broker.shutdown()
@pytest.mark.asyncio
async def test_auth_plugin_load():
cfg: dict[str, Any] = yaml.load(plugin_config_auth, Loader=Loader)
broker = Broker(config=cfg)
await broker.start()
await asyncio.sleep(0.5)
client1 = MQTTClient()
await client1.connect()
await client1.publish('my/topic', b'my message')
await client1.disconnect()
await asyncio.sleep(0.5)
await broker.shutdown()
@pytest.mark.asyncio
async def test_topic_plugin_load():
cfg: dict[str, Any] = yaml.load(plugin_config_topic, Loader=Loader)
broker = Broker(config=cfg)
await broker.start()
await broker.shutdown()

Wyświetl plik

@ -1,9 +1,12 @@
import asyncio import asyncio
import logging import logging
from typing import Any
import unittest import unittest
from amqtt.plugins.manager import BaseContext, Plugin, PluginManager from amqtt.broker import Action
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.session import Session
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter) logging.basicConfig(level=logging.INFO, format=formatter)
@ -14,21 +17,29 @@ class EmptyTestPlugin:
self.context = context self.context = context
class EventTestPlugin: class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin):
def __init__(self, context: BaseContext) -> None: def __init__(self, context: BaseContext) -> None:
self.context = context super().__init__(context)
self.test_flag = False self.test_close_flag = False
self.coro_flag = False self.test_auth_flag = False
self.test_topic_flag = False
self.test_event_flag = False
async def on_test(self) -> None: async def on_test(self) -> None:
self.test_flag = True self.test_event_flag = True
self.context.logger.info("on_test")
async def test_coro(self) -> None: async def authenticate(self, *, session: Session) -> bool | None:
self.coro_flag = True self.test_auth_flag = True
return None
async def ret_coro(self) -> str: async def topic_filtering(
return "TEST" self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
self.test_topic_flag = True
return False
async def close(self) -> None:
self.test_close_flag = True
class TestPluginManager(unittest.TestCase): class TestPluginManager(unittest.TestCase):
@ -47,9 +58,9 @@ class TestPluginManager(unittest.TestCase):
manager = PluginManager("amqtt.test.plugins", context=None) manager = PluginManager("amqtt.test.plugins", context=None)
self.loop.run_until_complete(fire_event()) self.loop.run_until_complete(fire_event())
plugin = manager.get_plugin("event_plugin") plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None assert plugin is not None
assert plugin.object.test_flag assert plugin.test_event_flag
def test_fire_event_wait(self) -> None: def test_fire_event_wait(self) -> None:
async def fire_event() -> None: async def fire_event() -> None:
@ -58,36 +69,33 @@ class TestPluginManager(unittest.TestCase):
manager = PluginManager("amqtt.test.plugins", context=None) manager = PluginManager("amqtt.test.plugins", context=None)
self.loop.run_until_complete(fire_event()) self.loop.run_until_complete(fire_event())
plugin = manager.get_plugin("event_plugin") plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None assert plugin is not None
assert plugin.object.test_flag assert plugin.test_event_flag
def test_map_coro(self) -> None: def test_plugin_close_coro(self) -> None:
async def call_coro() -> None:
await manager.map_plugin_coro("test_coro")
manager = PluginManager("amqtt.test.plugins", context=None) manager = PluginManager("amqtt.test.plugins", context=None)
self.loop.run_until_complete(call_coro()) self.loop.run_until_complete(manager.map_plugin_close())
plugin = manager.get_plugin("event_plugin") self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None assert plugin is not None
assert plugin.object.test_coro assert plugin.test_close_flag
def test_map_coro_return(self) -> None: def test_plugin_auth_coro(self) -> None:
async def call_coro() -> dict[Plugin, str]:
return await manager.map_plugin_coro("ret_coro")
manager = PluginManager("amqtt.test.plugins", context=None) manager = PluginManager("amqtt.test.plugins", context=None)
ret = self.loop.run_until_complete(call_coro()) self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
plugin = manager.get_plugin("event_plugin") self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None assert plugin is not None
assert ret[plugin] == "TEST" assert plugin.test_auth_flag
def test_map_coro_filter(self) -> None: def test_plugin_topic_coro(self) -> None:
"""Run plugin coro but expect no return as an empty filter is given."""
async def call_coro() -> dict[Plugin, Any]:
return await manager.map_plugin_coro("ret_coro", filter_plugins=[])
manager = PluginManager("amqtt.test.plugins", context=None) manager = PluginManager("amqtt.test.plugins", context=None)
ret = self.loop.run_until_complete(call_coro()) self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
assert len(ret) == 0 self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert plugin.test_topic_flag