instead of filtering upon call of the plugin's coro, filter upon plugin loading

pull/212/head
Andrew Mirsky 2025-06-12 19:12:33 -04:00
rodzic 7b936d785c
commit 43efa4c829
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
11 zmienionych plików z 144 dodań i 366 usunięć

Wyświetl plik

@ -689,11 +689,6 @@ class Broker:
:param listener:
:return:
"""
auth_plugins = None
auth_config = self.config.get("auth", None)
if isinstance(auth_config, dict):
auth_plugins = auth_config.get("plugins", None)
# returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
returns = await self.plugins_manager.map_plugin_auth(session=session)
auth_result = True
if returns:
@ -765,22 +760,14 @@ class Broker:
"""
topic_config = self.config.get("topic-check", {})
enabled = False
topic_plugins: list[str] | None = None
if isinstance(topic_config, dict):
enabled = topic_config.get("enabled", False)
topic_plugins = topic_config.get("plugins")
if not enabled:
return True
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
# results = await self.plugins_manager.map_plugin_coro(
# "topic_filtering",
# session=session,
# topic=topic,
# action=action,
# filter_plugins=topic_plugins,
# )
return all(result for result in results.values())
async def _delete_session(self, client_id: str) -> None:

Wyświetl plik

@ -34,4 +34,3 @@ class ProtocolHandlerError(Exception):
class PluginLoadError(Exception):
"""Exception thrown when loading a plugin."""

Wyświetl plik

