diff --git a/.gitignore b/.gitignore index 1d775a3..77680f3 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,8 @@ site/ _build/ .hypothesis/ coverage.xml + +#----- generated files ----- +*.log +*memray* +.coverage* diff --git a/amqtt/broker.py b/amqtt/broker.py index ef5bdcc..6f311ae 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -29,6 +29,7 @@ from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Sessio from amqtt.utils import format_client_message, gen_client_id, read_yaml_config from .events import BrokerEvents +from .mqtt.constants import QOS_0, QOS_1, QOS_2 from .mqtt.disconnect import DisconnectPacket from .plugins.manager import PluginManager @@ -435,6 +436,7 @@ class Broker: await self._delete_session(client_session.client_id) else: client_session.client_id = gen_client_id() + client_session.parent = 0 # Get session from cache elif client_session.client_id in self._sessions: @@ -494,9 +496,18 @@ class Broker: self.logger.debug(f"{client_session.client_id} Start messages handling") await handler.start() + + # publish messages that were retained because the client session was disconnected self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}") await self._publish_session_retained_messages(client_session) + # if this is not a new session, there are subscriptions associated with them; publish any topic retained messages + self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.") + for topic in self._subscriptions: + await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session) + + + await self._client_message_loop(client_session, handler) async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None: @@ -878,11 +889,20 @@ class Broker: qos = broadcast.get("qos", sub_qos) # Retain all messages which cannot be broadcasted, due to the session not being connected - if target_session.transitions.state != "connected": + # but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4] + # and, if a client used anonymous authentication, there is no expectation that messages should be retained + if (target_session.transitions.state != "connected" + and not target_session.clean_session + and qos in (QOS_1, QOS_2) + and not target_session.is_anonymous): self.logger.debug(f"Session {target_session.client_id} is not connected, retaining message.") await self._retain_broadcast_message(broadcast, qos, target_session) continue + # Only broadcast the message to connected clients + if target_session.transitions.state != "connected": + continue + self.logger.debug( f"Broadcasting message from {format_client_message(session=broadcast['session'])}" f" on topic '{broadcast['topic']}' to {format_client_message(session=target_session)}", diff --git a/amqtt/client.py b/amqtt/client.py index 70781a1..6e781f6 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -598,7 +598,7 @@ class MQTTClient: session.cadata = broker_conf.get("cadata") if cleansession is not None: - broker_conf["cleansession"] = cleansession + broker_conf["cleansession"] = cleansession # noop? session.clean_session = cleansession else: session.clean_session = self.config.get("cleansession", True) diff --git a/amqtt/mqtt/connect.py b/amqtt/mqtt/connect.py index 361fbdd..0370947 100644 --- a/amqtt/mqtt/connect.py +++ b/amqtt/mqtt/connect.py @@ -192,7 +192,7 @@ class ConnectPayload(MQTTPayload[ConnectVariableHeader]): # A Server MAY allow a Client to supply a ClientId that has a length of zero bytes # [MQTT-3.1.3-6] payload.client_id = gen_client_id() - # indicator to trow exception in case CLEAN_SESSION_FLAG is set to False + # indicator to throw exception in case CLEAN_SESSION_FLAG is set to False payload.client_id_is_random = True # Read will topic, username and password diff --git a/amqtt/plugins/authentication.py b/amqtt/plugins/authentication.py index aef38e5..0b64286 100644 --- a/amqtt/plugins/authentication.py +++ b/amqtt/plugins/authentication.py @@ -26,6 +26,7 @@ class AnonymousAuthPlugin(BaseAuthPlugin): if self._allow_anonymous: self.context.logger.debug("Authentication success: config allows anonymous") + session.is_anonymous = True return True if session and session.username: diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index c31edc2..301d6b6 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -262,6 +262,10 @@ class PluginManager(Generic[C]): def _schedule_coro(self, coro: Awaitable[str | bool | None]) -> asyncio.Future[str | bool | None]: return asyncio.ensure_future(coro) + def _clean_fired_events(self, future: asyncio.Future[Any]) -> None: + with contextlib.suppress(KeyError, ValueError): + self._fired_events.remove(future) + async def fire_event(self, event_name: Events, *, wait: bool = False, **method_kwargs: Any) -> None: """Fire an event to plugins. @@ -287,12 +291,7 @@ class PluginManager(Generic[C]): coro_instance: Awaitable[Any] = call_method(event_awaitable, method_kwargs) tasks.append(asyncio.ensure_future(coro_instance)) - - def clean_fired_events(future: asyncio.Future[Any]) -> None: - with contextlib.suppress(KeyError, ValueError): - self._fired_events.remove(future) - - tasks[-1].add_done_callback(clean_fired_events) + tasks[-1].add_done_callback(self._clean_fired_events) self._fired_events.extend(tasks) if wait and tasks: diff --git a/amqtt/scripts/default_client.yaml b/amqtt/scripts/default_client.yaml index 3921e49..1fd91be 100644 --- a/amqtt/scripts/default_client.yaml +++ b/amqtt/scripts/default_client.yaml @@ -4,6 +4,7 @@ ping_delay: 1 default_qos: 0 default_retain: false auto_reconnect: true +cleansession: true reconnect_max_interval: 10 reconnect_retries: 2 broker: diff --git a/amqtt/session.py b/amqtt/session.py index 1937080..a2b0f00 100644 --- a/amqtt/session.py +++ b/amqtt/session.py @@ -145,12 +145,15 @@ class Session: # Used to store incoming ApplicationMessage while publish protocol flows self.inflight_in: OrderedDict[int, IncomingApplicationMessage] = OrderedDict() - # Stores messages retained for this session + # Stores messages retained for this session (specifically when the client is disconnected) self.retained_messages: Queue[ApplicationMessage] = Queue() # Stores PUBLISH messages ID received in order and ready for application process self.delivered_message_queue: Queue[ApplicationMessage] = Queue() + # identify anonymous client sessions or clients which didn't identify themselves + self.is_anonymous: bool = False + def _init_states(self) -> None: self.transitions = Machine(states=Session.states, initial="new") self.transitions.add_transition( diff --git a/docs/references/broker_config.md b/docs/references/broker_config.md index 7afbc9d..47cd900 100644 --- a/docs/references/broker_config.md +++ b/docs/references/broker_config.md @@ -22,8 +22,7 @@ listener. ### `timeout-disconnect-delay` *(int)* -Client disconnect timeout without a keep-alive - +Client disconnect timeout without a keep-alive. ### `plugins` *(mapping)* diff --git a/tests/conftest.py b/tests/conftest.py index 9a8e6cb..43a1d61 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,9 +17,9 @@ pytest_plugins = ["pytest_logdog"] test_config = { "listeners": { - "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, - "ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 10}, - "wss": {"type": "ws", "bind": "127.0.0.1:8081", "max_connections": 10}, + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 15}, + "ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 15}, + "wss": {"type": "ws", "bind": "127.0.0.1:8081", "max_connections": 15}, }, "sys_interval": 0, "auth": { diff --git a/tests/test_broker.py b/tests/test_broker.py index b817438..93e3f9f 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -1,6 +1,9 @@ import asyncio import logging +import logging.config +import secrets import socket +import string from unittest.mock import MagicMock, call, patch import psutil @@ -22,8 +25,49 @@ from amqtt.mqtt.pubrec import PubrecPacket from amqtt.mqtt.pubrel import PubrelPacket from amqtt.session import OutgoingApplicationMessage -formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" -logging.basicConfig(level=logging.DEBUG, format=formatter) +# formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" +# logging.basicConfig(level=logging.DEBUG, format=formatter) + +LOGGING_CONFIG = { + 'version': 1, + 'disable_existing_loggers': False, + + 'formatters': { + 'default': { + 'format': '[%(asctime)s] %(levelname)s %(name)s: %(message)s', + }, + }, + + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'DEBUG', + 'formatter': 'default', + 'stream': 'ext://sys.stdout', + } + }, + + 'root': { + 'handlers': ['console'], + 'level': 'DEBUG', + }, + + 'loggers': { + 'transitions': { + 'handlers': ['console'], + 'level': 'WARNING', + 'propagate': False, + }, + }, +} + +logging.config.dictConfig(LOGGING_CONFIG) + + + + + + log = logging.getLogger(__name__) @@ -101,10 +145,11 @@ async def test_connect_tcp(broker): connections_number = 10 # mqtt 3.1 requires a connect packet, otherwise the socket connection is rejected - static_connect_packet = b'\x10\x1b\x00\x04MQTT\x04\x02\x00<\x00\x0ftest-client-123' sockets = [] for i in range(connections_number): + static_connect_packet = b'\x10\x1b\x00\x04MQTT\x04\x02\x00<\x00\x0ftest-client-12' + f"{i}".encode() + s = socket.create_connection(("127.0.0.1", 1883)) s.send(static_connect_packet) sockets.append(s) @@ -122,9 +167,11 @@ async def test_connect_tcp(broker): tcp_connections = [conn for conn in connections if conn.laddr.port == 1883] assert len(tcp_connections) == connections_number + 1 # Including the Broker's listening socket + await asyncio.sleep(0.1) for conn in connections: assert conn.status in ("ESTABLISHED", "LISTEN") + await asyncio.sleep(0.1) # close all connections for s in sockets: s.close() @@ -626,35 +673,142 @@ async def test_client_subscribe_publish_dollar_topic_2(broker): @pytest.mark.asyncio -async def test_client_publish_retain_subscribe(broker): - sub_client = MQTTClient() +async def test_client_publish_clean_session_subscribe(broker): + + sub_client = MQTTClient(client_id='test_client', config={'auto_reconnect': False}) await sub_client.connect("mqtt://127.0.0.1", cleansession=False) ret = await sub_client.subscribe( [("/qos0", QOS_0), ("/qos1", QOS_1), ("/qos2", QOS_2)], ) assert ret == [QOS_0, QOS_1, QOS_2] - await sub_client.disconnect() - await asyncio.sleep(0.1) - await _client_publish("/qos0", b"data", QOS_0, retain=True) - await _client_publish("/qos1", b"data", QOS_1, retain=True) - await _client_publish("/qos2", b"data", QOS_2, retain=True) - await sub_client.reconnect() - for qos in [QOS_0, QOS_1, QOS_2]: + await sub_client.disconnect() + await asyncio.sleep(0.5) + + + await _client_publish("/qos0", b"data0", QOS_0) # should not be retained + await _client_publish("/qos1", b"data1", QOS_1) + await _client_publish("/qos2", b"data2", QOS_2) + await asyncio.sleep(2) + + await sub_client.reconnect(cleansession=False) + for qos in [QOS_1, QOS_2]: log.debug(f"TEST QOS: {qos}") - message = await sub_client.deliver_message() + message = await sub_client.deliver_message(timeout_duration=2) log.debug(f"Message: {message.publish_packet if message else None!r}") assert message is not None assert message.topic == f"/qos{qos}" - assert message.data == b"data" + assert message.data == f"data{qos}".encode("utf-8") assert message.qos == qos + + try: + while True: + message = await sub_client.deliver_message(timeout_duration=1) + assert message is not None, "no other messages should have been retained" + except asyncio.TimeoutError: + pass + + await sub_client.disconnect() + await asyncio.sleep(0.1) + + +@pytest.mark.asyncio +async def test_client_publish_retain_with_new_subscribe(broker): + await asyncio.sleep(2) + sub_client1 = MQTTClient(client_id='test_client1') + await sub_client1.connect("mqtt://127.0.0.1") + + await sub_client1.disconnect() + await asyncio.sleep(0.5) + + await _client_publish("/qos0", b"data0", QOS_0, retain=True) + await asyncio.sleep(0.5) + + sub_client2 = MQTTClient(client_id='test_client2') + await sub_client2.connect("mqtt://127.0.0.1") + + # should receive the retained message on subscription + ret = await sub_client2.subscribe( + [("/qos0", QOS_0)], + ) + assert ret == [QOS_0] + + message = await sub_client2.deliver_message(timeout_duration=1) + assert message is not None + assert message.topic == "/qos0" + assert message.data == b"data0" + assert message.qos == QOS_0 + await sub_client2.disconnect() + await asyncio.sleep(0.1) + + +@pytest.mark.asyncio +async def test_client_publish_retain_latest_with_new_subscribe(broker): + await asyncio.sleep(2) + sub_client1 = MQTTClient(client_id='test_client1') + await sub_client1.connect("mqtt://127.0.0.1") + + await sub_client1.disconnect() + await asyncio.sleep(0.5) + + await _client_publish("/qos0", b"data a", QOS_0, retain=True) + await asyncio.sleep(0.5) + + sub_client2 = MQTTClient(client_id='test_client2') + await sub_client2.connect("mqtt://127.0.0.1") + + await _client_publish("/qos0", b"data b", QOS_0, retain=True) + + # should receive the retained message on subscription + ret = await sub_client2.subscribe( + [("/qos0", QOS_0)], + ) + assert ret == [QOS_0] + + message = await sub_client2.deliver_message(timeout_duration=1) + assert message is not None + assert message.topic == "/qos0" + assert message.data == b"data b" + assert message.qos == QOS_0 + await sub_client2.disconnect() + await asyncio.sleep(0.1) + + +@pytest.mark.asyncio +async def test_client_publish_retain_subscribe_on_reconnect(broker): + await asyncio.sleep(2) + sub_client = MQTTClient(client_id='test_client') + await sub_client.connect("mqtt://127.0.0.1", cleansession=False) + ret = await sub_client.subscribe( + [("/qos0", QOS_0)], + ) + assert ret == [QOS_0] + + await sub_client.disconnect() + await asyncio.sleep(0.5) + + await _client_publish("/qos0", b"data0", QOS_0, retain=True) + await asyncio.sleep(0.5) + + await sub_client.reconnect(cleansession=False) + + message = await sub_client.deliver_message(timeout_duration=1) + assert message is not None + assert message.topic == "/qos0" + assert message.data == b"data0" + assert message.qos == QOS_0 await sub_client.disconnect() await asyncio.sleep(0.1) @pytest.mark.asyncio async def _client_publish(topic, data, qos, retain=False) -> int | OutgoingApplicationMessage: - pub_client = MQTTClient() + + gen_id = "pub_" + valid_chars = string.ascii_letters + string.digits + gen_id += "".join(secrets.choice(valid_chars) for _ in range(16)) + + pub_client = MQTTClient(client_id=gen_id) ret: int | OutgoingApplicationMessage = await pub_client.connect("mqtt://127.0.0.1/") assert ret == 0 ret = await pub_client.publish(topic, data, qos, retain)