loading plugins from config file, bypassing entry points. authenticate and subscribe/publish

pull/212/head
Andrew Mirsky 2025-06-12 08:37:27 -04:00
rodzic 06053ce7ee
commit 7b936d785c
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
13 zmienionych plików z 399 dodań i 47 usunięć

Wyświetl plik

@ -693,16 +693,17 @@ class Broker:
auth_config = self.config.get("auth", None)
if isinstance(auth_config, dict):
auth_plugins = auth_config.get("plugins", None)
returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
# returns = await self.plugins_manager.map_plugin_coro("authenticate", session=session, filter_plugins=auth_plugins)
returns = await self.plugins_manager.map_plugin_auth(session=session)
auth_result = True
if returns:
for plugin in returns:
res = returns[plugin]
if res is False:
auth_result = False
self.logger.debug(f"Authentication failed due to '{plugin.name}' plugin result: {res}")
self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}")
else:
self.logger.debug(f"'{plugin.name}' plugin result: {res}")
self.logger.debug(f"'{plugin.__class__}' plugin result: {res}")
# If all plugins returned True, authentication is success
return auth_result
@ -771,13 +772,15 @@ class Broker:
if not enabled:
return True
results = await self.plugins_manager.map_plugin_coro(
"topic_filtering",
session=session,
topic=topic,
action=action,
filter_plugins=topic_plugins,
)
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
# results = await self.plugins_manager.map_plugin_coro(
# "topic_filtering",
# session=session,
# topic=topic,
# action=action,
# filter_plugins=topic_plugins,
# )
return all(result for result in results.values())
async def _delete_session(self, client_id: str) -> None:

Wyświetl plik

@ -30,3 +30,8 @@ class ConnectError(ClientError):
class ProtocolHandlerError(Exception):
"""Exceptions thrown by protocol handle."""
class PluginLoadError(Exception):
"""Exception thrown when loading a plugin."""

Wyświetl plik

@ -5,6 +5,7 @@ from passlib.apps import custom_app_context as pwd_context
from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
from amqtt.session import Session
_PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
@ -13,7 +14,7 @@ _PARTS_EXPECTED_LENGTH = 2 # Expected number of parts in a valid line
class BaseAuthPlugin(BasePlugin):
"""Base class for authentication plugins."""
def __init__(self, context: BrokerContext) -> None:
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.auth_config: dict[str, Any] | None = self._get_config_section("auth")

Wyświetl plik

@ -1,19 +1,26 @@
from dataclasses import dataclass
from typing import Any
from amqtt.broker import BrokerContext
from amqtt.plugins.manager import BaseContext
class BasePlugin:
"""The base from which all plugins should inherit."""
def __init__(self, context: BrokerContext) -> None:
def __init__(self, context: BaseContext) -> None:
self.context = context
def _get_config_section(self, name: str) -> dict[str, Any] | None:
if not self.context.config or not self.context.config.get(name, None):
if not self.context.config or not hasattr(self.context.config, 'get') or self.context.config.get(name, None):
return None
section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
# mypy has difficulty excluding int from `config`'s type, unless isinstance` is its own check
if isinstance(section_config, int):
return None
return section_config
@dataclass
class Config:
pass

Wyświetl plik

