additional test cases for message retention for retain flag and disconnected state

pull/248/head
Andrew Mirsky 2025-07-02 08:09:31 -04:00
rodzic 2ceb2ae43b
commit 341c6c1732
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
3 zmienionych plików z 174 dodań i 14 usunięć

Wyświetl plik

@ -29,7 +29,7 @@ from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Sessio
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
from .events import BrokerEvents from .events import BrokerEvents
from .mqtt.constants import QOS_1, QOS_2 from .mqtt.constants import QOS_1, QOS_2, QOS_0
from .mqtt.disconnect import DisconnectPacket from .mqtt.disconnect import DisconnectPacket
from .plugins.manager import BaseContext, PluginManager from .plugins.manager import BaseContext, PluginManager
@ -497,9 +497,18 @@ class Broker:
self.logger.debug(f"{client_session.client_id} Start messages handling") self.logger.debug(f"{client_session.client_id} Start messages handling")
await handler.start() await handler.start()
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
# publish messages that were retained because the client session was disconnecte
self.logger.debug(f"Offline messages queue size: {client_session.retained_messages.qsize()}")
await self._publish_session_retained_messages(client_session) await self._publish_session_retained_messages(client_session)
# publish messages that were marked as retained for a specific
# self.logger.debug(f"Publish messages that have been marked as retained.")
# for topic in self._subscriptions.keys():
# await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
await self._client_message_loop(client_session, handler) await self._client_message_loop(client_session, handler)
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None: async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:

Wyświetl plik

@ -145,7 +145,7 @@ class Session:
# Used to store incoming ApplicationMessage while publish protocol flows # Used to store incoming ApplicationMessage while publish protocol flows
self.inflight_in: OrderedDict[int, IncomingApplicationMessage] = OrderedDict() self.inflight_in: OrderedDict[int, IncomingApplicationMessage] = OrderedDict()
# Stores messages retained for this session # Stores messages retained for this session (specifically when the client is disconnected)
self.retained_messages: Queue[ApplicationMessage] = Queue() self.retained_messages: Queue[ApplicationMessage] = Queue()
# Stores PUBLISH messages ID received in order and ready for application process # Stores PUBLISH messages ID received in order and ready for application process

Wyświetl plik

