Merge pull request #248 from ajmirsky/issues/27

improvements in retaining messages
pull/249/head^2
Andrew Mirsky 2025-07-03 11:47:00 -04:00 zatwierdzone przez GitHub
commit 6f724b9a23
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
11 zmienionych plików z 212 dodań i 30 usunięć

5
.gitignore vendored
Wyświetl plik

@ -34,3 +34,8 @@ site/
_build/
.hypothesis/
coverage.xml
#----- generated files -----
*.log
*memray*
.coverage*

Wyświetl plik

@ -29,6 +29,7 @@ from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Sessio
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
from .events import BrokerEvents
from .mqtt.constants import QOS_0, QOS_1, QOS_2
from .mqtt.disconnect import DisconnectPacket
from .plugins.manager import PluginManager
@ -435,6 +436,7 @@ class Broker:
await self._delete_session(client_session.client_id)
else:
client_session.client_id = gen_client_id()
client_session.parent = 0
# Get session from cache
elif client_session.client_id in self._sessions:
@ -494,9 +496,18 @@ class Broker:
self.logger.debug(f"{client_session.client_id} Start messages handling")
await handler.start()
# publish messages that were retained because the client session was disconnected
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
await self._publish_session_retained_messages(client_session)
# if this is not a new session, there are subscriptions associated with them; publish any topic retained messages
self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.")
for topic in self._subscriptions:
await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
await self._client_message_loop(client_session, handler)
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
@ -878,11 +889,20 @@ class Broker:
qos = broadcast.get("qos", sub_qos)
# Retain all messages which cannot be broadcasted, due to the session not being connected
if target_session.transitions.state != "connected":
# but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]
# and, if a client used anonymous authentication, there is no expectation that messages should be retained
if (target_session.transitions.state != "connected"
and not target_session.clean_session
and qos in (QOS_1, QOS_2)
and not target_session.is_anonymous):
self.logger.debug(f"Session {target_session.client_id} is not connected, retaining message.")
await self._retain_broadcast_message(broadcast, qos, target_session)
continue
# Only broadcast the message to connected clients
if target_session.transitions.state != "connected":
continue
self.logger.debug(
f"Broadcasting message from {format_client_message(session=broadcast['session'])}"
f" on topic '{broadcast['topic']}' to {format_client_message(session=target_session)}",

Wyświetl plik

@ -598,7 +598,7 @@ class MQTTClient:
session.cadata = broker_conf.get("cadata")
if cleansession is not None:
broker_conf["cleansession"] = cleansession
broker_conf["cleansession"] = cleansession # noop?
session.clean_session = cleansession
else:
session.clean_session = self.config.get("cleansession", True)

Wyświetl plik

@ -192,7 +192,7 @@ class ConnectPayload(MQTTPayload[ConnectVariableHeader]):
# A Server MAY allow a Client to supply a ClientId that has a length of zero bytes
# [MQTT-3.1.3-6]
payload.client_id = gen_client_id()
# indicator to trow exception in case CLEAN_SESSION_FLAG is set to False
# indicator to throw exception in case CLEAN_SESSION_FLAG is set to False
payload.client_id_is_random = True
# Read will topic, username and password

Wyświetl plik

@ -26,6 +26,7 @@ class AnonymousAuthPlugin(BaseAuthPlugin):
if self._allow_anonymous:
self.context.logger.debug("Authentication success: config allows anonymous")
session.is_anonymous = True
return True
if session and session.username:

Wyświetl plik

@ -262,6 +262,10 @@ class PluginManager(Generic[C]):
def _schedule_coro(self, coro: Awaitable[str | bool | None]) -> asyncio.Future[str | bool | None]:
return asyncio.ensure_future(coro)
def _clean_fired_events(self, future: asyncio.Future[Any]) -> None:
with contextlib.suppress(KeyError, ValueError):
self._fired_events.remove(future)
async def fire_event(self, event_name: Events, *, wait: bool = False, **method_kwargs: Any) -> None:
"""Fire an event to plugins.
@ -287,12 +291,7 @@ class PluginManager(Generic[C]):
coro_instance: Awaitable[Any] = call_method(event_awaitable, method_kwargs)
tasks.append(asyncio.ensure_future(coro_instance))
def clean_fired_events(future: asyncio.Future[Any]) -> None:
with contextlib.suppress(KeyError, ValueError):
self._fired_events.remove(future)
tasks[-1].add_done_callback(clean_fired_events)
tasks[-1].add_done_callback(self._clean_fired_events)
self._fired_events.extend(tasks)
if wait and tasks:

Wyświetl plik

@ -4,6 +4,7 @@ ping_delay: 1
default_qos: 0
default_retain: false
auto_reconnect: true
cleansession: true
reconnect_max_interval: 10
reconnect_retries: 2
broker:

Wyświetl plik

@ -145,12 +145,15 @@ class Session:
# Used to store incoming ApplicationMessage while publish protocol flows
self.inflight_in: OrderedDict[int, IncomingApplicationMessage] = OrderedDict()
# Stores messages retained for this session
# Stores messages retained for this session (specifically when the client is disconnected)
self.retained_messages: Queue[ApplicationMessage] = Queue()
# Stores PUBLISH messages ID received in order and ready for application process
self.delivered_message_queue: Queue[ApplicationMessage] = Queue()
# identify anonymous client sessions or clients which didn't identify themselves
self.is_anonymous: bool = False
def _init_states(self) -> None:
self.transitions = Machine(states=Session.states, initial="new")
self.transitions.add_transition(

Wyświetl plik

@ -22,8 +22,7 @@ listener.
### `timeout-disconnect-delay` *(int)*
Client disconnect timeout without a keep-alive
Client disconnect timeout without a keep-alive.
### `plugins` *(mapping)*

Wyświetl plik

@ -17,9 +17,9 @@ pytest_plugins = ["pytest_logdog"]
test_config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
"ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 10},
"wss": {"type": "ws", "bind": "127.0.0.1:8081", "max_connections": 10},
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 15},
"ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 15},
"wss": {"type": "ws", "bind": "127.0.0.1:8081", "max_connections": 15},
},
"sys_interval": 0,
"auth": {

Wyświetl plik

@ -1,6 +1,9 @@
import asyncio
import logging
import logging.config
import secrets
import socket
import string
from unittest.mock import MagicMock, call, patch
import psutil
@ -22,8 +25,49 @@ from amqtt.mqtt.pubrec import PubrecPacket
from amqtt.mqtt.pubrel import PubrelPacket
from amqtt.session import OutgoingApplicationMessage
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter)
# formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
# logging.basicConfig(level=logging.DEBUG, format=formatter)
LOGGING_CONFIG = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'default': {
'format': '[%(asctime)s] %(levelname)s %(name)s: %(message)s',
},
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': 'DEBUG',
'formatter': 'default',
'stream': 'ext://sys.stdout',
}
},
'root': {
'handlers': ['console'],
'level': 'DEBUG',
},
'loggers': {
'transitions': {
'handlers': ['console'],
'level': 'WARNING',
'propagate': False,
},
},
}
logging.config.dictConfig(LOGGING_CONFIG)
log = logging.getLogger(__name__)
@ -101,10 +145,11 @@ async def test_connect_tcp(broker):
connections_number = 10
# mqtt 3.1 requires a connect packet, otherwise the socket connection is rejected
static_connect_packet = b'\x10\x1b\x00\x04MQTT\x04\x02\x00<\x00\x0ftest-client-123'
sockets = []
for i in range(connections_number):
static_connect_packet = b'\x10\x1b\x00\x04MQTT\x04\x02\x00<\x00\x0ftest-client-12' + f"{i}".encode()
s = socket.create_connection(("127.0.0.1", 1883))
s.send(static_connect_packet)
sockets.append(s)
@ -122,9 +167,11 @@ async def test_connect_tcp(broker):
tcp_connections = [conn for conn in connections if conn.laddr.port == 1883]
assert len(tcp_connections) == connections_number + 1 # Including the Broker's listening socket
await asyncio.sleep(0.1)
for conn in connections:
assert conn.status in ("ESTABLISHED", "LISTEN")
await asyncio.sleep(0.1)
# close all connections
for s in sockets:
s.close()
@ -626,35 +673,142 @@ async def test_client_subscribe_publish_dollar_topic_2(broker):
@pytest.mark.asyncio
async def test_client_publish_retain_subscribe(broker):
sub_client = MQTTClient()
async def test_client_publish_clean_session_subscribe(broker):
sub_client = MQTTClient(client_id='test_client', config={'auto_reconnect': False})
await sub_client.connect("mqtt://127.0.0.1", cleansession=False)
ret = await sub_client.subscribe(
[("/qos0", QOS_0), ("/qos1", QOS_1), ("/qos2", QOS_2)],
)
assert ret == [QOS_0, QOS_1, QOS_2]
await sub_client.disconnect()
await asyncio.sleep(0.1)
await _client_publish("/qos0", b"data", QOS_0, retain=True)
await _client_publish("/qos1", b"data", QOS_1, retain=True)
await _client_publish("/qos2", b"data", QOS_2, retain=True)
await sub_client.reconnect()
for qos in [QOS_0, QOS_1, QOS_2]:
await sub_client.disconnect()
await asyncio.sleep(0.5)
await _client_publish("/qos0", b"data0", QOS_0) # should not be retained
await _client_publish("/qos1", b"data1", QOS_1)
await _client_publish("/qos2", b"data2", QOS_2)
await asyncio.sleep(2)
await sub_client.reconnect(cleansession=False)
for qos in [QOS_1, QOS_2]:
log.debug(f"TEST QOS: {qos}")
message = await sub_client.deliver_message()
message = await sub_client.deliver_message(timeout_duration=2)
log.debug(f"Message: {message.publish_packet if message else None!r}")
assert message is not None
assert message.topic == f"/qos{qos}"
assert message.data == b"data"
assert message.data == f"data{qos}".encode("utf-8")
assert message.qos == qos
try:
while True:
message = await sub_client.deliver_message(timeout_duration=1)
assert message is not None, "no other messages should have been retained"
except asyncio.TimeoutError:
pass
await sub_client.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_publish_retain_with_new_subscribe(broker):
await asyncio.sleep(2)
sub_client1 = MQTTClient(client_id='test_client1')
await sub_client1.connect("mqtt://127.0.0.1")
await sub_client1.disconnect()
await asyncio.sleep(0.5)
await _client_publish("/qos0", b"data0", QOS_0, retain=True)
await asyncio.sleep(0.5)
sub_client2 = MQTTClient(client_id='test_client2')
await sub_client2.connect("mqtt://127.0.0.1")
# should receive the retained message on subscription
ret = await sub_client2.subscribe(
[("/qos0", QOS_0)],
)
assert ret == [QOS_0]
message = await sub_client2.deliver_message(timeout_duration=1)
assert message is not None
assert message.topic == "/qos0"
assert message.data == b"data0"
assert message.qos == QOS_0
await sub_client2.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_publish_retain_latest_with_new_subscribe(broker):
await asyncio.sleep(2)
sub_client1 = MQTTClient(client_id='test_client1')
await sub_client1.connect("mqtt://127.0.0.1")
await sub_client1.disconnect()
await asyncio.sleep(0.5)
await _client_publish("/qos0", b"data a", QOS_0, retain=True)
await asyncio.sleep(0.5)
sub_client2 = MQTTClient(client_id='test_client2')
await sub_client2.connect("mqtt://127.0.0.1")
await _client_publish("/qos0", b"data b", QOS_0, retain=True)
# should receive the retained message on subscription
ret = await sub_client2.subscribe(
[("/qos0", QOS_0)],
)
assert ret == [QOS_0]
message = await sub_client2.deliver_message(timeout_duration=1)
assert message is not None
assert message.topic == "/qos0"
assert message.data == b"data b"
assert message.qos == QOS_0
await sub_client2.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_publish_retain_subscribe_on_reconnect(broker):
await asyncio.sleep(2)
sub_client = MQTTClient(client_id='test_client')
await sub_client.connect("mqtt://127.0.0.1", cleansession=False)
ret = await sub_client.subscribe(
[("/qos0", QOS_0)],
)
assert ret == [QOS_0]
await sub_client.disconnect()
await asyncio.sleep(0.5)
await _client_publish("/qos0", b"data0", QOS_0, retain=True)
await asyncio.sleep(0.5)
await sub_client.reconnect(cleansession=False)
message = await sub_client.deliver_message(timeout_duration=1)
assert message is not None
assert message.topic == "/qos0"
assert message.data == b"data0"
assert message.qos == QOS_0
await sub_client.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def _client_publish(topic, data, qos, retain=False) -> int | OutgoingApplicationMessage:
pub_client = MQTTClient()
gen_id = "pub_"
valid_chars = string.ascii_letters + string.digits
gen_id += "".join(secrets.choice(valid_chars) for _ in range(16))
pub_client = MQTTClient(client_id=gen_id)
ret: int | OutgoingApplicationMessage = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
ret = await pub_client.publish(topic, data, qos, retain)