fixes Yakifo/amqtt#51 : plugins which fail to import or load should stop the broker from starting

pull/203/head
Andrew Mirsky 2025-06-09 00:11:06 -04:00
rodzic 80016d8cca
commit 2caa792d69
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
8 zmienionych plików z 115 dodań i 12 usunięć

Wyświetl plik

@ -145,6 +145,9 @@ class Broker:
loop: asyncio loop. defaults to `asyncio.get_event_loop()`.
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
Raises:
BrokerError, ParserError, PluginError
"""
states: ClassVar[list[str]] = [

Wyświetl plik

@ -88,6 +88,9 @@ class MQTTClient:
it will be generated randomly by `amqtt.utils.gen_client_id`
config: dictionary of configuration options (see [client configuration](client_config.md)).
Raises:
PluginError
"""
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
@ -142,7 +145,7 @@ class MQTTClient:
[CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
Raises:
amqtt.client.ConnectException: if connection fails
ClientError, ConnectError
"""
additional_headers = additional_headers if additional_headers is not None else {}

Wyświetl plik

@ -1,3 +1,6 @@
from typing import Any
class AMQTTError(Exception):
"""aMQTT base exception."""
@ -18,6 +21,20 @@ class BrokerError(Exception):
"""Exceptions thrown by broker."""
class PluginError(Exception):
"""Exceptions thrown when loading or initializing a plugin."""
class PluginImportError(PluginError):
def __init__(self, plugin: Any) -> None:
super().__init__(f"Plugin import failed: {plugin!r}")
class PluginInitError(PluginError):
def __init__(self, plugin: Any) -> None:
super().__init__(f"Plugin init failed: {plugin!r}")
class ClientError(Exception):
"""Exceptions thrown by client."""

Wyświetl plik

@ -8,6 +8,8 @@ from importlib.metadata import EntryPoint, EntryPoints, entry_points
import logging
from typing import Any, NamedTuple
from amqtt.errors import PluginImportError, PluginInitError
_LOGGER = logging.getLogger(__name__)
@ -80,17 +82,21 @@ class PluginManager:
try:
self.logger.debug(f" Loading plugin {ep!s}")
plugin = ep.load()
self.logger.debug(f" Initializing plugin {ep!s}")
plugin_context = copy.copy(self.app_context)
plugin_context.logger = self.logger.getChild(ep.name)
except ImportError as e:
self.logger.debug(f"Plugin import failed: {ep!r}", exc_info=True)
raise PluginImportError(ep) from e
self.logger.debug(f" Initializing plugin {ep!s}")
plugin_context = copy.copy(self.app_context)
plugin_context.logger = self.logger.getChild(ep.name)
try:
obj = plugin(plugin_context)
return Plugin(ep.name, ep, obj)
except ImportError:
self.logger.warning(f"Plugin {ep!r} import failed")
self.logger.debug("", exc_info=True)
return None
except Exception as e:
self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
raise PluginInitError(ep) from e
def get_plugin(self, name: str) -> Plugin | None:
"""Get a plugin by its name from the plugins loaded for the current namespace.

Wyświetl plik

@ -7,7 +7,7 @@ from yaml.parser import ParserError
from amqtt import __version__ as amqtt_version
from amqtt.broker import Broker
from amqtt.errors import BrokerError
from amqtt.errors import BrokerError, PluginError
from amqtt.utils import read_yaml_config
logger = logging.getLogger(__name__)
@ -58,7 +58,7 @@ def broker_main(
loop = asyncio.get_event_loop()
try:
broker = Broker(config)
except (BrokerError, ParserError) as exc:
except (BrokerError, ParserError, PluginError) as exc:
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
raise typer.Exit(code=1) from exc

Wyświetl plik

@ -185,7 +185,7 @@ timeout = 10
asyncio_default_fixture_loop_scope = "function"
#addopts = ["--tb=short", "--capture=tee-sys"]
#log_cli = true
#log_level = "DEBUG"
log_level = "DEBUG"
# ------------------------------------ MYPY ------------------------------------
[tool.mypy]

Wyświetl plik

@ -0,0 +1,10 @@
from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin
# intentional import error to test broker response
from pathlib import Pat # noqa
class MockImportErrorPlugin(BasePlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)

Wyświetl plik

@ -1,12 +1,17 @@
import inspect
from importlib.metadata import EntryPoint
from logging import getLogger
from pathlib import Path
from types import ModuleType
from typing import Any
from unittest.mock import patch
import pytest
import amqtt.plugins
from amqtt.broker import Broker, BrokerContext
from amqtt.errors import PluginError, PluginInitError, PluginImportError
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
_INVALID_METHOD: str = "invalid_foo"
@ -61,3 +66,62 @@ def test_plugins_correct_has_attr() -> None:
__import__(name)
_verify_module(module, module.__name__)
class MockInitErrorPlugin(BasePlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
raise KeyError
@pytest.mark.asyncio
async def test_plugin_exception_while_init() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
with pytest.raises(PluginInitError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
@pytest.mark.asyncio
async def test_plugin_exception_while_loading() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
with pytest.raises(PluginImportError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)