@ -1,6 +1,9 @@
import asyncio import asyncio
import logging import logging
import logging.config
import secrets
import socket import socket
import string
from unittest.mock import MagicMock, call, patch from unittest.mock import MagicMock, call, patch
import psutil import psutil
@ -22,8 +25,49 @@ from amqtt.mqtt.pubrec import PubrecPacket
from amqtt.mqtt.pubrel import PubrelPacket from amqtt.mqtt.pubrel import PubrelPacket
from amqtt.session import OutgoingApplicationMessage from amqtt.session import OutgoingApplicationMessage
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" # formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter) # logging.basicConfig(level=logging.DEBUG, format=formatter)
LOGGING_CONFIG = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'default': {
'format': '[%(asctime)s] %(levelname)s %(name)s: %(message)s',
},
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': 'DEBUG',
'formatter': 'default',
'stream': 'ext://sys.stdout',
}
},
'root': {
'handlers': ['console'],
'level': 'DEBUG',
},
'loggers': {
'transitions': {
'handlers': ['console'],
'level': 'WARNING',
'propagate': False,
},
},
}
logging.config.dictConfig(LOGGING_CONFIG)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -631,35 +675,142 @@ async def test_client_subscribe_publish_dollar_topic_2(broker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client_publish_retain_subscribe(broker): async def test_client_publish_clean_session_subscribe(broker):
sub_client = MQTTClient(client_id='test_client') sub_client = MQTTClient(client_id='test_client')
await sub_client.connect("mqtt://127.0.0.1", cleansession=False) await sub_client.connect("mqtt://127.0.0.1", cleansession=False)
ret = await sub_client.subscribe( ret = await sub_client.subscribe(
[("/qos0", QOS_0), ("/qos1", QOS_1), ("/qos2", QOS_2)], [("/qos0", QOS_0), ("/qos1", QOS_1), ("/qos2", QOS_2)],
) )
assert ret == [QOS_0, QOS_1, QOS_2] assert ret == [QOS_0, QOS_1, QOS_2]
await sub_client.disconnect()
await asyncio.sleep(0.1)
await _client_publish("/qos0", b"data", QOS_0, retain=True) await sub_client.disconnect()
await _client_publish("/qos1", b"data", QOS_1, retain=True) await asyncio.sleep(0.5)
await _client_publish("/qos2", b"data", QOS_2, retain=True)
await _client_publish("/qos0", b"data0", QOS_0) # should not be retained
await _client_publish("/qos1", b"data1", QOS_1)
await _client_publish("/qos2", b"data2", QOS_2)
await asyncio.sleep(0.5)
await sub_client.reconnect(cleansession=False) await sub_client.reconnect(cleansession=False)
for qos in [QOS_0, QOS_1, QOS_2]: for qos in [QOS_1, QOS_2]:
log.debug(f"TEST QOS: {qos}") log.debug(f"TEST QOS: {qos}")
message = await sub_client.deliver_message() message = await sub_client.deliver_message()
log.debug(f"Message: {message.publish_packet if message else None!r}") log.debug(f"Message: {message.publish_packet if message else None!r}")
assert message is not None assert message is not None
assert message.topic == f"/qos{qos}" assert message.topic == f"/qos{qos}"
assert message.data == b"data" assert message.data == f"data{qos}".encode("utf-8")
assert message.qos == qos assert message.qos == qos
try:
while True:
message = await sub_client.deliver_message(timeout_duration=1)
assert message is not None, "no other messages should have been retained"
except TimeoutError:
pass
await sub_client.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_publish_retain_with_new_subscribe(broker):
await asyncio.sleep(2)
sub_client1 = MQTTClient(client_id='test_client1')
await sub_client1.connect("mqtt://127.0.0.1")
await sub_client1.disconnect()
await asyncio.sleep(0.5)
await _client_publish("/qos0", b"data0", QOS_0, retain=True)
await asyncio.sleep(0.5)
sub_client2 = MQTTClient(client_id='test_client2')
await sub_client2.connect("mqtt://127.0.0.1")
# should receive the retained message on subscription
ret = await sub_client2.subscribe(
[("/qos0", QOS_0)],
)
assert ret == [QOS_0]
message = await sub_client2.deliver_message(timeout_duration=1)
assert message is not None
assert message.topic == "/qos0"
assert message.data == b"data0"
assert message.qos == QOS_0
await sub_client2.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_publish_retain_latest_with_new_subscribe(broker):
await asyncio.sleep(2)
sub_client1 = MQTTClient(client_id='test_client1')
await sub_client1.connect("mqtt://127.0.0.1")
await sub_client1.disconnect()
await asyncio.sleep(0.5)
await _client_publish("/qos0", b"data a", QOS_0, retain=True)
await asyncio.sleep(0.5)
sub_client2 = MQTTClient(client_id='test_client2')
await sub_client2.connect("mqtt://127.0.0.1")
await _client_publish("/qos0", b"data b", QOS_0, retain=True)
# should receive the retained message on subscription
ret = await sub_client2.subscribe(
[("/qos0", QOS_0)],
)
assert ret == [QOS_0]
message = await sub_client2.deliver_message(timeout_duration=1)
assert message is not None
assert message.topic == "/qos0"
assert message.data == b"data b"
assert message.qos == QOS_0
await sub_client2.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_publish_retain_subscribe_on_reconnect(broker):
await asyncio.sleep(2)
sub_client = MQTTClient(client_id='test_client')
await sub_client.connect("mqtt://127.0.0.1", cleansession=False)
ret = await sub_client.subscribe(
[("/qos0", QOS_0)],
)
assert ret == [QOS_0]
await sub_client.disconnect()
await asyncio.sleep(0.5)
await _client_publish("/qos0", b"data0", QOS_0, retain=True)
await asyncio.sleep(0.5)
await sub_client.reconnect(cleansession=False)
message = await sub_client.deliver_message(timeout_duration=1)
assert message is not None
assert message.topic == "/qos0"
assert message.data == b"data0"
assert message.qos == QOS_0
await sub_client.disconnect() await sub_client.disconnect()
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def _client_publish(topic, data, qos, retain=False) -> int | OutgoingApplicationMessage: async def _client_publish(topic, data, qos, retain=False) -> int | OutgoingApplicationMessage:
pub_client = MQTTClient()
gen_id = "pub_"
valid_chars = string.ascii_letters + string.digits
gen_id += "".join(secrets.choice(valid_chars) for _ in range(16))
pub_client = MQTTClient(client_id=gen_id)
ret: int | OutgoingApplicationMessage = await pub_client.connect("mqtt://127.0.0.1/") ret: int | OutgoingApplicationMessage = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0 assert ret == 0
ret = await pub_client.publish(topic, data, qos, retain) ret = await pub_client.publish(topic, data, qos, retain)