2022-02-05 19:44:30 +00:00
|
|
|
import inspect
|
2025-06-09 04:11:06 +00:00
|
|
|
from importlib.metadata import EntryPoint
|
2022-02-05 19:44:30 +00:00
|
|
|
from logging import getLogger
|
2024-12-29 18:23:27 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from types import ModuleType
|
|
|
|
from typing import Any
|
2025-06-09 04:11:06 +00:00
|
|
|
from unittest.mock import patch
|
2022-02-05 19:44:30 +00:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
import amqtt.plugins
|
2025-06-09 04:11:06 +00:00
|
|
|
from amqtt.broker import Broker, BrokerContext
|
|
|
|
from amqtt.errors import PluginError, PluginInitError, PluginImportError
|
|
|
|
from amqtt.plugins.base import BasePlugin
|
2022-02-05 19:44:30 +00:00
|
|
|
from amqtt.plugins.manager import BaseContext
|
|
|
|
|
2024-12-29 18:23:27 +00:00
|
|
|
_INVALID_METHOD: str = "invalid_foo"
|
|
|
|
_PLUGIN: str = "Plugin"
|
2022-02-05 19:44:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
class _TestContext(BaseContext):
|
2024-12-21 10:52:26 +00:00
|
|
|
def __init__(self) -> None:
|
2022-02-05 19:44:30 +00:00
|
|
|
super().__init__()
|
2024-12-29 18:23:27 +00:00
|
|
|
self.config: dict[str, Any] = {"auth": {}}
|
2024-12-21 10:52:26 +00:00
|
|
|
self.logger = getLogger(__name__)
|
2022-02-05 19:44:30 +00:00
|
|
|
|
|
|
|
|
2024-12-29 18:23:27 +00:00
|
|
|
def _verify_module(module: ModuleType, plugin_module_name: str) -> None:
|
2022-02-05 19:44:30 +00:00
|
|
|
if not module.__name__.startswith(plugin_module_name):
|
|
|
|
return
|
|
|
|
|
|
|
|
for name, clazz in inspect.getmembers(module, inspect.isclass):
|
|
|
|
if not name.endswith(_PLUGIN) or name == _PLUGIN:
|
|
|
|
continue
|
|
|
|
|
|
|
|
obj = clazz(_TestContext())
|
2022-02-05 20:08:17 +00:00
|
|
|
with pytest.raises(
|
|
|
|
AttributeError,
|
|
|
|
match=f"'{name}' object has no attribute '{_INVALID_METHOD}'",
|
|
|
|
):
|
2022-02-05 19:44:30 +00:00
|
|
|
getattr(obj, _INVALID_METHOD)
|
|
|
|
assert hasattr(obj, _INVALID_METHOD) is False
|
|
|
|
|
2024-12-29 18:23:27 +00:00
|
|
|
for _, obj in inspect.getmembers(module, inspect.ismodule):
|
2022-02-05 19:44:30 +00:00
|
|
|
_verify_module(obj, plugin_module_name)
|
|
|
|
|
|
|
|
|
2022-02-06 13:39:05 +00:00
|
|
|
def removesuffix(self: str, suffix: str) -> str:
|
2024-12-29 18:23:27 +00:00
|
|
|
"""Compatibility for Python versions prior to 3.9."""
|
2022-02-06 13:39:05 +00:00
|
|
|
if suffix and self.endswith(suffix):
|
|
|
|
return self[: -len(suffix)]
|
2024-12-21 10:52:26 +00:00
|
|
|
return self[:]
|
2022-02-06 13:39:05 +00:00
|
|
|
|
|
|
|
|
2024-12-29 18:23:27 +00:00
|
|
|
def test_plugins_correct_has_attr() -> None:
|
|
|
|
"""Test plugins to ensure they correctly handle the 'has_attr' check."""
|
2022-02-05 19:44:30 +00:00
|
|
|
module = amqtt.plugins
|
2024-12-29 18:23:27 +00:00
|
|
|
for file in Path(module.__file__).parent.rglob("*.py"):
|
|
|
|
if not Path(file).is_file():
|
2022-02-05 19:44:30 +00:00
|
|
|
continue
|
|
|
|
|
2024-12-29 18:23:27 +00:00
|
|
|
name = file.as_posix().replace("/", ".")
|
2022-02-06 13:39:05 +00:00
|
|
|
name = name[name.find(module.__name__) : -3]
|
|
|
|
name = removesuffix(name, ".__init__")
|
2022-02-05 19:44:30 +00:00
|
|
|
|
|
|
|
__import__(name)
|
|
|
|
|
|
|
|
_verify_module(module, module.__name__)
|
2025-06-09 04:11:06 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MockInitErrorPlugin(BasePlugin):
|
|
|
|
|
|
|
|
def __init__(self, context: BrokerContext) -> None:
|
|
|
|
super().__init__(context)
|
|
|
|
raise KeyError
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_plugin_exception_while_init() -> None:
|
|
|
|
class MockEntryPoints:
|
|
|
|
|
|
|
|
def select(self, group) -> list[EntryPoint]:
|
|
|
|
match group:
|
|
|
|
case 'tests.mock_plugins':
|
|
|
|
return [
|
|
|
|
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
|
|
|
|
]
|
|
|
|
case _:
|
|
|
|
return list()
|
|
|
|
|
|
|
|
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
|
|
|
|
|
|
|
config = {
|
|
|
|
"listeners": {
|
|
|
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
|
|
|
},
|
|
|
|
'sys_interval': 1
|
|
|
|
}
|
|
|
|
|
|
|
|
with pytest.raises(PluginInitError):
|
|
|
|
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_plugin_exception_while_loading() -> None:
|
|
|
|
class MockEntryPoints:
|
|
|
|
|
|
|
|
def select(self, group) -> list[EntryPoint]:
|
|
|
|
match group:
|
|
|
|
case 'tests.mock_plugins':
|
|
|
|
return [
|
|
|
|
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
|
|
|
|
]
|
|
|
|
case _:
|
|
|
|
return list()
|
|
|
|
|
|
|
|
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
|
|
|
|
|
|
|
config = {
|
|
|
|
"listeners": {
|
|
|
|
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
|
|
|
},
|
|
|
|
'sys_interval': 1
|
|
|
|
}
|
|
|
|
|
|
|
|
with pytest.raises(PluginImportError):
|
|
|
|
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|