From 371e76fb308ecc4739c3190f97e8fa2368aea3ba Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Sun, 1 Jun 2025 11:09:40 -0400 Subject: [PATCH] adding test to cover adapter base classes. adding test for BrokerSysPlugin --- amqtt/plugins/sys/broker.py | 2 +- tests/plugins/test_sys.py | 66 +++++++++++++++++++++++++++++++++++++ tests/test_adapters.py | 50 ++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 tests/plugins/test_sys.py create mode 100644 tests/test_adapters.py diff --git a/amqtt/plugins/sys/broker.py b/amqtt/plugins/sys/broker.py index f7ede3f..996e399 100644 --- a/amqtt/plugins/sys/broker.py +++ b/amqtt/plugins/sys/broker.py @@ -104,7 +104,7 @@ class BrokerSysPlugin(BasePlugin): pass # 'sys_interval' config parameter not found - async def on_broker_pre_stop(self, *args: None, **kwargs: None) -> None: + async def on_broker_pre_shutdown(self, *args: None, **kwargs: None) -> None: """Stop $SYS topics broadcasting.""" if self._sys_handle: self._sys_handle.cancel() diff --git a/tests/plugins/test_sys.py b/tests/plugins/test_sys.py new file mode 100644 index 0000000..4ee0454 --- /dev/null +++ b/tests/plugins/test_sys.py @@ -0,0 +1,66 @@ +import asyncio +import inspect +import logging +from importlib.metadata import EntryPoint +from logging import getLogger +from pathlib import Path +from types import ModuleType +from typing import Any +from unittest.mock import patch + +import pytest + +from amqtt.broker import Broker +from amqtt.client import MQTTClient +from amqtt.mqtt.constants import QOS_0 + + +logger = logging.getLogger(__name__) + +# test broker sys +@pytest.mark.asyncio +async def test_broker_sys_plugin() -> None: + + class MockEntryPoints: + + def select(self, group) -> list[EntryPoint]: + match group: + case 'tests.mock_plugins': + return [ + EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'), + ] + case _: + return list() + + + with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish: + + config = { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, + }, + 'sys_interval': 1 + } + + broker = Broker(plugin_namespace='tests.mock_plugins', config=config) + await broker.start() + client = MQTTClient() + await client.connect("mqtt://127.0.0.1:1883/") + await client.subscribe([("$SYS/broker/uptime", QOS_0),]) + await client.publish('test/topic', b'my test message') + await asyncio.sleep(2) + sys_msg_count = 0 + try: + while True: + message = await client.deliver_message(timeout_duration=0.5) + if '$SYS' in message.topic: + sys_msg_count += 1 + except TimeoutError: + pass + + logger.warning(f">>> sys message: {message.topic} - {message.data}") + await client.disconnect() + await broker.shutdown() + + + assert sys_msg_count > 1 diff --git a/tests/test_adapters.py b/tests/test_adapters.py new file mode 100644 index 0000000..c715f1b --- /dev/null +++ b/tests/test_adapters.py @@ -0,0 +1,50 @@ +import pytest + +from amqtt.adapters import ReaderAdapter, WriterAdapter + + +class BrokenReaderAdapter(ReaderAdapter): + + async def read(self, n: int = -1) -> bytes: + return await super().read(n) + + def feed_eof(self) -> None: + return super().feed_eof() + + +@pytest.mark.asyncio +async def test_abstract_read_raises(): + reader = BrokenReaderAdapter() + with pytest.raises(NotImplementedError): + await reader.read() + + with pytest.raises(NotImplementedError): + reader.feed_eof() + +class BrokerWriterAdapter(WriterAdapter): + def write(self, data: bytes) -> None: + super().write(data) + + async def drain(self) -> None: + await super().drain() + + def get_peer_info(self) -> tuple[str, int] | None: + return super().get_peer_info() + + async def close(self) -> None: + await super().close() + +@pytest.mark.asyncio +async def test_abstract_write_raises(): + writer = BrokerWriterAdapter() + with pytest.raises(NotImplementedError): + writer.write(b'') + + with pytest.raises(NotImplementedError): + await writer.drain() + + with pytest.raises(NotImplementedError): + writer.get_peer_info() + + with pytest.raises(NotImplementedError): + await writer.close()