@ -6,10 +6,20 @@ import contextlib
import copy
from importlib.metadata import EntryPoint, EntryPoints, entry_points
import logging
from typing import Any, NamedTuple
from typing import Any, NamedTuple, TYPE_CHECKING
from amqtt.errors import MQTTError, PluginLoadError
from amqtt.session import Session
from amqtt.utils import import_string
from dacite import from_dict, Config, DaciteError
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.broker import Action
class Plugin(NamedTuple):
name: str
@ -53,7 +63,9 @@ class PluginManager:
self.logger = logging.getLogger(namespace)
self.context = context if context is not None else BaseContext()
self.context.loop = self._loop
self._plugins: list[Plugin] = []
self._plugins: list[BasePlugin] = []
self._auth_plugins: list[BaseAuthPlugin] = []
self._topic_plugins: list[BaseTopicPlugin] = []
self._load_plugins(namespace)
self._fired_events: list[asyncio.Future[Any]] = []
plugins_manager[namespace] = self
@ -63,20 +75,54 @@ class PluginManager:
return self.context
def _load_plugins(self, namespace: str) -> None:
self.logger.debug(f"Loading plugins for namespace {namespace}")
ep: EntryPoints | list[EntryPoint] = []
if hasattr(entry_points(), "select"):
ep = entry_points().select(group=namespace)
elif namespace in entry_points():
ep = [entry_points()[namespace]]
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
if 'plugins' in self.app_context.config:
self.logger.info("Loading plugins from config file")
for plugin_info in self.app_context.config['plugins']:
if isinstance(plugin_info, dict):
assert len(plugin_info.keys()) == 1
plugin_path = list(plugin_info.keys())[0]
plugin_cfg = plugin_info[plugin_path]
plugin = self._load_str_plugin(plugin_path, plugin_cfg)
elif isinstance(plugin_info, str):
plugin = self._load_str_plugin(plugin_info, {})
else:
msg = 'Unexpected entry in plugins config'
raise PluginLoadError(msg)
for item in ep:
plugin = self._load_plugin(item)
if plugin is not None:
self._plugins.append(plugin)
self.logger.debug(f" Plugin {item.name} ready")
if isinstance(plugin, BaseAuthPlugin):
self._auth_plugins.append(plugin)
if isinstance(plugin, BaseTopicPlugin):
self._topic_plugins.append(plugin)
def _load_plugin(self, ep: EntryPoint) -> Plugin | None:
else:
self.logger.debug(f"Loading plugins for namespace {namespace}")
auth_filter_list = self.app_context.config['auth'].get('plugins', []) if 'auth' in self.app_context.config else []
topic_filter_list = self.app_context.config['topic'].get('plugins', []) if 'topic' in self.app_context.config else []
ep: EntryPoints | list[EntryPoint] = []
if hasattr(entry_points(), "select"):
ep = entry_points().select(group=namespace)
elif namespace in entry_points():
ep = [entry_points()[namespace]]
for item in ep:
plugin = self._load_ep_plugin(item)
if plugin is not None:
self._plugins.append(plugin.object)
if plugin.name in auth_filter_list:
self._auth_plugins.append(plugin.object)
elif plugin.name in topic_filter_list:
self._topic_plugins.append(plugin.object)
self.logger.debug(f" Plugin {item.name} ready")
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
try:
self.logger.debug(f" Loading plugin {ep!s}")
plugin = ep.load()
@ -92,16 +138,43 @@ class PluginManager:
return None
def get_plugin(self, name: str) -> Plugin | None:
"""Get a plugin by its name from the plugins loaded for the current namespace.
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> 'BasePlugin':
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.topic_checking import BaseTopicPlugin
:param name:
:return:
"""
for p in self._plugins:
if p.name == name:
return p
return None
try:
plugin_class = import_string(plugin_path)
except ModuleNotFoundError as ep:
self.logger.error(f"Plugin import failed: {plugin_path}")
raise MQTTError() from ep
if not issubclass(plugin_class, BasePlugin):
msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'"
raise PluginLoadError(msg)
plugin_context = copy.copy(self.app_context)
plugin_context.logger = self.logger.getChild(plugin_class.__name__)
try:
plugin_context.config = from_dict(data_class=plugin_class.Config, data=plugin_cfg or {}, config=Config(strict=True))
except DaciteError as e:
raise PluginLoadError from e
try:
return plugin_class(plugin_context)
except ImportError as e:
raise PluginLoadError from e
# def get_plugin(self, name: str) -> Plugin | None:
# """Get a plugin by its name from the plugins loaded for the current namespace.
#
# :param name:
# :return:
# """
# for p in self._plugins:
# if p.name == name:
# return p
# return None
async def close(self) -> None:
"""Free PluginManager resources and cancel pending event methods."""
@ -111,7 +184,7 @@ class PluginManager:
self._fired_events.clear()
@property
def plugins(self) -> list[Plugin]:
def plugins(self) -> list['BasePlugin']:
"""Get the loaded plugins list.
:return:
@ -137,7 +210,7 @@ class PluginManager:
tasks: list[asyncio.Future[Any]] = []
event_method_name = "on_" + event_name
for plugin in self._plugins:
event_method = getattr(plugin.object, event_method_name, None)
event_method = getattr(plugin, event_method_name, None)
if event_method:
try:
task = self._schedule_coro(event_method(*args, **kwargs))
@ -149,7 +222,7 @@ class PluginManager:
task.add_done_callback(clean_fired_events)
except AssertionError:
self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.name}' is not a coroutine")
self.logger.exception(f"Method '{event_method_name}' on plugin '{plugin.__class__}' is not a coroutine")
self._fired_events.extend(tasks)
if wait and tasks:
@ -212,3 +285,44 @@ class PluginManager:
:return:
"""
return await self.map(self._call_coro, coro_name, *args, **kwargs)
async def map_plugin_auth(self, session: Session) -> dict['BaseAuthPlugin', str | bool | None]:
tasks: list[asyncio.Future[Any]] = []
for plugin in self._auth_plugins:
async def auth_coro(p: 'BaseAuthPlugin', s: Session) -> str | bool | None:
return await p.authenticate(session=s)
coro_instance: Awaitable[str | bool | None] = auth_coro(plugin, session)
tasks.append(asyncio.ensure_future(coro_instance))
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = dict(zip(self._auth_plugins, ret_list, strict=False))
else:
ret_dict = {}
return ret_dict
async def map_plugin_topic(self, session: Session, topic: str, action: 'Action') -> dict['BaseTopicPlugin', str | bool | None]:
tasks: list[asyncio.Future[Any]] = []
for plugin in self._topic_plugins:
async def topic_coro(p: 'BaseTopicPlugin', s: Session, t: str, a: 'Action') -> str | bool | None:
return await p.topic_filtering(session=s, topic=t, action=a)
coro_instance: Awaitable[str | bool | None] = topic_coro(plugin, session, topic, action)
tasks.append(asyncio.ensure_future(coro_instance))
if tasks:
ret_list = await asyncio.gather(*tasks)
# Create result map plugin => ret
ret_dict = {dict(zip(self._auth_plugins, ret_list, strict=False))}
else:
ret_dict = {}
return ret_dict

