Merge remote-tracking branch 'source/0.11.2-rc' into session_persistence

pull/256/head
Andrew Mirsky 2025-07-07 12:10:00 -04:00
commit 3fa54ab7a5
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
13 zmienionych plików z 372 dodań i 13 usunięć

Wyświetl plik

@ -183,6 +183,8 @@ class Broker:
self._subscriptions: dict[str, list[tuple[Session, int]]] = {}
self._retained_messages: dict[str, RetainedApplicationMessage] = {}
self._topic_filter_matchers: dict[str, re.Pattern[str]] = {}
# Broadcast queue for outgoing messages
self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._broadcast_task: asyncio.Task[Any] | None = None
@ -690,6 +692,11 @@ class Broker:
f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
)
return False
if app_message.topic.startswith("$"):
self.logger.warning(
f"[MQTT-4.7.2-1] - {client_session.client_id} cannot use a topic with a leading $ character."
)
return False
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
if not permitted:
@ -908,9 +915,6 @@ class Broker:
self.logger.debug(f"Processing broadcast message: {broadcast}")
for k_filter, subscriptions in self._subscriptions.items():
if broadcast["topic"].startswith("$") and (k_filter.startswith(("+", "#"))):
self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
continue
# Skip all subscriptions which do not match the topic
if not self._matches(broadcast["topic"], k_filter):
@ -1039,11 +1043,21 @@ class Broker:
)
def _matches(self, topic: str, a_filter: str) -> bool:
if topic.startswith("$") and (a_filter.startswith(("+", "#"))):
self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
return False
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
# else use regex
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
# else use regex (re.compile is an expensive operation, store the matcher for future use)
if a_filter not in self._topic_filter_matchers:
self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter)
.replace("\\#", "?.*")
.replace("\\+", "[^/]*")
.lstrip("?"))
match_pattern = self._topic_filter_matchers[a_filter]
return bool(match_pattern.fullmatch(topic))
def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:

Wyświetl plik

@ -520,8 +520,9 @@ class ProtocolHandler(Generic[C]):
elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket):
task = asyncio.create_task(self.handle_disconnect(packet))
elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket):
# TODO: why is this not like all other inside create_task?
await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet))
# q: why is this not like all other inside a create_task?
# a: the connection needs to be established before any other packet tasks for this new session are scheduled
await self.handle_connect(packet)
if task:
running_tasks.append(task)
except MQTTError:

Wyświetl plik

@ -12,7 +12,7 @@ class PublishVariableHeader(MQTTVariableHeader):
def __init__(self, topic_name: str, packet_id: int | None = None) -> None:
super().__init__()
if "*" in topic_name:
if "#" in topic_name or "+" in topic_name:
msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."
raise MQTTError(msg)
self.topic_name = topic_name

Wyświetl plik

