diff --git a/amqtt/broker.py b/amqtt/broker.py index 5416f52..2e1f979 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -41,18 +41,24 @@ _defaults = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yam DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 -EVENT_BROKER_PRE_START = "broker_pre_start" -EVENT_BROKER_POST_START = "broker_post_start" -EVENT_BROKER_PRE_SHUTDOWN = "broker_pre_shutdown" -EVENT_BROKER_POST_SHUTDOWN = "broker_post_shutdown" -EVENT_BROKER_CLIENT_CONNECTED = "broker_client_connected" -EVENT_BROKER_CLIENT_DISCONNECTED = "broker_client_disconnected" -EVENT_BROKER_CLIENT_SUBSCRIBED = "broker_client_subscribed" -EVENT_BROKER_CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed" -EVENT_BROKER_MESSAGE_RECEIVED = "broker_message_received" + +class EventBroker(Enum): + """Events issued by the broker.""" + + PRE_START = "broker_pre_start" + POST_START = "broker_post_start" + PRE_SHUTDOWN = "broker_pre_shutdown" + POST_SHUTDOWN = "broker_post_shutdown" + CLIENT_CONNECTED = "broker_client_connected" + CLIENT_DISCONNECTED = "broker_client_disconnected" + CLIENT_SUBSCRIBED = "broker_client_subscribed" + CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed" + MESSAGE_RECEIVED = "broker_message_received" class Action(Enum): + """Actions issued by the broker.""" + SUBSCRIBE = "subscribe" PUBLISH = "publish" @@ -242,11 +248,11 @@ class Broker: msg = f"Broker instance can't be started: {exc}" raise BrokerError(msg) from exc - await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START) + await self.plugins_manager.fire_event(EventBroker.PRE_START.value) try: await self._start_listeners() self.transitions.starting_success() - await self.plugins_manager.fire_event(EVENT_BROKER_POST_START) + await self.plugins_manager.fire_event(EventBroker.POST_START.value) self._broadcast_task = asyncio.ensure_future(self._broadcast_loop()) self.logger.debug("Broker started") except Exception as e: @@ -327,7 +333,7 @@ class Broker: """Stop broker instance.""" self.logger.info("Shutting down broker...") # Fire broker_shutdown event to plugins - await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN) + await self.plugins_manager.fire_event(EventBroker.PRE_SHUTDOWN.value) # Cleanup all sessions for client_id in list(self._sessions.keys()): @@ -351,7 +357,7 @@ class Broker: self._broadcast_queue.get_nowait() self.logger.info("Broker closed") - await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN) + await self.plugins_manager.fire_event(EventBroker.POST_SHUTDOWN.value) self.transitions.stopping_success() async def _cleanup_session(self, client_id: str) -> None: @@ -494,7 +500,7 @@ class Broker: self._sessions[client_session.client_id] = (client_session, handler) await handler.mqtt_connack_authorize(authenticated) - await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id) + await self.plugins_manager.fire_event(EventBroker.CLIENT_CONNECTED.value, client_id=client_session.client_id) self.logger.debug(f"{client_session.client_id} Start messages handling") await handler.start() @@ -582,7 +588,7 @@ class Broker: self.logger.debug(f"{client_session.client_id} Disconnecting session") await self._stop_handler(handler) client_session.transitions.disconnect() - await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id) + await self.plugins_manager.fire_event(EventBroker.CLIENT_DISCONNECTED.value, client_id=client_session.client_id) return False return True @@ -600,7 +606,7 @@ class Broker: for index, subscription in enumerate(subscriptions.topics): if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED: await self.plugins_manager.fire_event( - EVENT_BROKER_CLIENT_SUBSCRIBED, + EventBroker.CLIENT_SUBSCRIBED.value, client_id=client_session.client_id, topic=subscription[0], qos=subscription[1], @@ -619,7 +625,7 @@ class Broker: for topic in unsubscription.topics: self._del_subscription(topic, client_session) await self.plugins_manager.fire_event( - EVENT_BROKER_CLIENT_UNSUBSCRIBED, + EventBroker.CLIENT_UNSUBSCRIBED.value, client_id=client_session.client_id, topic=topic, ) @@ -654,7 +660,7 @@ class Broker: self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.") else: await self.plugins_manager.fire_event( - EVENT_BROKER_MESSAGE_RECEIVED, + EventBroker.MESSAGE_RECEIVED.value, client_id=client_session.client_id, message=app_message, ) diff --git a/amqtt/client.py b/amqtt/client.py index 8da9aba..c7c7d5d 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -219,7 +219,7 @@ class MQTTClient: self.logger.debug(f"Reconnecting with session parameters: {self.session}") reconnect_max_interval = self.config.get("reconnect_max_interval", 10) - reconnect_retries = self.config.get("reconnect_retries", 5) + reconnect_retries = self.config.get("reconnect_retries", 2) nb_attempt = 1 while True: @@ -232,7 +232,7 @@ class MQTTClient: except Exception as e: self.logger.warning(f"Reconnection attempt failed: {e!r}") self.logger.debug("", exc_info=True) - if reconnect_retries < nb_attempt: # reconnect_retries >= 0 and + if 0 <= reconnect_retries < nb_attempt: self.logger.exception("Maximum connection attempts reached. Reconnection aborted.") self.logger.debug("", exc_info=True) msg = "Too many failed attempts" @@ -470,6 +470,7 @@ class MQTTClient: reader: StreamReaderAdapter | WebSocketsReader | None = None writer: StreamWriterAdapter | WebSocketsWriter | None = None self._connected_state.clear() + # Open connection if scheme in ("mqtt", "mqtts"): conn_reader, conn_writer = await asyncio.open_connection( @@ -489,11 +490,11 @@ class MQTTClient: ) reader = WebSocketsReader(websocket) writer = WebSocketsWriter(websocket) - - if reader is None or writer is None: - self.session.transitions.disconnect() - self.logger.warning("reader or writer not initialized") - msg = "reader or writer not initialized" + elif not self.session.broker_uri: + msg = "missing broker uri" + raise ClientError(msg) + else: + msg = f"incorrect scheme defined in uri: '{scheme!r}'" raise ClientError(msg) # Start MQTT protocol @@ -533,7 +534,7 @@ class MQTTClient: while self.client_tasks: task = self.client_tasks.popleft() if not task.done(): - task.cancel() + task.cancel(msg="Connection closed.") self.logger.debug("Monitoring broker disconnection") # Wait for disconnection from broker (like connection lost) diff --git a/tests/test_broker.py b/tests/test_broker.py index 2d71a0a..3591112 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -7,18 +7,7 @@ import psutil import pytest from amqtt.adapters import StreamReaderAdapter, StreamWriterAdapter -from amqtt.broker import ( - EVENT_BROKER_CLIENT_CONNECTED, - EVENT_BROKER_CLIENT_DISCONNECTED, - EVENT_BROKER_CLIENT_SUBSCRIBED, - EVENT_BROKER_CLIENT_UNSUBSCRIBED, - EVENT_BROKER_MESSAGE_RECEIVED, - EVENT_BROKER_POST_SHUTDOWN, - EVENT_BROKER_POST_START, - EVENT_BROKER_PRE_SHUTDOWN, - EVENT_BROKER_PRE_START, - Broker, -) +from amqtt.broker import EventBroker, Broker from amqtt.client import MQTTClient from amqtt.errors import ConnectError from amqtt.mqtt.connack import ConnackPacket @@ -67,8 +56,8 @@ def test_split_bindaddr_port(input_str, output_addr, output_port): async def test_start_stop(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ - call().fire_event(EVENT_BROKER_PRE_START), - call().fire_event(EVENT_BROKER_POST_START), + call().fire_event(EventBroker.PRE_START.value), + call().fire_event(EventBroker.POST_START.value), ], any_order=True, ) @@ -76,8 +65,8 @@ async def test_start_stop(broker, mock_plugin_manager): await broker.shutdown() mock_plugin_manager.assert_has_calls( [ - call().fire_event(EVENT_BROKER_PRE_SHUTDOWN), - call().fire_event(EVENT_BROKER_POST_SHUTDOWN), + call().fire_event(EventBroker.PRE_SHUTDOWN.value), + call().fire_event(EventBroker.POST_SHUTDOWN.value), ], any_order=True, ) @@ -98,11 +87,11 @@ async def test_client_connect(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_CLIENT_CONNECTED, + EventBroker.CLIENT_CONNECTED.value, client_id=client.session.client_id, ), call().fire_event( - EVENT_BROKER_CLIENT_DISCONNECTED, + EventBroker.CLIENT_DISCONNECTED.value, client_id=client.session.client_id, ), ], @@ -235,7 +224,7 @@ async def test_client_subscribe(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_CLIENT_SUBSCRIBED, + EventBroker.CLIENT_SUBSCRIBED.value, client_id=client.session.client_id, topic="/topic", qos=QOS_0, @@ -272,7 +261,7 @@ async def test_client_subscribe_twice(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_CLIENT_SUBSCRIBED, + EventBroker.CLIENT_SUBSCRIBED.value, client_id=client.session.client_id, topic="/topic", qos=QOS_0, @@ -306,13 +295,13 @@ async def test_client_unsubscribe(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_CLIENT_SUBSCRIBED, + EventBroker.CLIENT_SUBSCRIBED.value, client_id=client.session.client_id, topic="/topic", qos=QOS_0, ), call().fire_event( - EVENT_BROKER_CLIENT_UNSUBSCRIBED, + EventBroker.CLIENT_UNSUBSCRIBED.value, client_id=client.session.client_id, topic="/topic", ), @@ -337,7 +326,7 @@ async def test_client_publish(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_MESSAGE_RECEIVED, + EventBroker.MESSAGE_RECEIVED.value, client_id=pub_client.session.client_id, message=ret_message, ), @@ -509,7 +498,7 @@ async def test_client_publish_big(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls( [ call().fire_event( - EVENT_BROKER_MESSAGE_RECEIVED, + EventBroker.MESSAGE_RECEIVED.value, client_id=pub_client.session.client_id, message=ret_message, ), @@ -740,3 +729,16 @@ async def test_broker_broadcast_cancellation(broker): await _client_publish(topic, data, qos) message = await asyncio.wait_for(sub_client.deliver_message(), timeout=1) assert message + + +@pytest.mark.asyncio +async def test_broker_socket_open_close(broker): + + # check that https://github.com/Yakifo/amqtt/issues/86 is fixed + + # mqtt 3.1 requires a connect packet, otherwise the socket connection is rejected + static_connect_packet = b'\x10\x1b\x00\x04MQTT\x04\x02\x00<\x00\x0ftest-client-123' + s = socket.create_connection(("127.0.0.1", 1883)) + s.send(static_connect_packet) + await asyncio.sleep(0.1) + s.close() diff --git a/tests/test_client.py b/tests/test_client.py index be504e4..6c5db8b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,7 +7,7 @@ import pytest from amqtt.broker import Broker from amqtt.client import MQTTClient -from amqtt.errors import ConnectError +from amqtt.errors import ClientError, ConnectError from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" @@ -332,9 +332,23 @@ async def test_client_with_will_empty_message(broker_fixture): await client2.disconnect() -async def test_client_no_auth(): +async def test_connect_broken_uri(): + config = {"auto_reconnect": False} + client = MQTTClient(config=config) + with pytest.raises(ClientError): + await client.connect('"mqtt://someplace') +@pytest.mark.asyncio +async def test_connect_incorrect_scheme(): + config = {"auto_reconnect": False} + client = MQTTClient(config=config) + with pytest.raises(ClientError): + await client.connect('"mq://someplace') + + +async def test_client_no_auth(): + class MockEntryPoints: def select(self, group) -> list[EntryPoint]: @@ -368,4 +382,3 @@ async def test_client_no_auth(): await client.connect("mqtt://127.0.0.1:1883/") await broker.shutdown() - diff --git a/tests/test_paho.py b/tests/test_paho.py index 2ffca53..2ba8936 100644 --- a/tests/test_paho.py +++ b/tests/test_paho.py @@ -6,8 +6,7 @@ from unittest.mock import MagicMock, call, patch import pytest from paho.mqtt import client as mqtt_client -from amqtt.broker import EVENT_BROKER_CLIENT_CONNECTED, EVENT_BROKER_CLIENT_DISCONNECTED, EVENT_BROKER_PRE_START, \ - EVENT_BROKER_POST_START +from amqtt.broker import EventBroker from amqtt.client import MQTTClient from amqtt.mqtt.constants import QOS_1, QOS_2 @@ -54,11 +53,11 @@ async def test_paho_connect(broker, mock_plugin_manager): broker.plugins_manager.assert_has_calls( [ call.fire_event( - EVENT_BROKER_CLIENT_CONNECTED, + EventBroker.CLIENT_CONNECTED.value, client_id=client_id, ), call.fire_event( - EVENT_BROKER_CLIENT_DISCONNECTED, + EventBroker.CLIENT_DISCONNECTED.value, client_id=client_id, ), ],