kopia lustrzana https://github.com/Yakifo/amqtt
Merge pull request #203 from ajmirsky/issues/51
fixes Yakifo/amqtt#51 : plugins which fail to import or initializepull/200/head^2
commit
54af134832
|
@ -151,6 +151,9 @@ class Broker:
|
|||
loop: asyncio loop. defaults to `asyncio.new_event_loop()`.
|
||||
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
|
||||
|
||||
Raises:
|
||||
BrokerError, ParserError, PluginError
|
||||
|
||||
"""
|
||||
|
||||
states: ClassVar[list[str]] = [
|
||||
|
|
|
@ -88,6 +88,9 @@ class MQTTClient:
|
|||
it will be generated randomly by `amqtt.utils.gen_client_id`
|
||||
config: dictionary of configuration options (see [client configuration](client_config.md)).
|
||||
|
||||
Raises:
|
||||
PluginError
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
|
||||
|
@ -142,7 +145,7 @@ class MQTTClient:
|
|||
[CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
|
||||
|
||||
Raises:
|
||||
amqtt.client.ConnectException: if connection fails
|
||||
ClientError, ConnectError
|
||||
|
||||
"""
|
||||
additional_headers = additional_headers if additional_headers is not None else {}
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from typing import Any
|
||||
|
||||
|
||||
class AMQTTError(Exception):
|
||||
"""aMQTT base exception."""
|
||||
|
||||
|
@ -21,6 +24,20 @@ class BrokerError(Exception):
|
|||
"""Exceptions thrown by broker."""
|
||||
|
||||
|
||||
class PluginError(Exception):
|
||||
"""Exceptions thrown when loading or initializing a plugin."""
|
||||
|
||||
|
||||
class PluginImportError(PluginError):
|
||||
def __init__(self, plugin: Any) -> None:
|
||||
super().__init__(f"Plugin import failed: {plugin!r}")
|
||||
|
||||
|
||||
class PluginInitError(PluginError):
|
||||
def __init__(self, plugin: Any) -> None:
|
||||
super().__init__(f"Plugin init failed: {plugin!r}")
|
||||
|
||||
|
||||
class ClientError(Exception):
|
||||
"""Exceptions thrown by client."""
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@ from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
|||
import logging
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from amqtt.errors import PluginImportError, PluginInitError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -80,17 +82,21 @@ class PluginManager:
|
|||
try:
|
||||
self.logger.debug(f" Loading plugin {ep!s}")
|
||||
plugin = ep.load()
|
||||
self.logger.debug(f" Initializing plugin {ep!s}")
|
||||
|
||||
plugin_context = copy.copy(self.app_context)
|
||||
plugin_context.logger = self.logger.getChild(ep.name)
|
||||
except ImportError as e:
|
||||
self.logger.debug(f"Plugin import failed: {ep!r}", exc_info=True)
|
||||
raise PluginImportError(ep) from e
|
||||
|
||||
self.logger.debug(f" Initializing plugin {ep!s}")
|
||||
|
||||
plugin_context = copy.copy(self.app_context)
|
||||
plugin_context.logger = self.logger.getChild(ep.name)
|
||||
try:
|
||||
obj = plugin(plugin_context)
|
||||
return Plugin(ep.name, ep, obj)
|
||||
except ImportError:
|
||||
self.logger.warning(f"Plugin {ep!r} import failed")
|
||||
self.logger.debug("", exc_info=True)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
|
||||
raise PluginInitError(ep) from e
|
||||
|
||||
def get_plugin(self, name: str) -> Plugin | None:
|
||||
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
||||
|
|
|
@ -7,7 +7,7 @@ from yaml.parser import ParserError
|
|||
|
||||
from amqtt import __version__ as amqtt_version
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.errors import BrokerError
|
||||
from amqtt.errors import BrokerError, PluginError
|
||||
from amqtt.utils import read_yaml_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -59,7 +59,7 @@ def broker_main(
|
|||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
broker = Broker(config)
|
||||
except (BrokerError, ParserError) as exc:
|
||||
except (BrokerError, ParserError, PluginError) as exc:
|
||||
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
|
|
|
@ -191,7 +191,7 @@ timeout = 10
|
|||
asyncio_default_fixture_loop_scope = "function"
|
||||
#addopts = ["--tb=short", "--capture=tee-sys"]
|
||||
#log_cli = true
|
||||
#log_level = "DEBUG"
|
||||
log_level = "DEBUG"
|
||||
|
||||
# ------------------------------------ MYPY ------------------------------------
|
||||
[tool.mypy]
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
|
||||
# intentional import error to test broker response
|
||||
from pathlib import Pat # noqa
|
||||
|
||||
class MockImportErrorPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
|
@ -1,12 +1,17 @@
|
|||
import inspect
|
||||
from importlib.metadata import EntryPoint
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import amqtt.plugins
|
||||
from amqtt.broker import Broker, BrokerContext
|
||||
from amqtt.errors import PluginError, PluginInitError, PluginImportError
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
|
||||
_INVALID_METHOD: str = "invalid_foo"
|
||||
|
@ -61,3 +66,62 @@ def test_plugins_correct_has_attr() -> None:
|
|||
__import__(name)
|
||||
|
||||
_verify_module(module, module.__name__)
|
||||
|
||||
|
||||
class MockInitErrorPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
raise KeyError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_exception_while_init() -> None:
|
||||
class MockEntryPoints:
|
||||
|
||||
def select(self, group) -> list[EntryPoint]:
|
||||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
||||
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1
|
||||
}
|
||||
|
||||
with pytest.raises(PluginInitError):
|
||||
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_exception_while_loading() -> None:
|
||||
class MockEntryPoints:
|
||||
|
||||
def select(self, group) -> list[EntryPoint]:
|
||||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
||||
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1
|
||||
}
|
||||
|
||||
with pytest.raises(PluginImportError):
|
||||
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
|
Ładowanie…
Reference in New Issue