@ -0,0 +1,25 @@
import pytest
from amqtt.errors import AMQTTError
from amqtt.mqtt.connack import ConnackPacket
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = ConnackPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"return_code",
"session_parent"
])
def test_empty_variable_header(prop):
packet = ConnackPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,9 +1,11 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, PUBLISH
class ConnectPacketTest(unittest.TestCase):
@ -150,3 +152,36 @@ class ConnectPacketTest(unittest.TestCase):
assert message.username == "user"
assert message.payload.password == "password"
assert message.password == "password"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = ConnectPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"proto_name",
"proto_level",
"username_flag",
"password_flag",
"clean_session_flag",
"will_retain_flag",
"will_qos",
"will_flag",
"reserved_flag",
"client_id",
"client_id_is_random",
"will_topic",
"will_message",
"username",
"password",
"keep_alive",
])
def test_empty_variable_header(prop):
packet = ConnectPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt import PUBLISH, MQTTFixedHeader
from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket
@ -20,3 +23,22 @@ class PubackPacketTest(unittest.TestCase):
publish = PubackPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"@\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
connect_packet = PubackPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id",
])
def test_empty_variable_header(prop):
connect_packet = PubackPacket()
with pytest.raises(ValueError):
assert getattr(connect_packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(connect_packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket
@ -20,3 +23,21 @@ class PubcompPacketTest(unittest.TestCase):
publish = PubcompPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"p\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubcompPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubcompPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, CONNECT
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader
@ -116,3 +119,28 @@ class PublishPacketTest(unittest.TestCase):
assert packet.dup_flag
assert packet.qos == QOS_2
assert packet.retain_flag
def test_incorrect_fixed_header():
header = MQTTFixedHeader(CONNECT, 0x00)
with pytest.raises(AMQTTError):
_ = PublishPacket(fixed=header)
def test_set_flags():
packet = PublishPacket()
packet.set_flags(dup_flag=True, qos=QOS_1, retain_flag=True)
@pytest.mark.parametrize("prop", [
"packet_id",
"data",
"topic_name"
])
def test_empty_variable_header(prop):
packet = PublishPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket
@ -20,3 +23,22 @@ class PubrecPacketTest(unittest.TestCase):
publish = PubrecPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"P\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubrecPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubrecPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket
@ -20,3 +23,22 @@ class PubrelPacketTest(unittest.TestCase):
publish = PubrelPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"b\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubrelPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubrelPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -510,17 +510,62 @@ async def test_client_publish_dup(broker):
@pytest.mark.asyncio
async def test_client_publish_invalid_topic(broker):
async def test_client_publishing_invalid_topic(broker):
assert broker.transitions.is_started()
pub_client = MQTTClient()
pub_client = MQTTClient(config={'auto_reconnect': False})
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.publish("/+", b"data", QOS_0)
await pub_client.subscribe([
("my/+/topic", QOS_0)
])
await asyncio.sleep(0.5)
# need to build & send packet directly to bypass client's check of invalid topic name
# see test_client.py::test_publish_to_incorrect_wildcard for client checks
packet = PublishPacket.build(topic_name='my/topic', message=b'messages',
packet_id=None, dup_flag=False, qos=QOS_0, retain=False)
packet.topic_name = "my/+/topic"
await pub_client._handler._send_packet(packet)
await asyncio.sleep(0.5)
with pytest.raises(asyncio.TimeoutError):
msg = await pub_client.deliver_message(timeout_duration=1)
assert msg is None
await asyncio.sleep(0.1)
await pub_client.disconnect()
@pytest.mark.asyncio
async def test_client_publish_asterisk(broker):
"""'*' is a valid, non-wildcard character for MQTT."""
assert broker.transitions.is_started()
pub_client = MQTTClient(config={'auto_reconnect': False})
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.subscribe([
("my*/topic", QOS_0),
("my/+/topic", QOS_0)
])
await asyncio.sleep(0.1)
await pub_client.publish('my*/topic', b'my valid message', QOS_0, retain=False)
await asyncio.sleep(0.1)
msg = await pub_client.deliver_message(timeout_duration=1)
assert msg is not None
assert msg.topic == "my*/topic"
assert msg.data == b'my valid message'
await asyncio.sleep(0.1)
msg = await pub_client.publish('my/****/topic', b'my valid message', QOS_0, retain=False)
assert msg is not None
assert msg.topic == "my/****/topic"
assert msg.data == b'my valid message'
await pub_client.disconnect()
@pytest.mark.asyncio
async def test_client_publish_big(broker, mock_plugin_manager):
pub_client = MQTTClient()

Wyświetl plik

@ -7,7 +7,7 @@ import pytest
from amqtt.broker import Broker
from amqtt.client import MQTTClient
from amqtt.errors import ClientError, ConnectError
from amqtt.errors import ClientError, ConnectError, MQTTError
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
@ -490,3 +490,17 @@ async def test_client_no_auth():
await client.connect("mqtt://127.0.0.1:1883/")
await broker.shutdown()
@pytest.mark.asyncio
async def test_publish_to_incorrect_wildcard(broker_fixture):
client = MQTTClient(config={'auto_reconnect': False})
await client.connect("mqtt://127.0.0.1/")
with pytest.raises(MQTTError):
await client.publish("my/+/topic", b'plus-sign wildcard topic invalid publish')
with pytest.raises(MQTTError):
await client.publish("topic/#", b'hash wildcard topic invalid publish')
await client.publish("topic/*", b'asterisk topic normal publish')
await client.disconnect()

Wyświetl plik

@ -0,0 +1,110 @@
import asyncio
import logging
import pytest
from amqtt.broker import Broker
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_0
logger = logging.getLogger(__name__)
@pytest.mark.asyncio
async def test_publish_to_dollar_sign_topics():
"""Applications cannot use a topic with a leading $ character for their own purposes [MQTT-4.7.2-1]."""
cfg = {
'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}},
'plugins': {'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True}},
}
b = Broker(config=cfg)
await b.start()
await asyncio.sleep(0.1)
c = MQTTClient(config={'auto_reconnect': False})
await c.connect()
await asyncio.sleep(0.1)
await c.subscribe(
[('$#', QOS_0),
('#', QOS_0)]
)
await asyncio.sleep(0.1)
await c.publish('$MY', b'message should be blocked')
await asyncio.sleep(0.1)
with pytest.raises(asyncio.TimeoutError):
# wait long enough for broker sys plugin to run
_ = await c.deliver_message(timeout_duration=1)
await c.disconnect()
await asyncio.sleep(0.1)
await b.shutdown()
@pytest.mark.asyncio
async def test_hash_will_not_receive_dollar():
"""A subscription to “#” will not receive any messages published to a topic beginning with a $ [MQTT-4.7.2-1]."""
cfg = {
'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}},
'plugins': {
'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True},
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 2}
}
}
b = Broker(config=cfg)
await b.start()
await asyncio.sleep(0.1)
c = MQTTClient(config={'auto_reconnect': False})
await c.connect()
await asyncio.sleep(0.1)
await c.subscribe(
[('#', QOS_0)]
)
await asyncio.sleep(0.1)
with pytest.raises(asyncio.TimeoutError):
# wait long enough for broker sys plugin to run
_ = await c.deliver_message(timeout_duration=5)
await c.disconnect()
await asyncio.sleep(0.1)
await b.shutdown()
@pytest.mark.asyncio
async def test_plus_will_not_receive_dollar():
"""A subscription to “+/monitor/Clients” will not receive any messages published to “$SYS/monitor/Clients [MQTT-4.7.2-1]"""
# BrokerSysPlugin doesn't use $SYS/monitor/Clients, so this is an equivalent test with $SYS/broker topics
cfg = {
'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}},
'plugins': {
'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True},
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 2}
}
}
b = Broker(config=cfg)
await b.start()
await asyncio.sleep(0.1)
c = MQTTClient(config={'auto_reconnect': False})
await c.connect()
await asyncio.sleep(0.1)
await c.subscribe(
[('+/broker/#', QOS_0),
('+/broker/time', QOS_0),
('+/broker/clients/#', QOS_0),
('+/broker/+/maximum', QOS_0)
]
)
await asyncio.sleep(0.1)
with pytest.raises(asyncio.TimeoutError):
# wait long enough for broker sys plugin to run
_ = await c.deliver_message(timeout_duration=5)
await c.disconnect()
await asyncio.sleep(0.1)
await b.shutdown()