kopia lustrzana https://github.com/Yakifo/amqtt
fixes Yakifo/amqtt#51 : plugins which fail to import or load should stop the broker from starting
rodzic
80016d8cca
commit
2caa792d69
|
@ -145,6 +145,9 @@ class Broker:
|
||||||
loop: asyncio loop. defaults to `asyncio.get_event_loop()`.
|
loop: asyncio loop. defaults to `asyncio.get_event_loop()`.
|
||||||
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
|
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BrokerError, ParserError, PluginError
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
states: ClassVar[list[str]] = [
|
states: ClassVar[list[str]] = [
|
||||||
|
|
|
@ -88,6 +88,9 @@ class MQTTClient:
|
||||||
it will be generated randomly by `amqtt.utils.gen_client_id`
|
it will be generated randomly by `amqtt.utils.gen_client_id`
|
||||||
config: dictionary of configuration options (see [client configuration](client_config.md)).
|
config: dictionary of configuration options (see [client configuration](client_config.md)).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PluginError
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
|
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
|
||||||
|
@ -142,7 +145,7 @@ class MQTTClient:
|
||||||
[CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
|
[CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
amqtt.client.ConnectException: if connection fails
|
ClientError, ConnectError
|
||||||
|
|
||||||
"""
|
"""
|
||||||
additional_headers = additional_headers if additional_headers is not None else {}
|
additional_headers = additional_headers if additional_headers is not None else {}
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class AMQTTError(Exception):
|
class AMQTTError(Exception):
|
||||||
"""aMQTT base exception."""
|
"""aMQTT base exception."""
|
||||||
|
|
||||||
|
@ -18,6 +21,20 @@ class BrokerError(Exception):
|
||||||
"""Exceptions thrown by broker."""
|
"""Exceptions thrown by broker."""
|
||||||
|
|
||||||
|
|
||||||
|
class PluginError(Exception):
|
||||||
|
"""Exceptions thrown when loading or initializing a plugin."""
|
||||||
|
|
||||||
|
|
||||||
|
class PluginImportError(PluginError):
|
||||||
|
def __init__(self, plugin: Any) -> None:
|
||||||
|
super().__init__(f"Plugin import failed: {plugin!r}")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInitError(PluginError):
|
||||||
|
def __init__(self, plugin: Any) -> None:
|
||||||
|
super().__init__(f"Plugin init failed: {plugin!r}")
|
||||||
|
|
||||||
|
|
||||||
class ClientError(Exception):
|
class ClientError(Exception):
|
||||||
"""Exceptions thrown by client."""
|
"""Exceptions thrown by client."""
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,8 @@ from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
from amqtt.errors import PluginImportError, PluginInitError
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,17 +82,21 @@ class PluginManager:
|
||||||
try:
|
try:
|
||||||
self.logger.debug(f" Loading plugin {ep!s}")
|
self.logger.debug(f" Loading plugin {ep!s}")
|
||||||
plugin = ep.load()
|
plugin = ep.load()
|
||||||
self.logger.debug(f" Initializing plugin {ep!s}")
|
|
||||||
|
|
||||||
plugin_context = copy.copy(self.app_context)
|
except ImportError as e:
|
||||||
plugin_context.logger = self.logger.getChild(ep.name)
|
self.logger.debug(f"Plugin import failed: {ep!r}", exc_info=True)
|
||||||
|
raise PluginImportError(ep) from e
|
||||||
|
|
||||||
|
self.logger.debug(f" Initializing plugin {ep!s}")
|
||||||
|
|
||||||
|
plugin_context = copy.copy(self.app_context)
|
||||||
|
plugin_context.logger = self.logger.getChild(ep.name)
|
||||||
|
try:
|
||||||
obj = plugin(plugin_context)
|
obj = plugin(plugin_context)
|
||||||
return Plugin(ep.name, ep, obj)
|
return Plugin(ep.name, ep, obj)
|
||||||
except ImportError:
|
except Exception as e:
|
||||||
self.logger.warning(f"Plugin {ep!r} import failed")
|
self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
|
||||||
self.logger.debug("", exc_info=True)
|
raise PluginInitError(ep) from e
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_plugin(self, name: str) -> Plugin | None:
|
def get_plugin(self, name: str) -> Plugin | None:
|
||||||
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
||||||
|
|
|
@ -7,7 +7,7 @@ from yaml.parser import ParserError
|
||||||
|
|
||||||
from amqtt import __version__ as amqtt_version
|
from amqtt import __version__ as amqtt_version
|
||||||
from amqtt.broker import Broker
|
from amqtt.broker import Broker
|
||||||
from amqtt.errors import BrokerError
|
from amqtt.errors import BrokerError, PluginError
|
||||||
from amqtt.utils import read_yaml_config
|
from amqtt.utils import read_yaml_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -58,7 +58,7 @@ def broker_main(
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
broker = Broker(config)
|
broker = Broker(config)
|
||||||
except (BrokerError, ParserError) as exc:
|
except (BrokerError, ParserError, PluginError) as exc:
|
||||||
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
|
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
|
||||||
raise typer.Exit(code=1) from exc
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
|
@ -185,7 +185,7 @@ timeout = 10
|
||||||
asyncio_default_fixture_loop_scope = "function"
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
#addopts = ["--tb=short", "--capture=tee-sys"]
|
#addopts = ["--tb=short", "--capture=tee-sys"]
|
||||||
#log_cli = true
|
#log_cli = true
|
||||||
#log_level = "DEBUG"
|
log_level = "DEBUG"
|
||||||
|
|
||||||
# ------------------------------------ MYPY ------------------------------------
|
# ------------------------------------ MYPY ------------------------------------
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
from amqtt.broker import BrokerContext
|
||||||
|
from amqtt.plugins.base import BasePlugin
|
||||||
|
|
||||||
|
# intentional import error to test broker response
|
||||||
|
from pathlib import Pat # noqa
|
||||||
|
|
||||||
|
class MockImportErrorPlugin(BasePlugin):
|
||||||
|
|
||||||
|
def __init__(self, context: BrokerContext) -> None:
|
||||||
|
super().__init__(context)
|
|
@ -1,12 +1,17 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
from importlib.metadata import EntryPoint
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import amqtt.plugins
|
import amqtt.plugins
|
||||||
|
from amqtt.broker import Broker, BrokerContext
|
||||||
|
from amqtt.errors import PluginError, PluginInitError, PluginImportError
|
||||||
|
from amqtt.plugins.base import BasePlugin
|
||||||
from amqtt.plugins.manager import BaseContext
|
from amqtt.plugins.manager import BaseContext
|
||||||
|
|
||||||
_INVALID_METHOD: str = "invalid_foo"
|
_INVALID_METHOD: str = "invalid_foo"
|
||||||
|
@ -61,3 +66,62 @@ def test_plugins_correct_has_attr() -> None:
|
||||||
__import__(name)
|
__import__(name)
|
||||||
|
|
||||||
_verify_module(module, module.__name__)
|
_verify_module(module, module.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MockInitErrorPlugin(BasePlugin):
|
||||||
|
|
||||||
|
def __init__(self, context: BrokerContext) -> None:
|
||||||
|
super().__init__(context)
|
||||||
|
raise KeyError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plugin_exception_while_init() -> None:
|
||||||
|
class MockEntryPoints:
|
||||||
|
|
||||||
|
def select(self, group) -> list[EntryPoint]:
|
||||||
|
match group:
|
||||||
|
case 'tests.mock_plugins':
|
||||||
|
return [
|
||||||
|
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
|
||||||
|
]
|
||||||
|
case _:
|
||||||
|
return list()
|
||||||
|
|
||||||
|
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"listeners": {
|
||||||
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
|
},
|
||||||
|
'sys_interval': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(PluginInitError):
|
||||||
|
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plugin_exception_while_loading() -> None:
|
||||||
|
class MockEntryPoints:
|
||||||
|
|
||||||
|
def select(self, group) -> list[EntryPoint]:
|
||||||
|
match group:
|
||||||
|
case 'tests.mock_plugins':
|
||||||
|
return [
|
||||||
|
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
|
||||||
|
]
|
||||||
|
case _:
|
||||||
|
return list()
|
||||||
|
|
||||||
|
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"listeners": {
|
||||||
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||||
|
},
|
||||||
|
'sys_interval': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(PluginImportError):
|
||||||
|
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||||
|
|
Ładowanie…
Reference in New Issue