Merge remote-tracking branch 'source/0.11.2-rc' into session_persistence

pull/256/head
Andrew Mirsky 2025-07-07 12:10:00 -04:00
commit 3fa54ab7a5
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
13 zmienionych plików z 372 dodań i 13 usunięć

Wyświetl plik

@ -183,6 +183,8 @@ class Broker:
self._subscriptions: dict[str, list[tuple[Session, int]]] = {} self._subscriptions: dict[str, list[tuple[Session, int]]] = {}
self._retained_messages: dict[str, RetainedApplicationMessage] = {} self._retained_messages: dict[str, RetainedApplicationMessage] = {}
self._topic_filter_matchers: dict[str, re.Pattern[str]] = {}
# Broadcast queue for outgoing messages # Broadcast queue for outgoing messages
self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._broadcast_task: asyncio.Task[Any] | None = None self._broadcast_task: asyncio.Task[Any] | None = None
@ -690,6 +692,11 @@ class Broker:
f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection", f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
) )
return False return False
if app_message.topic.startswith("$"):
self.logger.warning(
f"[MQTT-4.7.2-1] - {client_session.client_id} cannot use a topic with a leading $ character."
)
return False
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH) permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
if not permitted: if not permitted:
@ -908,9 +915,6 @@ class Broker:
self.logger.debug(f"Processing broadcast message: {broadcast}") self.logger.debug(f"Processing broadcast message: {broadcast}")
for k_filter, subscriptions in self._subscriptions.items(): for k_filter, subscriptions in self._subscriptions.items():
if broadcast["topic"].startswith("$") and (k_filter.startswith(("+", "#"))):
self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
continue
# Skip all subscriptions which do not match the topic # Skip all subscriptions which do not match the topic
if not self._matches(broadcast["topic"], k_filter): if not self._matches(broadcast["topic"], k_filter):
@ -1039,11 +1043,21 @@ class Broker:
) )
def _matches(self, topic: str, a_filter: str) -> bool: def _matches(self, topic: str, a_filter: str) -> bool:
if topic.startswith("$") and (a_filter.startswith(("+", "#"))):
self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
return False
if "#" not in a_filter and "+" not in a_filter: if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match # if filter doesn't contain wildcard, return exact match
return a_filter == topic return a_filter == topic
# else use regex
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?")) # else use regex (re.compile is an expensive operation, store the matcher for future use)
if a_filter not in self._topic_filter_matchers:
self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter)
.replace("\\#", "?.*")
.replace("\\+", "[^/]*")
.lstrip("?"))
match_pattern = self._topic_filter_matchers[a_filter]
return bool(match_pattern.fullmatch(topic)) return bool(match_pattern.fullmatch(topic))
def _get_handler(self, session: Session) -> BrokerProtocolHandler | None: def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:

Wyświetl plik

@ -520,8 +520,9 @@ class ProtocolHandler(Generic[C]):
elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket): elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket):
task = asyncio.create_task(self.handle_disconnect(packet)) task = asyncio.create_task(self.handle_disconnect(packet))
elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket): elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket):
# TODO: why is this not like all other inside create_task? # q: why is this not like all other inside a create_task?
await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet)) # a: the connection needs to be established before any other packet tasks for this new session are scheduled
await self.handle_connect(packet)
if task: if task:
running_tasks.append(task) running_tasks.append(task)
except MQTTError: except MQTTError:

Wyświetl plik

@ -12,7 +12,7 @@ class PublishVariableHeader(MQTTVariableHeader):
def __init__(self, topic_name: str, packet_id: int | None = None) -> None: def __init__(self, topic_name: str, packet_id: int | None = None) -> None:
super().__init__() super().__init__()
if "*" in topic_name: if "#" in topic_name or "+" in topic_name:
msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters." msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."
raise MQTTError(msg) raise MQTTError(msg)
self.topic_name = topic_name self.topic_name = topic_name

Wyświetl plik