Wyświetl plik

@ -1,14 +1,15 @@
from typing import Any
from amqtt.broker import Action, BrokerContext
from amqtt.broker import Action
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
from amqtt.session import Session
class BaseTopicPlugin(BasePlugin):
"""Base class for topic plugins."""
def __init__(self, context: BrokerContext) -> None:
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
@ -37,7 +38,7 @@ class BaseTopicPlugin(BasePlugin):
class TopicTabooPlugin(BaseTopicPlugin):
def __init__(self, context: BrokerContext) -> None:
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self._taboo: list[str] = ["prohibited", "top-secret", "data/classified"]

Wyświetl plik

@ -6,3 +6,5 @@ default_retain: false
auto_reconnect: true
reconnect_max_interval: 10
reconnect_retries: 2
broker:
uri: "mqtt://127.0.0.1"

Wyświetl plik

@ -1,10 +1,13 @@
from __future__ import annotations
import logging
import sys
from importlib import import_module
from pathlib import Path
import secrets
import string
import typing
from types import ModuleType
from typing import Any
import yaml
@ -48,3 +51,32 @@ def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None:
except yaml.YAMLError:
logger.exception(f"Invalid config_file {config_file}")
return None
def cached_import(module_path: str, class_name: str=None) -> ModuleType:
# Check whether module is loaded and fully initialized.
if not ((module := sys.modules.get(module_path))
and (spec := getattr(module, "__spec__", None)) # noqa
and getattr(spec, "_initializing", False) is False): # noqa
module = import_module(module_path)
if class_name:
return getattr(module, class_name)
return module
# TODO : figure out proper return type
def import_string(dotted_path) -> Any:
"""
Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import failed.
"""
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError as err:
raise ImportError(f"{dotted_path} doesn't look like a module path") from err
try:
return cached_import(module_path, class_name)
except AttributeError as err:
raise ImportError(
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
) from err

Wyświetl plik

@ -32,7 +32,8 @@ dependencies = [
"websockets==15.0.1", # https://pypi.org/project/websockets
"passlib==1.7.4", # https://pypi.org/project/passlib
"PyYAML==6.0.2", # https://pypi.org/project/PyYAML
"typer==0.15.4"
"typer==0.15.4",
"dacite>=1.9.2",
]
[dependency-groups]
@ -184,14 +185,14 @@ max-returns = 10
# ----------------------------------- PYTEST -----------------------------------
[tool.pytest.ini_options]
addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
#addopts = ["--cov=amqtt", "--cov-report=term-missing", "--cov-report=html"]
testpaths = ["tests"]
asyncio_mode = "auto"
timeout = 10
asyncio_default_fixture_loop_scope = "function"
#addopts = ["--tb=short", "--capture=tee-sys"]
#log_cli = true
#log_level = "DEBUG"
addopts = ["--tb=short", "--capture=tee-sys"]
log_cli = true
log_level = "DEBUG"
# ------------------------------------ MYPY ------------------------------------
[tool.mypy]

