amqtt/tests/plugins/test_plugins.py

200 wiersze
6.7 KiB
Python
Czysty Zwykły widok Historia

import asyncio
2022-02-05 19:44:30 +00:00
import inspect
from functools import partial
from importlib.metadata import EntryPoint
2022-02-05 19:44:30 +00:00
from logging import getLogger
2024-12-29 18:23:27 +00:00
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Coroutine
from unittest.mock import patch
2022-02-05 19:44:30 +00:00
import pytest
import amqtt.plugins
from amqtt.broker import Broker, BrokerContext
from amqtt.client import MQTTClient
from amqtt.errors import PluginError, PluginInitError, PluginImportError
from amqtt.events import MQTTEvents, BrokerEvents
from amqtt.mqtt.constants import QOS_0
from amqtt.plugins.base import BasePlugin
2022-02-05 19:44:30 +00:00
from amqtt.plugins.manager import BaseContext
2024-12-29 18:23:27 +00:00
_INVALID_METHOD: str = "invalid_foo"
_PLUGIN: str = "Plugin"
2022-02-05 19:44:30 +00:00
class _TestContext(BaseContext):
def __init__(self) -> None:
2022-02-05 19:44:30 +00:00
super().__init__()
2024-12-29 18:23:27 +00:00
self.config: dict[str, Any] = {"auth": {}}
self.logger = getLogger(__name__)
2022-02-05 19:44:30 +00:00
2024-12-29 18:23:27 +00:00
def _verify_module(module: ModuleType, plugin_module_name: str) -> None:
2022-02-05 19:44:30 +00:00
if not module.__name__.startswith(plugin_module_name):
return
for name, clazz in inspect.getmembers(module, inspect.isclass):
if not name.endswith(_PLUGIN) or name == _PLUGIN:
continue
obj = clazz(_TestContext())
2022-02-05 20:08:17 +00:00
with pytest.raises(
AttributeError,
match=f"'{name}' object has no attribute '{_INVALID_METHOD}'",
):
2022-02-05 19:44:30 +00:00
getattr(obj, _INVALID_METHOD)
assert hasattr(obj, _INVALID_METHOD) is False
2024-12-29 18:23:27 +00:00
for _, obj in inspect.getmembers(module, inspect.ismodule):
2022-02-05 19:44:30 +00:00
_verify_module(obj, plugin_module_name)
2022-02-06 13:39:05 +00:00
def removesuffix(self: str, suffix: str) -> str:
2024-12-29 18:23:27 +00:00
"""Compatibility for Python versions prior to 3.9."""
2022-02-06 13:39:05 +00:00
if suffix and self.endswith(suffix):
return self[: -len(suffix)]
return self[:]
2022-02-06 13:39:05 +00:00
2024-12-29 18:23:27 +00:00
def test_plugins_correct_has_attr() -> None:
"""Test plugins to ensure they correctly handle the 'has_attr' check."""
2022-02-05 19:44:30 +00:00
module = amqtt.plugins
2024-12-29 18:23:27 +00:00
for file in Path(module.__file__).parent.rglob("*.py"):
if not Path(file).is_file():
2022-02-05 19:44:30 +00:00
continue
2024-12-29 18:23:27 +00:00
name = file.as_posix().replace("/", ".")
2022-02-06 13:39:05 +00:00
name = name[name.find(module.__name__) : -3]
name = removesuffix(name, ".__init__")
2022-02-05 19:44:30 +00:00
__import__(name)
_verify_module(module, module.__name__)
class MockInitErrorPlugin(BasePlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
raise KeyError
@pytest.mark.asyncio
async def test_plugin_exception_while_init() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
with pytest.raises(PluginInitError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
@pytest.mark.asyncio
async def test_plugin_exception_while_loading() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
with pytest.raises(PluginImportError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
class TestAllEventsPlugin(BasePlugin[BaseContext]):
"""A plugin to verify all events get sent to plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.test_flags = { events:False for events in list(MQTTEvents) + list(BrokerEvents)}
async def call_method(self, event_name: str, **kwargs: Any) -> None:
assert event_name in self.test_flags
self.test_flags[event_name] = True
def __getattr__(self, name: str) -> Callable[..., Coroutine[Any, Any, None]]:
"""Dynamically handle calls to methods starting with 'on_'."""
if name.startswith("on_"):
event_name = name.replace('on_', '')
return partial(self.call_method, event_name)
if name not in ('authenticate', 'topic_filtering'):
pytest.fail(f'unexpected method called: {name}')
@pytest.mark.asyncio
async def test_all_plugin_events():
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestAllEventsPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:TestAllEventsPlugin'),
]
case _:
return list()
# patch the entry points so we can load our test plugin
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
await broker.start()
await asyncio.sleep(2)
# make sure all expected events get triggered
client = MQTTClient()
await client.connect("mqtt://127.0.0.1:1883/")
await client.subscribe([('my/test/topic', QOS_0),])
await client.publish('test/topic', b'my test message')
await client.unsubscribe(['my/test/topic',])
await client.disconnect()
await asyncio.sleep(1)
# get the plugin so it doesn't get gc on shutdown
test_plugin = broker.plugins_manager.get_plugin('TestAllEventsPlugin')
await broker.shutdown()
await asyncio.sleep(1)
assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}'