@ -0,0 +1,25 @@
import pytest
from amqtt.errors import AMQTTError
from amqtt.mqtt.connack import ConnackPacket
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = ConnackPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"return_code",
"session_parent"
])
def test_empty_variable_header(prop):
packet = ConnackPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,9 +1,11 @@
import asyncio import asyncio
import unittest import unittest
import pytest
from amqtt.adapters import BufferReader from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, PUBLISH
class ConnectPacketTest(unittest.TestCase): class ConnectPacketTest(unittest.TestCase):
@ -150,3 +152,36 @@ class ConnectPacketTest(unittest.TestCase):
assert message.username == "user" assert message.username == "user"
assert message.payload.password == "password" assert message.payload.password == "password"
assert message.password == "password" assert message.password == "password"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = ConnectPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"proto_name",
"proto_level",
"username_flag",
"password_flag",
"clean_session_flag",
"will_retain_flag",
"will_qos",
"will_flag",
"reserved_flag",
"client_id",
"client_id_is_random",
"will_topic",
"will_message",
"username",
"password",
"keep_alive",
])
def test_empty_variable_header(prop):
packet = ConnectPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio import asyncio
import unittest import unittest
import pytest
from amqtt.adapters import BufferReader from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt import PUBLISH, MQTTFixedHeader
from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket
@ -20,3 +23,22 @@ class PubackPacketTest(unittest.TestCase):
publish = PubackPacket(variable_header=variable_header) publish = PubackPacket(variable_header=variable_header)
out = publish.to_bytes() out = publish.to_bytes()
assert out == b"@\x02\x00\n" assert out == b"@\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
connect_packet = PubackPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id",
])
def test_empty_variable_header(prop):
connect_packet = PubackPacket()
with pytest.raises(ValueError):
assert getattr(connect_packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(connect_packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio import asyncio
import unittest import unittest
import pytest
from amqtt.adapters import BufferReader from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket
@ -20,3 +23,21 @@ class PubcompPacketTest(unittest.TestCase):
publish = PubcompPacket(variable_header=variable_header) publish = PubcompPacket(variable_header=variable_header)
out = publish.to_bytes() out = publish.to_bytes()
assert out == b"p\x02\x00\n" assert out == b"p\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubcompPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubcompPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio import asyncio
import unittest import unittest
import pytest
from amqtt.adapters import BufferReader from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, CONNECT
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader
@ -116,3 +119,28 @@ class PublishPacketTest(unittest.TestCase):
assert packet.dup_flag assert packet.dup_flag
assert packet.qos == QOS_2 assert packet.qos == QOS_2
assert packet.retain_flag assert packet.retain_flag
def test_incorrect_fixed_header():
header = MQTTFixedHeader(CONNECT, 0x00)
with pytest.raises(AMQTTError):
_ = PublishPacket(fixed=header)
def test_set_flags():
packet = PublishPacket()
packet.set_flags(dup_flag=True, qos=QOS_1, retain_flag=True)
@pytest.mark.parametrize("prop", [
"packet_id",
"data",
"topic_name"
])
def test_empty_variable_header(prop):
packet = PublishPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio import asyncio
import unittest import unittest
import pytest
from amqtt.adapters import BufferReader from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket
@ -20,3 +23,22 @@ class PubrecPacketTest(unittest.TestCase):
publish = PubrecPacket(variable_header=variable_header) publish = PubrecPacket(variable_header=variable_header)
out = publish.to_bytes() out = publish.to_bytes()
assert out == b"P\x02\x00\n" assert out == b"P\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubrecPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubrecPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio import asyncio
import unittest import unittest
import pytest
from amqtt.adapters import BufferReader from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket
@ -20,3 +23,22 @@ class PubrelPacketTest(unittest.TestCase):
publish = PubrelPacket(variable_header=variable_header) publish = PubrelPacket(variable_header=variable_header)
out = publish.to_bytes() out = publish.to_bytes()
assert out == b"b\x02\x00\n" assert out == b"b\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubrelPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubrelPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -510,17 +510,62 @@ async def test_client_publish_dup(broker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client_publish_invalid_topic(broker): async def test_client_publishing_invalid_topic(broker):
assert broker.transitions.is_started() assert broker.transitions.is_started()
pub_client = MQTTClient() pub_client = MQTTClient(config={'auto_reconnect': False})
ret = await pub_client.connect("mqtt://127.0.0.1/") ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0 assert ret == 0
await pub_client.publish("/+", b"data", QOS_0) await pub_client.subscribe([
("my/+/topic", QOS_0)
])
await asyncio.sleep(0.5)
# need to build & send packet directly to bypass client's check of invalid topic name
# see test_client.py::test_publish_to_incorrect_wildcard for client checks
packet = PublishPacket.build(topic_name='my/topic', message=b'messages',
packet_id=None, dup_flag=False, qos=QOS_0, retain=False)
packet.topic_name = "my/+/topic"
await pub_client._handler._send_packet(packet)
await asyncio.sleep(0.5)
with pytest.raises(asyncio.TimeoutError):
msg = await pub_client.deliver_message(timeout_duration=1)
assert msg is None
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
await pub_client.disconnect() await pub_client.disconnect()
@pytest.mark.asyncio
async def test_client_publish_asterisk(broker):
"""'*' is a valid, non-wildcard character for MQTT."""
assert broker.transitions.is_started()
pub_client = MQTTClient(config={'auto_reconnect': False})
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.subscribe([
("my*/topic", QOS_0),
("my/+/topic", QOS_0)
])
await asyncio.sleep(0.1)
await pub_client.publish('my*/topic', b'my valid message', QOS_0, retain=False)
await asyncio.sleep(0.1)
msg = await pub_client.deliver_message(timeout_duration=1)
assert msg is not None
assert msg.topic == "my*/topic"
assert msg.data == b'my valid message'
await asyncio.sleep(0.1)
msg = await pub_client.publish('my/****/topic', b'my valid message', QOS_0, retain=False)
assert msg is not None
assert msg.topic == "my/****/topic"
assert msg.data == b'my valid message'
await pub_client.disconnect()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client_publish_big(broker, mock_plugin_manager): async def test_client_publish_big(broker, mock_plugin_manager):
pub_client = MQTTClient() pub_client = MQTTClient()

Wyświetl plik

@ -7,7 +7,7 @@ import pytest
from amqtt.broker import Broker from amqtt.broker import Broker
from amqtt.client import MQTTClient from amqtt.client import MQTTClient
from amqtt.errors import ClientError, ConnectError from amqtt.errors import ClientError, ConnectError, MQTTError
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
@ -490,3 +490,17 @@ async def test_client_no_auth():
await client.connect("mqtt://127.0.0.1:1883/") await client.connect("mqtt://127.0.0.1:1883/")
await broker.shutdown() await broker.shutdown()
@pytest.mark.asyncio
async def test_publish_to_incorrect_wildcard(broker_fixture):
client = MQTTClient(config={'auto_reconnect': False})
await client.connect("mqtt://127.0.0.1/")
with pytest.raises(MQTTError):
await client.publish("my/+/topic", b'plus-sign wildcard topic invalid publish')
with pytest.raises(MQTTError):
await client.publish("topic/#", b'hash wildcard topic invalid publish')
await client.publish("topic/*", b'asterisk topic normal publish')
await client.disconnect()

Wyświetl plik

@ -0,0 +1,110 @@
import asyncio
import logging
import pytest
from amqtt.broker import Broker
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_0
logger = logging.getLogger(__name__)
@pytest.mark.asyncio
async def test_publish_to_dollar_sign_topics():
"""Applications cannot use a topic with a leading $ character for their own purposes [MQTT-4.7.2-1]."""
cfg = {
'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}},
'plugins': {'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True}},
}
b = Broker(config=cfg)
await b.start()
await asyncio.sleep(0.1)
c = MQTTClient(config={'auto_reconnect': False})
await c.connect()
await asyncio.sleep(0.1)
await c.subscribe(
[('$#', QOS_0),
('#', QOS_0)]
)
await asyncio.sleep(0.1)
await c.publish('$MY', b'message should be blocked')
await asyncio.sleep(0.1)
with pytest.raises(asyncio.TimeoutError):
# wait long enough for broker sys plugin to run
_ = await c.deliver_message(timeout_duration=1)
await c.disconnect()
await asyncio.sleep(0.1)
await b.shutdown()
@pytest.mark.asyncio
async def test_hash_will_not_receive_dollar():
"""A subscription to “#” will not receive any messages published to a topic beginning with a $ [MQTT-4.7.2-1]."""
cfg = {
'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}},
'plugins': {
'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True},
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 2}
}
}
b = Broker(config=cfg)
await b.start()
await asyncio.sleep(0.1)
c = MQTTClient(config={'auto_reconnect': False})
await c.connect()
await asyncio.sleep(0.1)
await c.subscribe(
[('#', QOS_0)]
)
await asyncio.sleep(0.1)
with pytest.raises(asyncio.TimeoutError):
# wait long enough for broker sys plugin to run
_ = await c.deliver_message(timeout_duration=5)
await c.disconnect()
await asyncio.sleep(0.1)
await b.shutdown()
@pytest.mark.asyncio
async def test_plus_will_not_receive_dollar():
"""A subscription to “+/monitor/Clients” will not receive any messages published to “$SYS/monitor/Clients [MQTT-4.7.2-1]"""
# BrokerSysPlugin doesn't use $SYS/monitor/Clients, so this is an equivalent test with $SYS/broker topics
cfg = {
'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}},
'plugins': {
'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True},
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 2}
}
}
b = Broker(config=cfg)
await b.start()
await asyncio.sleep(0.1)
c = MQTTClient(config={'auto_reconnect': False})
await c.connect()
await asyncio.sleep(0.1)
await c.subscribe(
[('+/broker/#', QOS_0),
('+/broker/time', QOS_0),
('+/broker/clients/#', QOS_0),
('+/broker/+/maximum', QOS_0)
]
)
await asyncio.sleep(0.1)
with pytest.raises(asyncio.TimeoutError):
# wait long enough for broker sys plugin to run
_ = await c.deliver_message(timeout_duration=5)
await c.disconnect()
await asyncio.sleep(0.1)
await b.shutdown()