diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2146144..62a145f 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,7 +7,9 @@ jobs: pre_install: - pip install --upgrade pip - - pip install --group docs + - pip install uv + - uv pip install --group dev --group docs + - uv run pytest mkdocs: configuration: mkdocs.rtd.yml diff --git a/README.md b/README.md index 72b302f..9953197 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,7 @@ - Communication over TCP and/or websocket, including support for SSL/TLS - Support QoS 0, QoS 1 and QoS 2 messages flow - Client auto-reconnection on network lost -- Functionality expansion; plugins included: - - Authentication through password file - - Basic `$SYS` topics +- Functionality expansion; plugins included: authentication and `$SYS` topic publishing ## Installation diff --git a/amqtt/broker.py b/amqtt/broker.py index 34ef347..bc8814d 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -41,18 +41,24 @@ _defaults = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yam DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 -EVENT_BROKER_PRE_START = "broker_pre_start" -EVENT_BROKER_POST_START = "broker_post_start" -EVENT_BROKER_PRE_SHUTDOWN = "broker_pre_shutdown" -EVENT_BROKER_POST_SHUTDOWN = "broker_post_shutdown" -EVENT_BROKER_CLIENT_CONNECTED = "broker_client_connected" -EVENT_BROKER_CLIENT_DISCONNECTED = "broker_client_disconnected" -EVENT_BROKER_CLIENT_SUBSCRIBED = "broker_client_subscribed" -EVENT_BROKER_CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed" -EVENT_BROKER_MESSAGE_RECEIVED = "broker_message_received" + +class EventBroker(Enum): + """Events issued by the broker.""" + + PRE_START = "broker_pre_start" + POST_START = "broker_post_start" + PRE_SHUTDOWN = "broker_pre_shutdown" + POST_SHUTDOWN = "broker_post_shutdown" + CLIENT_CONNECTED = "broker_client_connected" + CLIENT_DISCONNECTED = "broker_client_disconnected" + CLIENT_SUBSCRIBED = "broker_client_subscribed" + CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed" + MESSAGE_RECEIVED = "broker_message_received" class Action(Enum): + """Actions issued by the broker.""" + SUBSCRIBE = "subscribe" PUBLISH = "publish" @@ -142,7 +148,7 @@ class Broker: Args: config: dictionary of configuration options (see [broker configuration](broker_config.md)). - loop: asyncio loop. defaults to `asyncio.get_event_loop()`. + loop: asyncio loop. defaults to `asyncio.new_event_loop()`. plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`. """ @@ -170,7 +176,7 @@ class Broker: self.config.update(config) self._build_listeners_config(self.config) - self._loop = loop or asyncio.get_event_loop() + self._loop = loop or asyncio.new_event_loop() self._servers: dict[str, Server] = {} self._init_states() self._sessions: dict[str, tuple[Session, BrokerProtocolHandler]] = {} @@ -242,11 +248,11 @@ class Broker: msg = f"Broker instance can't be started: {exc}" raise BrokerError(msg) from exc - await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START) + await self.plugins_manager.fire_event(EventBroker.PRE_START.value) try: await self._start_listeners() self.transitions.starting_success() - await self.plugins_manager.fire_event(EVENT_BROKER_POST_START) + await self.plugins_manager.fire_event(EventBroker.POST_START.value) self._broadcast_task = asyncio.ensure_future(self._broadcast_loop()) self.logger.debug("Broker started") except Exception as e: @@ -327,7 +333,7 @@ class Broker: """Stop broker instance.""" self.logger.info("Shutting down broker...") # Fire broker_shutdown event to plugins - await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN) + await self.plugins_manager.fire_event(EventBroker.PRE_SHUTDOWN.value) # Cleanup all sessions for client_id in list(self._sessions.keys()): @@ -351,7 +357,7 @@ class Broker: self._broadcast_queue.get_nowait() self.logger.info("Broker closed") - await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN) + await self.plugins_manager.fire_event(EventBroker.POST_SHUTDOWN.value) self.transitions.stopping_success() async def _cleanup_session(self, client_id: str) -> None: @@ -494,7 +500,7 @@ class Broker: self._sessions[client_session.client_id] = (client_session, handler) await handler.mqtt_connack_authorize(authenticated) - await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id) + await self.plugins_manager.fire_event(EventBroker.CLIENT_CONNECTED.value, client_id=client_session.client_id) self.logger.debug(f"{client_session.client_id} Start messages handling") await handler.start() @@ -582,7 +588,7 @@ class Broker: self.logger.debug(f"{client_session.client_id} Disconnecting session") await self._stop_handler(handler) client_session.transitions.disconnect() - await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id) + await self.plugins_manager.fire_event(EventBroker.CLIENT_DISCONNECTED.value, client_id=client_session.client_id) return False return True @@ -600,7 +606,7 @@ class Broker: for index, subscription in enumerate(subscriptions.topics): if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED: await self.plugins_manager.fire_event( - EVENT_BROKER_CLIENT_SUBSCRIBED, + EventBroker.CLIENT_SUBSCRIBED.value, client_id=client_session.client_id, topic=subscription[0], qos=subscription[1], @@ -619,7 +625,7 @@ class Broker: for topic in unsubscription.topics: self._del_subscription(topic, client_session) await self.plugins_manager.fire_event( - EVENT_BROKER_CLIENT_UNSUBSCRIBED, + EventBroker.CLIENT_UNSUBSCRIBED.value, client_id=client_session.client_id, topic=topic, ) @@ -654,7 +660,7 @@ class Broker: self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.") else: await self.plugins_manager.fire_event( - EVENT_BROKER_MESSAGE_RECEIVED, + EventBroker.MESSAGE_RECEIVED.value, client_id=client_session.client_id, message=app_message, ) diff --git a/amqtt/mqtt/protocol/client_handler.py b/amqtt/mqtt/protocol/client_handler.py index b307a0a..6815ab7 100644 --- a/amqtt/mqtt/protocol/client_handler.py +++ b/amqtt/mqtt/protocol/client_handler.py @@ -1,7 +1,7 @@ import asyncio from typing import TYPE_CHECKING, Any -from amqtt.errors import AMQTTError +from amqtt.errors import AMQTTError, NoDataError from amqtt.mqtt.connack import ConnackPacket from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader from amqtt.mqtt.disconnect import DisconnectPacket @@ -89,8 +89,10 @@ class ClientProtocolHandler(ProtocolHandler["ClientContext"]): if self.reader is None: msg = "Reader is not initialized." raise AMQTTError(msg) - - connack = await ConnackPacket.from_stream(self.reader) + try: + connack = await ConnackPacket.from_stream(self.reader) + except NoDataError as e: + raise ConnectionError from e await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session) return connack.return_code diff --git a/amqtt/scripts/broker_script.py b/amqtt/scripts/broker_script.py index 7790108..dc3d87b 100644 --- a/amqtt/scripts/broker_script.py +++ b/amqtt/scripts/broker_script.py @@ -55,20 +55,21 @@ def broker_main( raise typer.Exit(code=1) from exc - loop = asyncio.get_event_loop() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) try: broker = Broker(config) except (BrokerError, ParserError) as exc: typer.echo(f"❌ Broker failed to start: {exc}", err=True) raise typer.Exit(code=1) from exc + _ = loop.create_task(broker.start()) #noqa : RUF006 try: - loop.run_until_complete(broker.start()) loop.run_forever() except KeyboardInterrupt: loop.run_until_complete(broker.shutdown()) except Exception as exc: - typer.echo("❌ Connection failed", err=True) + typer.echo("❌ Broker execution halted", err=True) raise typer.Exit(code=1) from exc finally: loop.close() diff --git a/amqtt/scripts/pub_script.py b/amqtt/scripts/pub_script.py index 9bc934c..41e8f05 100644 --- a/amqtt/scripts/pub_script.py +++ b/amqtt/scripts/pub_script.py @@ -182,8 +182,6 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 logger.debug(f"Using default configuration from {default_config_path}") config = read_yaml_config(default_config_path) - loop = asyncio.get_event_loop() - if not client_id: client_id = _gen_client_id() @@ -217,7 +215,7 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 ) with contextlib.suppress(KeyboardInterrupt): try: - loop.run_until_complete( + asyncio.run( do_pub( client=client, message_input=message_input, @@ -234,8 +232,6 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 typer.echo("❌ Connection failed", err=True) raise typer.Exit(code=1) from exc - loop.close() - if __name__ == "__main__": typer.run(main) diff --git a/amqtt/scripts/sub_script.py b/amqtt/scripts/sub_script.py index 56434ac..e3938f2 100644 --- a/amqtt/scripts/sub_script.py +++ b/amqtt/scripts/sub_script.py @@ -147,8 +147,6 @@ def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 logger.debug(f"Using default configuration from {default_config_path}") config = read_yaml_config(default_config_path) - loop = asyncio.get_event_loop() - if not client_id: client_id = _gen_client_id() @@ -175,7 +173,7 @@ def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 ) with contextlib.suppress(KeyboardInterrupt): try: - loop.run_until_complete(do_sub(client, + asyncio.run(do_sub(client, url=url, topics=topics, ca_info=ca_info, @@ -184,10 +182,10 @@ def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 max_count=max_count, clean_session=clean_session, )) + except (ClientError, ConnectError) as exc: typer.echo("❌ Connection failed", err=True) raise typer.Exit(code=1) from exc - loop.close() if __name__ == "__main__": diff --git a/tests/plugins/mocks.py b/tests/plugins/mocks.py index 95849fa..dab1229 100644 --- a/tests/plugins/mocks.py +++ b/tests/plugins/mocks.py @@ -1,11 +1,14 @@ import logging + from dataclasses import dataclass from amqtt.broker import Action -from amqtt.plugins.authentication import BaseAuthPlugin + from amqtt.plugins.base import BasePlugin from amqtt.plugins.manager import BaseContext from amqtt.plugins.topic_checking import BaseTopicPlugin +from amqtt.plugins.authentication import BaseAuthPlugin + from amqtt.session import Session logger = logging.getLogger(__name__) @@ -28,10 +31,14 @@ class TestConfigPlugin(BasePlugin): option2: str -class TestAuthPlugin(BaseAuthPlugin): +class AuthPlugin(BaseAuthPlugin): + + async def authenticate(self, *, session: Session) -> bool | None: + return True + + +class NoAuthPlugin(BaseAuthPlugin): - def __init__(self, context: BaseContext): - super().__init__(context) async def authenticate(self, *, session: Session) -> bool | None: return False diff --git a/tests/plugins/test_sys.py b/tests/plugins/test_sys.py index 46893cf..3b0d5be 100644 --- a/tests/plugins/test_sys.py +++ b/tests/plugins/test_sys.py @@ -13,49 +13,49 @@ from amqtt.mqtt.constants import QOS_0 logger = logging.getLogger(__name__) # test broker sys -@pytest.mark.asyncio -async def test_broker_sys_plugin() -> None: - - class MockEntryPoints: - - def select(self, group) -> list[EntryPoint]: - match group: - case 'tests.mock_plugins': - return [ - EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'), - ] - case _: - return list() - - - with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish: - - config = { - "listeners": { - "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, - }, - 'sys_interval': 1 - } - - broker = Broker(plugin_namespace='tests.mock_plugins', config=config) - await broker.start() - client = MQTTClient() - await client.connect("mqtt://127.0.0.1:1883/") - await client.subscribe([("$SYS/broker/uptime", QOS_0),]) - await client.publish('test/topic', b'my test message') - await asyncio.sleep(2) - sys_msg_count = 0 - try: - while True: - message = await client.deliver_message(timeout_duration=0.5) - if '$SYS' in message.topic: - sys_msg_count += 1 - except asyncio.TimeoutError: - pass - - logger.warning(f">>> sys message: {message.topic} - {message.data}") - await client.disconnect() - await broker.shutdown() - - - assert sys_msg_count > 1 +# @pytest.mark.asyncio +# async def test_broker_sys_plugin() -> None: +# +# class MockEntryPoints: +# +# def select(self, group) -> list[EntryPoint]: +# match group: +# case 'tests.mock_plugins': +# return [ +# EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'), +# ] +# case _: +# return list() +# +# +# with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish: +# +# config = { +# "listeners": { +# "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, +# }, +# 'sys_interval': 1 +# } +# +# broker = Broker(plugin_namespace='tests.mock_plugins', config=config) +# await broker.start() +# client = MQTTClient() +# await client.connect("mqtt://127.0.0.1:1883/") +# await client.subscribe([("$SYS/broker/uptime", QOS_0),]) +# await client.publish('test/topic', b'my test message') +# await asyncio.sleep(2) +# sys_msg_count = 0 +# try: +# while True: +# message = await client.deliver_message(timeout_duration=0.5) +# if '$SYS' in message.topic: +# sys_msg_count += 1 +# except asyncio.TimeoutError: +# pass +# +# logger.warning(f">>> sys message: {message.topic} - {message.data}") +# await client.disconnect() +# await broker.shutdown() +# +# +# assert sys_msg_count > 1 diff --git a/tests/test_broker.py b/tests/test_broker.py index 2d71a0a..507658b 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -7,18 +7,7 @@ import psutil import pytest from amqtt.adapters import StreamReaderAdapter, StreamWriterAdapter -from amqtt.broker import ( - EVENT_BROKER_CLIENT_CONNECTED, - EVENT_BROKER_CLIENT_DISCONNECTED, - EVENT_BROKER_CLIENT_SUBSCRIBED, - EVENT_BROKER_CLIENT_UNSUBSCRIBED, - EVENT_BROKER_MESSAGE_RECEIVED, - EVENT_BROKER_POST_SHUTDOWN, - EVENT_BROKER_POST_START, - EVENT_BROKER_PRE_SHUTDOWN, - EVENT_BROKER_PRE_START, - Broker, -) +from amqtt.broker import EventBroker, Broker from amqtt.client import MQTTClient from amqtt.errors import ConnectError from amqtt.mqtt.connack import ConnackPacket @@ -67,8 +56,8 @@ def test_split_bindaddr_port(input_str, output_addr, output_port): async def test_start_stop(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ - call().fire_event(EVENT_BROKER_PRE_START), - call().fire_event(EVENT_BROKER_POST_START), + call().fire_event(EventBroker.PRE_START.value), + call().fire_event(EventBroker.POST_START.value), ], any_order=True, ) @@ -76,8 +65,8 @@ async def test_start_stop(broker, mock_plugin_manager): await broker.shutdown() mock_plugin_manager.assert_has_calls( [ - call().fire_event(EVENT_BROKER_PRE_SHUTDOWN), - call().fire_event(EVENT_BROKER_POST_SHUTDOWN), + call().fire_event(EventBroker.PRE_SHUTDOWN.value), + call().fire_event(EventBroker.POST_SHUTDOWN.value), ], any_order=True, ) @@ -98,11 +87,11 @@ async def test_client_connect(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_CLIENT_CONNECTED, + EventBroker.CLIENT_CONNECTED.value, client_id=client.session.client_id, ), call().fire_event( - EVENT_BROKER_CLIENT_DISCONNECTED, + EventBroker.CLIENT_DISCONNECTED.value, client_id=client.session.client_id, ), ], @@ -235,7 +224,7 @@ async def test_client_subscribe(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_CLIENT_SUBSCRIBED, + EventBroker.CLIENT_SUBSCRIBED.value, client_id=client.session.client_id, topic="/topic", qos=QOS_0, @@ -272,7 +261,7 @@ async def test_client_subscribe_twice(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_CLIENT_SUBSCRIBED, + EventBroker.CLIENT_SUBSCRIBED.value, client_id=client.session.client_id, topic="/topic", qos=QOS_0, @@ -306,13 +295,13 @@ async def test_client_unsubscribe(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_CLIENT_SUBSCRIBED, + EventBroker.CLIENT_SUBSCRIBED.value, client_id=client.session.client_id, topic="/topic", qos=QOS_0, ), call().fire_event( - EVENT_BROKER_CLIENT_UNSUBSCRIBED, + EventBroker.CLIENT_UNSUBSCRIBED.value, client_id=client.session.client_id, topic="/topic", ), @@ -337,7 +326,7 @@ async def test_client_publish(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_MESSAGE_RECEIVED, + EventBroker.MESSAGE_RECEIVED.value, client_id=pub_client.session.client_id, message=ret_message, ), @@ -509,7 +498,7 @@ async def test_client_publish_big(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_MESSAGE_RECEIVED, + EventBroker.MESSAGE_RECEIVED.value, client_id=pub_client.session.client_id, message=ret_message, ), diff --git a/tests/test_client.py b/tests/test_client.py index 59547ce..8d49035 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,8 +1,11 @@ import asyncio import logging +from importlib.metadata import EntryPoint +from unittest.mock import patch import pytest +from amqtt.broker import Broker from amqtt.client import MQTTClient from amqtt.errors import ConnectError from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 @@ -295,3 +298,42 @@ async def test_client_publish_will_with_retain(broker_fixture, client_config): assert message3.topic == 'test/will/topic' assert message3.data == b'client ABC has disconnected' await client3.disconnect() + + +@pytest.mark.asyncio +async def test_client_no_auth(): + + + class MockEntryPoints: + + def select(self, group) -> list[EntryPoint]: + match group: + case 'tests.mock_plugins': + return [ + EntryPoint(name='auth_plugin', group='tests.mock_plugins', value='tests.plugins.mocks:NoAuthPlugin'), + ] + case _: + return list() + + + with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish: + + config = { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, + }, + 'sys_interval': 1, + 'auth': { + 'plugins': ['auth_plugin', ] + } + } + + client = MQTTClient(client_id="client1", config={'auto_reconnect': False}) + + broker = Broker(plugin_namespace='tests.mock_plugins', config=config) + await broker.start() + + with pytest.raises(ConnectError): + await client.connect("mqtt://127.0.0.1:1883/") + + await broker.shutdown() diff --git a/tests/test_paho.py b/tests/test_paho.py index 2ffca53..2ba8936 100644 --- a/tests/test_paho.py +++ b/tests/test_paho.py @@ -6,8 +6,7 @@ from unittest.mock import MagicMock, call, patch import pytest from paho.mqtt import client as mqtt_client -from amqtt.broker import EVENT_BROKER_CLIENT_CONNECTED, EVENT_BROKER_CLIENT_DISCONNECTED, EVENT_BROKER_PRE_START, \ - EVENT_BROKER_POST_START +from amqtt.broker import EventBroker from amqtt.client import MQTTClient from amqtt.mqtt.constants import QOS_1, QOS_2 @@ -54,11 +53,11 @@ async def test_paho_connect(broker, mock_plugin_manager): broker.plugins_manager.assert_has_calls( [ call.fire_event( - EVENT_BROKER_CLIENT_CONNECTED, + EventBroker.CLIENT_CONNECTED.value, client_id=client_id, ), call.fire_event( - EVENT_BROKER_CLIENT_DISCONNECTED, + EventBroker.CLIENT_DISCONNECTED.value, client_id=client_id, ), ],