kopia lustrzana https://github.com/Yakifo/amqtt
additional test cases for message retention for retain flag and disconnected state
rodzic
2ceb2ae43b
commit
341c6c1732
|
@ -29,7 +29,7 @@ from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Sessio
|
|||
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
|
||||
|
||||
from .events import BrokerEvents
|
||||
from .mqtt.constants import QOS_1, QOS_2
|
||||
from .mqtt.constants import QOS_1, QOS_2, QOS_0
|
||||
from .mqtt.disconnect import DisconnectPacket
|
||||
from .plugins.manager import BaseContext, PluginManager
|
||||
|
||||
|
@ -497,9 +497,18 @@ class Broker:
|
|||
|
||||
self.logger.debug(f"{client_session.client_id} Start messages handling")
|
||||
await handler.start()
|
||||
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
|
||||
|
||||
# publish messages that were retained because the client session was disconnecte
|
||||
self.logger.debug(f"Offline messages queue size: {client_session.retained_messages.qsize()}")
|
||||
await self._publish_session_retained_messages(client_session)
|
||||
|
||||
# publish messages that were marked as retained for a specific
|
||||
# self.logger.debug(f"Publish messages that have been marked as retained.")
|
||||
# for topic in self._subscriptions.keys():
|
||||
# await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
|
||||
|
||||
|
||||
|
||||
await self._client_message_loop(client_session, handler)
|
||||
|
||||
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
|
||||
|
|
|
@ -145,7 +145,7 @@ class Session:
|
|||
# Used to store incoming ApplicationMessage while publish protocol flows
|
||||
self.inflight_in: OrderedDict[int, IncomingApplicationMessage] = OrderedDict()
|
||||
|
||||
# Stores messages retained for this session
|
||||
# Stores messages retained for this session (specifically when the client is disconnected)
|
||||
self.retained_messages: Queue[ApplicationMessage] = Queue()
|
||||
|
||||
# Stores PUBLISH messages ID received in order and ready for application process
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import logging.config
|
||||
import secrets
|
||||
import socket
|
||||
import string
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import psutil
|
||||
|
@ -22,8 +25,49 @@ from amqtt.mqtt.pubrec import PubrecPacket
|
|||
from amqtt.mqtt.pubrel import PubrelPacket
|
||||
from amqtt.session import OutgoingApplicationMessage
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
# formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
# logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
|
||||
LOGGING_CONFIG = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
|
||||
'formatters': {
|
||||
'default': {
|
||||
'format': '[%(asctime)s] %(levelname)s %(name)s: %(message)s',
|
||||
},
|
||||
},
|
||||
|
||||
'handlers': {
|
||||
'console': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'level': 'DEBUG',
|
||||
'formatter': 'default',
|
||||
'stream': 'ext://sys.stdout',
|
||||
}
|
||||
},
|
||||
|
||||
'root': {
|
||||
'handlers': ['console'],
|
||||
'level': 'DEBUG',
|
||||
},
|
||||
|
||||
'loggers': {
|
||||
'transitions': {
|
||||
'handlers': ['console'],
|
||||
'level': 'WARNING',
|
||||
'propagate': False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -631,35 +675,142 @@ async def test_client_subscribe_publish_dollar_topic_2(broker):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_retain_subscribe(broker):
|
||||
async def test_client_publish_clean_session_subscribe(broker):
|
||||
|
||||
sub_client = MQTTClient(client_id='test_client')
|
||||
await sub_client.connect("mqtt://127.0.0.1", cleansession=False)
|
||||
ret = await sub_client.subscribe(
|
||||
[("/qos0", QOS_0), ("/qos1", QOS_1), ("/qos2", QOS_2)],
|
||||
)
|
||||
assert ret == [QOS_0, QOS_1, QOS_2]
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await _client_publish("/qos0", b"data", QOS_0, retain=True)
|
||||
await _client_publish("/qos1", b"data", QOS_1, retain=True)
|
||||
await _client_publish("/qos2", b"data", QOS_2, retain=True)
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
await _client_publish("/qos0", b"data0", QOS_0) # should not be retained
|
||||
await _client_publish("/qos1", b"data1", QOS_1)
|
||||
await _client_publish("/qos2", b"data2", QOS_2)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await sub_client.reconnect(cleansession=False)
|
||||
for qos in [QOS_0, QOS_1, QOS_2]:
|
||||
for qos in [QOS_1, QOS_2]:
|
||||
log.debug(f"TEST QOS: {qos}")
|
||||
message = await sub_client.deliver_message()
|
||||
log.debug(f"Message: {message.publish_packet if message else None!r}")
|
||||
assert message is not None
|
||||
assert message.topic == f"/qos{qos}"
|
||||
assert message.data == b"data"
|
||||
assert message.data == f"data{qos}".encode("utf-8")
|
||||
assert message.qos == qos
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await sub_client.deliver_message(timeout_duration=1)
|
||||
assert message is not None, "no other messages should have been retained"
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_retain_with_new_subscribe(broker):
|
||||
await asyncio.sleep(2)
|
||||
sub_client1 = MQTTClient(client_id='test_client1')
|
||||
await sub_client1.connect("mqtt://127.0.0.1")
|
||||
|
||||
await sub_client1.disconnect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await _client_publish("/qos0", b"data0", QOS_0, retain=True)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
sub_client2 = MQTTClient(client_id='test_client2')
|
||||
await sub_client2.connect("mqtt://127.0.0.1")
|
||||
|
||||
# should receive the retained message on subscription
|
||||
ret = await sub_client2.subscribe(
|
||||
[("/qos0", QOS_0)],
|
||||
)
|
||||
assert ret == [QOS_0]
|
||||
|
||||
message = await sub_client2.deliver_message(timeout_duration=1)
|
||||
assert message is not None
|
||||
assert message.topic == "/qos0"
|
||||
assert message.data == b"data0"
|
||||
assert message.qos == QOS_0
|
||||
await sub_client2.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_retain_latest_with_new_subscribe(broker):
|
||||
await asyncio.sleep(2)
|
||||
sub_client1 = MQTTClient(client_id='test_client1')
|
||||
await sub_client1.connect("mqtt://127.0.0.1")
|
||||
|
||||
await sub_client1.disconnect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await _client_publish("/qos0", b"data a", QOS_0, retain=True)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
sub_client2 = MQTTClient(client_id='test_client2')
|
||||
await sub_client2.connect("mqtt://127.0.0.1")
|
||||
|
||||
await _client_publish("/qos0", b"data b", QOS_0, retain=True)
|
||||
|
||||
# should receive the retained message on subscription
|
||||
ret = await sub_client2.subscribe(
|
||||
[("/qos0", QOS_0)],
|
||||
)
|
||||
assert ret == [QOS_0]
|
||||
|
||||
message = await sub_client2.deliver_message(timeout_duration=1)
|
||||
assert message is not None
|
||||
assert message.topic == "/qos0"
|
||||
assert message.data == b"data b"
|
||||
assert message.qos == QOS_0
|
||||
await sub_client2.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_retain_subscribe_on_reconnect(broker):
|
||||
await asyncio.sleep(2)
|
||||
sub_client = MQTTClient(client_id='test_client')
|
||||
await sub_client.connect("mqtt://127.0.0.1", cleansession=False)
|
||||
ret = await sub_client.subscribe(
|
||||
[("/qos0", QOS_0)],
|
||||
)
|
||||
assert ret == [QOS_0]
|
||||
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await _client_publish("/qos0", b"data0", QOS_0, retain=True)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await sub_client.reconnect(cleansession=False)
|
||||
|
||||
message = await sub_client.deliver_message(timeout_duration=1)
|
||||
assert message is not None
|
||||
assert message.topic == "/qos0"
|
||||
assert message.data == b"data0"
|
||||
assert message.qos == QOS_0
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def _client_publish(topic, data, qos, retain=False) -> int | OutgoingApplicationMessage:
|
||||
pub_client = MQTTClient()
|
||||
|
||||
gen_id = "pub_"
|
||||
valid_chars = string.ascii_letters + string.digits
|
||||
gen_id += "".join(secrets.choice(valid_chars) for _ in range(16))
|
||||
|
||||
pub_client = MQTTClient(client_id=gen_id)
|
||||
ret: int | OutgoingApplicationMessage = await pub_client.connect("mqtt://127.0.0.1/")
|
||||
assert ret == 0
|
||||
ret = await pub_client.publish(topic, data, qos, retain)
|
||||
|
|
Ładowanie…
Reference in New Issue