Wyświetl plik

@ -0,0 +1,10 @@
---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- test.plugins.plugins.TestSimplePlugin
- test.plugins.plugins.TestConfigPlugin:
option1: foo
option2: bar

Wyświetl plik

@ -0,0 +1,48 @@
import logging
from dataclasses import dataclass
from amqtt.broker import Action
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
from amqtt.plugins.topic_checking import BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class TestSimplePlugin(BasePlugin):
def __init__(self, context: BaseContext):
super().__init__(context)
class TestConfigPlugin(BasePlugin):
def __init__(self, context: BaseContext):
super().__init__(context)
@dataclass
class Config:
option1: int
option2: str
class TestAuthPlugin(BaseAuthPlugin):
def __init__(self, context: BaseContext):
super().__init__(context)
async def authenticate(self, *, session: Session) -> bool | None:
return False
class TestTopicPlugin(BaseTopicPlugin):
def __init__(self, context: BaseContext):
super().__init__(context)
def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
return True

Wyświetl plik

@ -0,0 +1,117 @@
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
import pytest
import yaml
from amqtt.broker import Broker
from yaml import CLoader as Loader
from dacite import from_dict, Config, UnexpectedDataError
from amqtt.client import MQTTClient
from amqtt.errors import PluginLoadError
logger = logging.getLogger(__name__)
plugin_config = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestSimplePlugin:
- tests.plugins.mocks.TestConfigPlugin:
option1: 1
option2: bar
"""
plugin_invalid_config_one = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestSimplePlugin:
option1: 1
option2: bar
"""
plugin_invalid_config_two = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestConfigPlugin:
"""
plugin_config_auth = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestAuthPlugin:
"""
plugin_config_topic = """---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
plugins:
- tests.plugins.mocks.TestTopicPlugin:
"""
@pytest.mark.asyncio
async def test_plugin_config_extra_fields():
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
with pytest.raises(PluginLoadError):
_ = Broker(config=cfg)
@pytest.mark.asyncio
async def test_plugin_config_missing_fields():
cfg: dict[str, Any] = yaml.load(plugin_invalid_config_one, Loader=Loader)
with pytest.raises(PluginLoadError):
_ = Broker(config=cfg)
@pytest.mark.asyncio
async def test_alternate_plugin_load():
cfg: dict[str, Any] = yaml.load(plugin_config, Loader=Loader)
broker = Broker(config=cfg)
await broker.start()
await broker.shutdown()
@pytest.mark.asyncio
async def test_auth_plugin_load():
cfg: dict[str, Any] = yaml.load(plugin_config_auth, Loader=Loader)
broker = Broker(config=cfg)
await broker.start()
await asyncio.sleep(0.5)
client1 = MQTTClient()
await client1.connect()
await client1.publish('my/topic', b'my message')
await client1.disconnect()
await asyncio.sleep(0.5)
await broker.shutdown()
@pytest.mark.asyncio
async def test_topic_plugin_load():
cfg: dict[str, Any] = yaml.load(plugin_config_topic, Loader=Loader)
broker = Broker(config=cfg)
await broker.start()
await broker.shutdown()

11
uv.lock
Wyświetl plik

@ -12,6 +12,7 @@ name = "amqtt"
version = "0.11.0rc1"
source = { editable = "." }
dependencies = [
{ name = "dacite" },
{ name = "passlib" },
{ name = "pyyaml" },
{ name = "transitions" },
@ -66,6 +67,7 @@ docs = [
[package.metadata]
requires-dist = [
{ name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" },
{ name = "dacite", specifier = ">=1.9.2" },
{ name = "passlib", specifier = "==1.7.4" },
{ name = "pyyaml", specifier = "==6.0.2" },
{ name = "transitions", specifier = "==0.9.2" },
@ -485,6 +487,15 @@ version = "0.9.5"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/8c3ac3d8bc94e6de8d7ae270bb5bc437b210bb9d6d9e46630c98f4abd20c/csscompressor-0.9.5.tar.gz", hash = "sha256:afa22badbcf3120a4f392e4d22f9fff485c044a1feda4a950ecc5eba9dd31a05", size = 237808, upload-time = "2017-11-26T21:13:08.238Z" }
[[package]]
name = "dacite"
version = "1.9.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420, upload-time = "2025-02-05T09:27:29.757Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600, upload-time = "2025-02-05T09:27:24.345Z" },
]
[[package]]
name = "dill"
version = "0.4.0"