@ -11,7 +11,7 @@ from amqtt.session import Session
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
class BaseAuthPlugin(BasePlugin):
class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins."""
def __init__(self, context: BaseContext) -> None:

Wyświetl plik

@ -1,18 +1,19 @@
from dataclasses import dataclass
from typing import Any
from typing import Any, Generic, TypeVar
from amqtt.plugins.manager import BaseContext
C = TypeVar("C", bound=BaseContext)
class BasePlugin:
class BasePlugin(Generic[C]):
"""The base from which all plugins should inherit."""
def __init__(self, context: BaseContext) -> None:
self.context = context
def __init__(self, context: C) -> None:
self.context: C = context
def _get_config_section(self, name: str) -> dict[str, Any] | None:
if not self.context.config or not hasattr(self.context.config, 'get') or self.context.config.get(name, None):
if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None):
return None
section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
@ -21,6 +22,5 @@ class BasePlugin:
return None
return section_config
@dataclass
class Config:
async def close(self) -> None:
pass

Wyświetl plik

@ -1,25 +1,22 @@
__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"]
import asyncio
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable
import contextlib
import copy
from importlib.metadata import EntryPoint, EntryPoints, entry_points
import logging
from typing import Any, NamedTuple, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
from amqtt.errors import MQTTError, PluginLoadError
from amqtt.session import Session
from amqtt.utils import import_string
from dacite import from_dict, Config, DaciteError
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.broker import Action
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
class Plugin(NamedTuple):
name: str
@ -75,52 +72,31 @@ class PluginManager:
return self.context
def _load_plugins(self, namespace: str) -> None:
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
if 'plugins' in self.app_context.config:
self.logger.info("Loading plugins from config file")
for plugin_info in self.app_context.config['plugins']:
self.logger.debug(f"Loading plugins for namespace {namespace}")
if isinstance(plugin_info, dict):
assert len(plugin_info.keys()) == 1
plugin_path = list(plugin_info.keys())[0]
plugin_cfg = plugin_info[plugin_path]
plugin = self._load_str_plugin(plugin_path, plugin_cfg)
elif isinstance(plugin_info, str):
plugin = self._load_str_plugin(plugin_info, {})
else:
msg = 'Unexpected entry in plugins config'
raise PluginLoadError(msg)
auth_filter_list = []
topic_filter_list = []
if self.app_context.config and "auth" in self.app_context.config:
auth_filter_list = self.app_context.config["auth"].get("plugins", [])
if self.app_context.config and "topic" in self.app_context.config:
topic_filter_list = self.app_context.config["topic"].get("plugins", [])
self._plugins.append(plugin)
if isinstance(plugin, BaseAuthPlugin):
self._auth_plugins.append(plugin)
if isinstance(plugin, BaseTopicPlugin):
self._topic_plugins.append(plugin)
ep: EntryPoints | list[EntryPoint] = []
if hasattr(entry_points(), "select"):
ep = entry_points().select(group=namespace)
elif namespace in entry_points():
ep = [entry_points()[namespace]]
else:
self.logger.debug(f"Loading plugins for namespace {namespace}")
auth_filter_list = self.app_context.config['auth'].get('plugins', []) if 'auth' in self.app_context.config else []
topic_filter_list = self.app_context.config['topic'].get('plugins', []) if 'topic' in self.app_context.config else []
ep: EntryPoints | list[EntryPoint] = []
if hasattr(entry_points(), "select"):
ep = entry_points().select(group=namespace)
elif namespace in entry_points():
ep = [entry_points()[namespace]]
for item in ep:
plugin = self._load_ep_plugin(item)
if plugin is not None:
self._plugins.append(plugin.object)
if plugin.name in auth_filter_list:
self._auth_plugins.append(plugin.object)
elif plugin.name in topic_filter_list:
self._topic_plugins.append(plugin.object)
self.logger.debug(f" Plugin {item.name} ready")
for item in ep:
plugin = self._load_ep_plugin(item)
if plugin is not None:
self._plugins.append(plugin.object)
if (not auth_filter_list or plugin.name in auth_filter_list) and hasattr(plugin.object, "authenticate"):
self._auth_plugins.append(plugin.object)
if (not topic_filter_list or plugin.name in topic_filter_list) and hasattr(plugin.object, "topic_filtering"):
self._topic_plugins.append(plugin.object)
self.logger.debug(f" Plugin {item.name} ready")
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
try:
@ -138,53 +114,27 @@ class PluginManager:
return None
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> 'BasePlugin':
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
def get_plugin(self, name: str) -> Optional["BasePlugin"]:
"""Get a plugin by its name from the plugins loaded for the current namespace.
try:
plugin_class = import_string(plugin_path)
except ModuleNotFoundError as ep:
self.logger.error(f"Plugin import failed: {plugin_path}")
raise MQTTError() from ep
if not issubclass(plugin_class, BasePlugin):
msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'"
raise PluginLoadError(msg)
plugin_context = copy.copy(self.app_context)
plugin_context.logger = self.logger.getChild(plugin_class.__name__)
try:
plugin_context.config = from_dict(data_class=plugin_class.Config, data=plugin_cfg or {}, config=Config(strict=True))
except DaciteError as e:
raise PluginLoadError from e
try:
return plugin_class(plugin_context)
except ImportError as e:
raise PluginLoadError from e
# def get_plugin(self, name: str) -> Plugin | None:
# """Get a plugin by its name from the plugins loaded for the current namespace.
#
# :param name:
# :return:
# """
# for p in self._plugins:
# if p.name == name:
# return p
# return None
:param name:
:return:
"""
for p in self._plugins:
self.logger.debug(f"plugin name >>>> {p.__class__.__name__}")
if p.__class__.__name__ == name:
return p
return None
async def close(self) -> None:
"""Free PluginManager resources and cancel pending event methods."""
await self.map_plugin_coro("close")
await self.map_plugin_close()
for task in self._fired_events:
task.cancel()
self._fired_events.clear()
@property
def plugins(self) -> list['BasePlugin']:
def plugins(self) -> list["BasePlugin"]:
"""Get the loaded plugins list.
:return:
@ -229,100 +179,83 @@ class PluginManager:
await asyncio.wait(tasks)
self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
async def map(
self,
coro: Callable[[Plugin, Any], Awaitable[str | bool | None]],
*args: Any,
**kwargs: Any,
) -> dict[Plugin, str | bool | None]:
"""Schedule a given coroutine call for each plugin.
async def map_plugin_auth(self, session: Session) -> dict["BasePlugin", str | bool | None]:
"""Schedule a coroutine for plugin 'authenticate' calls.
The coro called gets the Plugin instance as the first argument of its method call.
:param coro: coro to call on each plugin
:param filter_plugins: list of plugin names to filter (only plugin whose name is
in the filter are called). None will call all plugins. [] will call None.
:param args: arguments to pass to coro
:param kwargs: arguments to pass to coro
:param session: the client session associated with the authentication check
:return: dict containing return from coro call for each plugin.
"""
p_list = kwargs.pop("filter_plugins", None)
if p_list is None:
p_list = [p.name for p in self.plugins]
tasks: list[asyncio.Future[Any]] = []
plugins_list: list[Plugin] = []
for plugin in self._plugins:
if plugin.name in p_list:
coro_instance = coro(plugin, *args, **kwargs)
if coro_instance:
try:
tasks.append(self._schedule_coro(coro_instance))
plugins_list.append(plugin)
except AssertionError:
self.logger.exception(f"Method '{coro!r}' on plugin '{plugin.name}' is not a coroutine")
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = dict(zip(plugins_list, ret_list, strict=False))
else:
ret_dict = {}
return ret_dict
@staticmethod
async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> str | bool | None:
if not hasattr(plugin.object, coro_name):
_LOGGER.warning(f"Plugin doesn't implement coro_name '{coro_name}': {plugin.name}")
return None
coro: Awaitable[str | bool | None] = getattr(plugin.object, coro_name)(*args, **kwargs)
return await coro
async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, str | bool | None]:
"""Call a plugin declared by plugin by its name.
:param coro_name:
:param args:
:param kwargs:
:return:
"""
return await self.map(self._call_coro, coro_name, *args, **kwargs)
async def map_plugin_auth(self, session: Session) -> dict['BaseAuthPlugin', str | bool | None]:
tasks: list[asyncio.Future[Any]] = []
for plugin in self._auth_plugins:
async def auth_coro(p: 'BaseAuthPlugin', s: Session) -> str | bool | None:
async def auth_coro(p: "BaseAuthPlugin", s: Session) -> str | bool | None:
return await p.authenticate(session=s)
if not hasattr(plugin, "authenticate"):
continue
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {}
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
else:
ret_dict = {}
return ret_dict
async def map_plugin_topic(self, session: Session, topic: str, action: 'Action') -> dict['BaseTopicPlugin', str | bool | None]:
async def map_plugin_topic(self,
session: Session, topic: str, action: "Action"
) -> dict["BasePlugin", str | bool | None]:
"""Schedule a coroutine for plugin 'topic_filtering' calls.
:param session: the client session associated with the topic_filtering check
:param topic: the topic that needs to be filtered
:param action: the action being executed
:return: dict containing return from coro call for each plugin.
"""
tasks: list[asyncio.Future[Any]] = []
for plugin in self._topic_plugins:
async def topic_coro(p: 'BaseTopicPlugin', s: Session, t: str, a: 'Action') -> str | bool | None:
async def topic_coro(p: "BaseTopicPlugin", s: Session, t: str, a: "Action") -> str | bool | None:
return await p.topic_filtering(session=s, topic=t, action=a)
if not hasattr(plugin, "topic_filtering"):
continue
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {}
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = {dict(zip(self._auth_plugins, ret_list, strict=False))}
else:
ret_dict = {}
ret_dict= dict(zip(self._auth_plugins, ret_list, strict=False))
return ret_dict
async def map_plugin_close(self) -> dict["BasePlugin", str | bool | None]:
tasks: list[asyncio.Future[Any]] = []
for plugin in self._plugins:
async def close_coro(p: "BasePlugin") -> None:
await p.close()
if not hasattr(plugin, "close"):
continue
coro_instance: Awaitable[str | bool | None] = close_coro(plugin)
tasks.append(asyncio.ensure_future(coro_instance))
ret_dict: dict[BasePlugin, str | bool | None] = {}
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
return ret_dict

Wyświetl plik

@ -42,7 +42,7 @@ STAT_CLIENTS_CONNECTED = "clients_connected"
STAT_CLIENTS_DISCONNECTED = "clients_disconnected"
class BrokerSysPlugin(BasePlugin):
class BrokerSysPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
# Broker statistics initialization

Wyświetl plik

@ -6,7 +6,7 @@ from amqtt.plugins.manager import BaseContext
from amqtt.session import Session
class BaseTopicPlugin(BasePlugin):
class BaseTopicPlugin(BasePlugin[BaseContext]):
"""Base class for topic plugins."""
def __init__(self, context: BaseContext) -> None:

Wyświetl plik

@ -1,13 +1,10 @@
from __future__ import annotations
import logging
import sys
from importlib import import_module
from pathlib import Path
import secrets
import string
import typing
from types import ModuleType
from typing import Any
import yaml
@ -51,32 +48,3 @@ def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None:
except yaml.YAMLError:
logger.exception(f"Invalid config_file {config_file}")
return None
def cached_import(module_path: str, class_name: str=None) -> ModuleType:
# Check whether module is loaded and fully initialized.
if not ((module := sys.modules.get(module_path))
and (spec := getattr(module, "__spec__", None)) # noqa
and getattr(spec, "_initializing", False) is False): # noqa
module = import_module(module_path)
if class_name:
return getattr(module, class_name)
return module
# TODO : figure out proper return type
def import_string(dotted_path) -> Any:
"""
Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import failed.
"""
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError as err:
raise ImportError(f"{dotted_path} doesn't look like a module path") from err
try:
return cached_import(module_path, class_name)
except AttributeError as err:
raise ImportError(
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
) from err

Wyświetl plik

@ -185,13 +185,13 @@ max-returns = 10
# ----------------------------------- PYTEST -----------------------------------
[tool.pytest.ini_options]
#addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
testpaths = ["tests"]
asyncio_mode = "auto"
timeout = 10
asyncio_default_fixture_loop_scope = "function"
addopts = ["--tb=short", "--capture=tee-sys"]
log_cli = true
#addopts = ["--tb=short", "--capture=tee-sys"]
#log_cli = true
log_level = "DEBUG"
# ------------------------------------ MYPY ------------------------------------

Wyświetl plik

@ -1,117 +0,0 @@
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
import pytest
import yaml
from amqtt.broker import Broker
from yaml import CLoader as Loader
from dacite import from_dict, Config, UnexpectedDataError
from amqtt.client import MQTTClient
from amqtt.errors import PluginLoadError
logger = logging.getLogger(__name__)
plugin_config = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestSimplePlugin:
- tests.plugins.mocks.TestConfigPlugin:
option1: 1
option2: bar
"""
plugin_invalid_config_one = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestSimplePlugin:
option1: 1
option2: bar
"""
plugin_invalid_config_two = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestConfigPlugin:
"""
plugin_config_auth = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestAuthPlugin:
"""
plugin_config_topic = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestTopicPlugin:
"""
@pytest.mark.asyncio
async def test_plugin_config_extra_fields():
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
with pytest.raises(PluginLoadError):
_ = Broker(config=cfg)
@pytest.mark.asyncio
async def test_plugin_config_missing_fields():
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
with pytest.raises(PluginLoadError):
_ = Broker(config=cfg)
@pytest.mark.asyncio
async def test_alternate_plugin_load():
cfg: dict[str, Any] = yaml.load(plugin_config, Loader=Loader)
broker = Broker(config=cfg)
await broker.start()
await broker.shutdown()
@pytest.mark.asyncio
async def test_auth_plugin_load():
cfg: dict[str, Any] = yaml.load(plugin_config_auth, Loader=Loader)
broker = Broker(config=cfg)
await broker.start()
await asyncio.sleep(0.5)
client1 = MQTTClient()
await client1.connect()
await client1.publish('my/topic', b'my message')
await client1.disconnect()
await asyncio.sleep(0.5)
await broker.shutdown()
@pytest.mark.asyncio
async def test_topic_plugin_load():
cfg: dict[str, Any] = yaml.load(plugin_config_topic, Loader=Loader)
broker = Broker(config=cfg)
await broker.start()
await broker.shutdown()

