fixes Yakifo/amqtt#51 : plugins which fail to import or load should stop the broker from starting

pull/203/head
Andrew Mirsky 2025-06-09 00:11:06 -04:00
rodzic 80016d8cca
commit 2caa792d69
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
8 zmienionych plików z 115 dodań i 12 usunięć

Wyświetl plik

@ -145,6 +145,9 @@ class Broker:
loop: asyncio loop. defaults to `asyncio.get_event_loop()`. loop: asyncio loop. defaults to `asyncio.get_event_loop()`.
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`. plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
Raises:
BrokerError, ParserError, PluginError
""" """
states: ClassVar[list[str]] = [ states: ClassVar[list[str]] = [

Wyświetl plik

@ -88,6 +88,9 @@ class MQTTClient:
it will be generated randomly by `amqtt.utils.gen_client_id` it will be generated randomly by `amqtt.utils.gen_client_id`
config: dictionary of configuration options (see [client configuration](client_config.md)). config: dictionary of configuration options (see [client configuration](client_config.md)).
Raises:
PluginError
""" """
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None: def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
@ -142,7 +145,7 @@ class MQTTClient:
[CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code [CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
Raises: Raises:
amqtt.client.ConnectException: if connection fails ClientError, ConnectError
""" """
additional_headers = additional_headers if additional_headers is not None else {} additional_headers = additional_headers if additional_headers is not None else {}

Wyświetl plik

@ -1,3 +1,6 @@
from typing import Any
class AMQTTError(Exception): class AMQTTError(Exception):
"""aMQTT base exception.""" """aMQTT base exception."""
@ -18,6 +21,20 @@ class BrokerError(Exception):
"""Exceptions thrown by broker.""" """Exceptions thrown by broker."""
class PluginError(Exception):
"""Exceptions thrown when loading or initializing a plugin."""
class PluginImportError(PluginError):
def __init__(self, plugin: Any) -> None:
super().__init__(f"Plugin import failed: {plugin!r}")
class PluginInitError(PluginError):
def __init__(self, plugin: Any) -> None:
super().__init__(f"Plugin init failed: {plugin!r}")
class ClientError(Exception): class ClientError(Exception):
"""Exceptions thrown by client.""" """Exceptions thrown by client."""

Wyświetl plik

@ -8,6 +8,8 @@ from importlib.metadata import EntryPoint, EntryPoints, entry_points
import logging import logging
from typing import Any, NamedTuple from typing import Any, NamedTuple
from amqtt.errors import PluginImportError, PluginInitError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -80,17 +82,21 @@ class PluginManager:
try: try:
self.logger.debug(f" Loading plugin {ep!s}") self.logger.debug(f" Loading plugin {ep!s}")
plugin = ep.load() plugin = ep.load()
self.logger.debug(f" Initializing plugin {ep!s}")
plugin_context = copy.copy(self.app_context) except ImportError as e:
plugin_context.logger = self.logger.getChild(ep.name) self.logger.debug(f"Plugin import failed: {ep!r}", exc_info=True)
raise PluginImportError(ep) from e
self.logger.debug(f" Initializing plugin {ep!s}")
plugin_context = copy.copy(self.app_context)
plugin_context.logger = self.logger.getChild(ep.name)
try:
obj = plugin(plugin_context) obj = plugin(plugin_context)
return Plugin(ep.name, ep, obj) return Plugin(ep.name, ep, obj)
except ImportError: except Exception as e:
self.logger.warning(f"Plugin {ep!r} import failed") self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
self.logger.debug("", exc_info=True) raise PluginInitError(ep) from e
return None
def get_plugin(self, name: str) -> Plugin | None: def get_plugin(self, name: str) -> Plugin | None:
"""Get a plugin by its name from the plugins loaded for the current namespace. """Get a plugin by its name from the plugins loaded for the current namespace.

Wyświetl plik

@ -7,7 +7,7 @@ from yaml.parser import ParserError
from amqtt import __version__ as amqtt_version from amqtt import __version__ as amqtt_version
from amqtt.broker import Broker from amqtt.broker import Broker
from amqtt.errors import BrokerError from amqtt.errors import BrokerError, PluginError
from amqtt.utils import read_yaml_config from amqtt.utils import read_yaml_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,7 +58,7 @@ def broker_main(
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
broker = Broker(config) broker = Broker(config)
except (BrokerError, ParserError) as exc: except (BrokerError, ParserError, PluginError) as exc:
typer.echo(f"❌ Broker failed to start: {exc}", err=True) typer.echo(f"❌ Broker failed to start: {exc}", err=True)
raise typer.Exit(code=1) from exc raise typer.Exit(code=1) from exc

Wyświetl plik

@ -185,7 +185,7 @@ timeout = 10
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"
#addopts = ["--tb=short", "--capture=tee-sys"] #addopts = ["--tb=short", "--capture=tee-sys"]
#log_cli = true #log_cli = true
#log_level = "DEBUG" log_level = "DEBUG"
# ------------------------------------ MYPY ------------------------------------ # ------------------------------------ MYPY ------------------------------------
[tool.mypy] [tool.mypy]

Wyświetl plik

@ -0,0 +1,10 @@
from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin
# intentional import error to test broker response
from pathlib import Pat # noqa
class MockImportErrorPlugin(BasePlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)

Wyświetl plik

@ -1,12 +1,17 @@
import inspect import inspect
from importlib.metadata import EntryPoint
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Any from typing import Any
from unittest.mock import patch
import pytest import pytest
import amqtt.plugins import amqtt.plugins
from amqtt.broker import Broker, BrokerContext
from amqtt.errors import PluginError, PluginInitError, PluginImportError
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext from amqtt.plugins.manager import BaseContext
_INVALID_METHOD: str = "invalid_foo" _INVALID_METHOD: str = "invalid_foo"
@ -61,3 +66,62 @@ def test_plugins_correct_has_attr() -> None:
__import__(name) __import__(name)
_verify_module(module, module.__name__) _verify_module(module, module.__name__)
class MockInitErrorPlugin(BasePlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
raise KeyError
@pytest.mark.asyncio
async def test_plugin_exception_while_init() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
with pytest.raises(PluginInitError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
@pytest.mark.asyncio
async def test_plugin_exception_while_loading() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
with pytest.raises(PluginImportError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)