From 2caa792d69d4899332693e8ddee1571c06d58042 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Mon, 9 Jun 2025 00:11:06 -0400 Subject: [PATCH] fixes Yakifo/amqtt#51 : plugins which fail to import or load should stop the broker from starting --- amqtt/broker.py | 3 ++ amqtt/client.py | 5 ++- amqtt/errors.py | 17 +++++++++ amqtt/plugins/manager.py | 22 +++++++----- amqtt/scripts/broker_script.py | 4 +-- pyproject.toml | 2 +- tests/plugins/mock_plugins.py | 10 ++++++ tests/plugins/test_plugins.py | 64 ++++++++++++++++++++++++++++++++++ 8 files changed, 115 insertions(+), 12 deletions(-) create mode 100644 tests/plugins/mock_plugins.py diff --git a/amqtt/broker.py b/amqtt/broker.py index 5690227..62502ad 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -145,6 +145,9 @@ class Broker: loop: asyncio loop. defaults to `asyncio.get_event_loop()`. plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`. + Raises: + BrokerError, ParserError, PluginError + """ states: ClassVar[list[str]] = [ diff --git a/amqtt/client.py b/amqtt/client.py index 8da9aba..511dae9 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -88,6 +88,9 @@ class MQTTClient: it will be generated randomly by `amqtt.utils.gen_client_id` config: dictionary of configuration options (see [client configuration](client_config.md)). + Raises: + PluginError + """ def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None: @@ -142,7 +145,7 @@ class MQTTClient: [CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code Raises: - amqtt.client.ConnectException: if connection fails + ClientError, ConnectError """ additional_headers = additional_headers if additional_headers is not None else {} diff --git a/amqtt/errors.py b/amqtt/errors.py index 71d65c7..844a135 100644 --- a/amqtt/errors.py +++ b/amqtt/errors.py @@ -1,3 +1,6 @@ +from typing import Any + + class AMQTTError(Exception): """aMQTT base exception.""" @@ -18,6 +21,20 @@ class BrokerError(Exception): """Exceptions thrown by broker.""" +class PluginError(Exception): + """Exceptions thrown when loading or initializing a plugin.""" + + +class PluginImportError(PluginError): + def __init__(self, plugin: Any) -> None: + super().__init__(f"Plugin import failed: {plugin!r}") + + +class PluginInitError(PluginError): + def __init__(self, plugin: Any) -> None: + super().__init__(f"Plugin init failed: {plugin!r}") + + class ClientError(Exception): """Exceptions thrown by client.""" diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index a0e5c71..61bbb49 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -8,6 +8,8 @@ from importlib.metadata import EntryPoint, EntryPoints, entry_points import logging from typing import Any, NamedTuple +from amqtt.errors import PluginImportError, PluginInitError + _LOGGER = logging.getLogger(__name__) @@ -80,17 +82,21 @@ class PluginManager: try: self.logger.debug(f" Loading plugin {ep!s}") plugin = ep.load() - self.logger.debug(f" Initializing plugin {ep!s}") - plugin_context = copy.copy(self.app_context) - plugin_context.logger = self.logger.getChild(ep.name) + except ImportError as e: + self.logger.debug(f"Plugin import failed: {ep!r}", exc_info=True) + raise PluginImportError(ep) from e + + self.logger.debug(f" Initializing plugin {ep!s}") + + plugin_context = copy.copy(self.app_context) + plugin_context.logger = self.logger.getChild(ep.name) + try: obj = plugin(plugin_context) return Plugin(ep.name, ep, obj) - except ImportError: - self.logger.warning(f"Plugin {ep!r} import failed") - self.logger.debug("", exc_info=True) - - return None + except Exception as e: + self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True) + raise PluginInitError(ep) from e def get_plugin(self, name: str) -> Plugin | None: """Get a plugin by its name from the plugins loaded for the current namespace. diff --git a/amqtt/scripts/broker_script.py b/amqtt/scripts/broker_script.py index 7790108..ad649fc 100644 --- a/amqtt/scripts/broker_script.py +++ b/amqtt/scripts/broker_script.py @@ -7,7 +7,7 @@ from yaml.parser import ParserError from amqtt import __version__ as amqtt_version from amqtt.broker import Broker -from amqtt.errors import BrokerError +from amqtt.errors import BrokerError, PluginError from amqtt.utils import read_yaml_config logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ def broker_main( loop = asyncio.get_event_loop() try: broker = Broker(config) - except (BrokerError, ParserError) as exc: + except (BrokerError, ParserError, PluginError) as exc: typer.echo(f"❌ Broker failed to start: {exc}", err=True) raise typer.Exit(code=1) from exc diff --git a/pyproject.toml b/pyproject.toml index 672aa27..03f4bae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,7 +185,7 @@ timeout = 10 asyncio_default_fixture_loop_scope = "function" #addopts = ["--tb=short", "--capture=tee-sys"] #log_cli = true -#log_level = "DEBUG" +log_level = "DEBUG" # ------------------------------------ MYPY ------------------------------------ [tool.mypy] diff --git a/tests/plugins/mock_plugins.py b/tests/plugins/mock_plugins.py new file mode 100644 index 0000000..45a6ea7 --- /dev/null +++ b/tests/plugins/mock_plugins.py @@ -0,0 +1,10 @@ +from amqtt.broker import BrokerContext +from amqtt.plugins.base import BasePlugin + +# intentional import error to test broker response +from pathlib import Pat # noqa + +class MockImportErrorPlugin(BasePlugin): + + def __init__(self, context: BrokerContext) -> None: + super().__init__(context) diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index 58c86a1..0e6a0fc 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -1,12 +1,17 @@ import inspect +from importlib.metadata import EntryPoint from logging import getLogger from pathlib import Path from types import ModuleType from typing import Any +from unittest.mock import patch import pytest import amqtt.plugins +from amqtt.broker import Broker, BrokerContext +from amqtt.errors import PluginError, PluginInitError, PluginImportError +from amqtt.plugins.base import BasePlugin from amqtt.plugins.manager import BaseContext _INVALID_METHOD: str = "invalid_foo" @@ -61,3 +66,62 @@ def test_plugins_correct_has_attr() -> None: __import__(name) _verify_module(module, module.__name__) + + +class MockInitErrorPlugin(BasePlugin): + + def __init__(self, context: BrokerContext) -> None: + super().__init__(context) + raise KeyError + + +@pytest.mark.asyncio +async def test_plugin_exception_while_init() -> None: + class MockEntryPoints: + + def select(self, group) -> list[EntryPoint]: + match group: + case 'tests.mock_plugins': + return [ + EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'), + ] + case _: + return list() + + with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish: + + config = { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, + }, + 'sys_interval': 1 + } + + with pytest.raises(PluginInitError): + _ = Broker(plugin_namespace='tests.mock_plugins', config=config) + + +@pytest.mark.asyncio +async def test_plugin_exception_while_loading() -> None: + class MockEntryPoints: + + def select(self, group) -> list[EntryPoint]: + match group: + case 'tests.mock_plugins': + return [ + EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'), + ] + case _: + return list() + + with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish: + + config = { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, + }, + 'sys_interval': 1 + } + + with pytest.raises(PluginImportError): + _ = Broker(plugin_namespace='tests.mock_plugins', config=config)