Wyświetl plik

@ -1,9 +1,12 @@
import asyncio
import logging
from typing import Any
import unittest
from amqtt.plugins.manager import BaseContext, Plugin, PluginManager
from amqtt.broker import Action
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.session import Session
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
@ -14,21 +17,29 @@ class EmptyTestPlugin:
self.context = context
class EventTestPlugin:
class EventTestPlugin(BaseAuthPlugin, BaseTopicPlugin):
def __init__(self, context: BaseContext) -> None:
self.context = context
self.test_flag = False
self.coro_flag = False
super().__init__(context)
self.test_close_flag = False
self.test_auth_flag = False
self.test_topic_flag = False
self.test_event_flag = False
async def on_test(self) -> None:
self.test_flag = True
self.context.logger.info("on_test")
self.test_event_flag = True
async def test_coro(self) -> None:
self.coro_flag = True
async def authenticate(self, *, session: Session) -> bool | None:
self.test_auth_flag = True
return None
async def ret_coro(self) -> str:
return "TEST"
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
self.test_topic_flag = True
return False
async def close(self) -> None:
self.test_close_flag = True
class TestPluginManager(unittest.TestCase):
@ -47,9 +58,9 @@ class TestPluginManager(unittest.TestCase):
manager = PluginManager("amqtt.test.plugins", context=None)
self.loop.run_until_complete(fire_event())
plugin = manager.get_plugin("event_plugin")
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert plugin.object.test_flag
assert plugin.test_event_flag
def test_fire_event_wait(self) -> None:
async def fire_event() -> None:
@ -58,36 +69,33 @@ class TestPluginManager(unittest.TestCase):
manager = PluginManager("amqtt.test.plugins", context=None)
self.loop.run_until_complete(fire_event())
plugin = manager.get_plugin("event_plugin")
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert plugin.object.test_flag
assert plugin.test_event_flag
def test_map_coro(self) -> None:
async def call_coro() -> None:
await manager.map_plugin_coro("test_coro")
def test_plugin_close_coro(self) -> None:
manager = PluginManager("amqtt.test.plugins", context=None)
self.loop.run_until_complete(call_coro())
plugin = manager.get_plugin("event_plugin")
self.loop.run_until_complete(manager.map_plugin_close())
self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert plugin.object.test_coro
assert plugin.test_close_flag
def test_map_coro_return(self) -> None:
async def call_coro() -> dict[Plugin, str]:
return await manager.map_plugin_coro("ret_coro")
def test_plugin_auth_coro(self) -> None:
manager = PluginManager("amqtt.test.plugins", context=None)
ret = self.loop.run_until_complete(call_coro())
plugin = manager.get_plugin("event_plugin")
self.loop.run_until_complete(manager.map_plugin_auth(session=Session()))
self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert ret[plugin] == "TEST"
assert plugin.test_auth_flag
def test_map_coro_filter(self) -> None:
"""Run plugin coro but expect no return as an empty filter is given."""
async def call_coro() -> dict[Plugin, Any]:
return await manager.map_plugin_coro("ret_coro", filter_plugins=[])
def test_plugin_topic_coro(self) -> None:
manager = PluginManager("amqtt.test.plugins", context=None)
ret = self.loop.run_until_complete(call_coro())
assert len(ret) == 0
self.loop.run_until_complete(manager.map_plugin_topic(session=Session(), topic="test", action=Action.PUBLISH))
self.loop.run_until_complete(asyncio.sleep(0.5))
plugin = manager.get_plugin("EventTestPlugin")
assert plugin is not None
assert plugin.test_topic_flag