diff --git a/amqtt/broker.py b/amqtt/broker.py index 80245ab..066d76b 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -183,6 +183,8 @@ class Broker: self._subscriptions: dict[str, list[tuple[Session, int]]] = {} self._retained_messages: dict[str, RetainedApplicationMessage] = {} + self._topic_filter_matchers: dict[str, re.Pattern[str]] = {} + # Broadcast queue for outgoing messages self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() self._broadcast_task: asyncio.Task[Any] | None = None @@ -690,6 +692,11 @@ class Broker: f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection", ) return False + if app_message.topic.startswith("$"): + self.logger.warning( + f"[MQTT-4.7.2-1] - {client_session.client_id} cannot use a topic with a leading $ character." + ) + return False permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH) if not permitted: @@ -908,9 +915,6 @@ class Broker: self.logger.debug(f"Processing broadcast message: {broadcast}") for k_filter, subscriptions in self._subscriptions.items(): - if broadcast["topic"].startswith("$") and (k_filter.startswith(("+", "#"))): - self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #") - continue # Skip all subscriptions which do not match the topic if not self._matches(broadcast["topic"], k_filter): @@ -1039,11 +1043,21 @@ class Broker: ) def _matches(self, topic: str, a_filter: str) -> bool: + if topic.startswith("$") and (a_filter.startswith(("+", "#"))): + self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #") + return False + if "#" not in a_filter and "+" not in a_filter: # if filter doesn't contain wildcard, return exact match return a_filter == topic - # else use regex - match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?")) + + # else use regex (re.compile is an expensive operation, store the matcher for future use) + if a_filter not in self._topic_filter_matchers: + self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter) + .replace("\\#", "?.*") + .replace("\\+", "[^/]*") + .lstrip("?")) + match_pattern = self._topic_filter_matchers[a_filter] return bool(match_pattern.fullmatch(topic)) def _get_handler(self, session: Session) -> BrokerProtocolHandler | None: diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index eff191c..66bc1cb 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -520,8 +520,9 @@ class ProtocolHandler(Generic[C]): elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket): task = asyncio.create_task(self.handle_disconnect(packet)) elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket): - # TODO: why is this not like all other inside create_task? - await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet)) + # q: why is this not like all other inside a create_task? + # a: the connection needs to be established before any other packet tasks for this new session are scheduled + await self.handle_connect(packet) if task: running_tasks.append(task) except MQTTError: diff --git a/amqtt/mqtt/publish.py b/amqtt/mqtt/publish.py index 5d368c7..b3972da 100644 --- a/amqtt/mqtt/publish.py +++ b/amqtt/mqtt/publish.py @@ -12,7 +12,7 @@ class PublishVariableHeader(MQTTVariableHeader): def __init__(self, topic_name: str, packet_id: int | None = None) -> None: super().__init__() - if "*" in topic_name: + if "#" in topic_name or "+" in topic_name: msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters." raise MQTTError(msg) self.topic_name = topic_name diff --git a/tests/mqtt/test_connack.py b/tests/mqtt/test_connack.py new file mode 100644 index 0000000..58de854 --- /dev/null +++ b/tests/mqtt/test_connack.py @@ -0,0 +1,25 @@ +import pytest +from amqtt.errors import AMQTTError +from amqtt.mqtt.connack import ConnackPacket +from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH + + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = ConnackPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "return_code", + "session_parent" +]) +def test_empty_variable_header(prop): + packet = ConnackPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") diff --git a/tests/mqtt/test_connect.py b/tests/mqtt/test_connect.py index f552d46..4f1385a 100644 --- a/tests/mqtt/test_connect.py +++ b/tests/mqtt/test_connect.py @@ -1,9 +1,11 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader -from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader +from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, PUBLISH class ConnectPacketTest(unittest.TestCase): @@ -150,3 +152,36 @@ class ConnectPacketTest(unittest.TestCase): assert message.username == "user" assert message.payload.password == "password" assert message.password == "password" + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = ConnectPacket(fixed=header) + +@pytest.mark.parametrize("prop", [ + "proto_name", + "proto_level", + "username_flag", + "password_flag", + "clean_session_flag", + "will_retain_flag", + "will_qos", + "will_flag", + "reserved_flag", + "client_id", + "client_id_is_random", + "will_topic", + "will_message", + "username", + "password", + "keep_alive", +]) +def test_empty_variable_header(prop): + packet = ConnectPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") + diff --git a/tests/mqtt/test_puback.py b/tests/mqtt/test_puback.py index d2f3a17..bce08c7 100644 --- a/tests/mqtt/test_puback.py +++ b/tests/mqtt/test_puback.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt import PUBLISH, MQTTFixedHeader from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket @@ -20,3 +23,22 @@ class PubackPacketTest(unittest.TestCase): publish = PubackPacket(variable_header=variable_header) out = publish.to_bytes() assert out == b"@\x02\x00\n" + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + connect_packet = PubackPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "packet_id", +]) +def test_empty_variable_header(prop): + connect_packet = PubackPacket() + + with pytest.raises(ValueError): + assert getattr(connect_packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(connect_packet, prop, "a value") diff --git a/tests/mqtt/test_pubcomp.py b/tests/mqtt/test_pubcomp.py index 86ea7ca..8f0b69b 100644 --- a/tests/mqtt/test_pubcomp.py +++ b/tests/mqtt/test_pubcomp.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt import MQTTFixedHeader, PUBLISH from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket @@ -20,3 +23,21 @@ class PubcompPacketTest(unittest.TestCase): publish = PubcompPacket(variable_header=variable_header) out = publish.to_bytes() assert out == b"p\x02\x00\n" + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = PubcompPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "packet_id" +]) +def test_empty_variable_header(prop): + packet = PubcompPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") diff --git a/tests/mqtt/test_publish.py b/tests/mqtt/test_publish.py index 71a115d..c67e0c5 100644 --- a/tests/mqtt/test_publish.py +++ b/tests/mqtt/test_publish.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt.packet import MQTTFixedHeader, CONNECT from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader @@ -116,3 +119,28 @@ class PublishPacketTest(unittest.TestCase): assert packet.dup_flag assert packet.qos == QOS_2 assert packet.retain_flag + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(CONNECT, 0x00) + with pytest.raises(AMQTTError): + _ = PublishPacket(fixed=header) + +def test_set_flags(): + packet = PublishPacket() + packet.set_flags(dup_flag=True, qos=QOS_1, retain_flag=True) + + +@pytest.mark.parametrize("prop", [ + "packet_id", + "data", + "topic_name" +]) +def test_empty_variable_header(prop): + packet = PublishPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") diff --git a/tests/mqtt/test_pubrec.py b/tests/mqtt/test_pubrec.py index 08f8088..0911570 100644 --- a/tests/mqtt/test_pubrec.py +++ b/tests/mqtt/test_pubrec.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket @@ -20,3 +23,22 @@ class PubrecPacketTest(unittest.TestCase): publish = PubrecPacket(variable_header=variable_header) out = publish.to_bytes() assert out == b"P\x02\x00\n" + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = PubrecPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "packet_id" +]) +def test_empty_variable_header(prop): + packet = PubrecPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") diff --git a/tests/mqtt/test_pubrel.py b/tests/mqtt/test_pubrel.py index 2c89d49..a803540 100644 --- a/tests/mqtt/test_pubrel.py +++ b/tests/mqtt/test_pubrel.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket @@ -20,3 +23,22 @@ class PubrelPacketTest(unittest.TestCase): publish = PubrelPacket(variable_header=variable_header) out = publish.to_bytes() assert out == b"b\x02\x00\n" + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = PubrelPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "packet_id" +]) +def test_empty_variable_header(prop): + packet = PubrelPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") \ No newline at end of file diff --git a/tests/test_broker.py b/tests/test_broker.py index 93e3f9f..a73c4d9 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -510,17 +510,62 @@ async def test_client_publish_dup(broker): @pytest.mark.asyncio -async def test_client_publish_invalid_topic(broker): +async def test_client_publishing_invalid_topic(broker): assert broker.transitions.is_started() - pub_client = MQTTClient() + pub_client = MQTTClient(config={'auto_reconnect': False}) ret = await pub_client.connect("mqtt://127.0.0.1/") assert ret == 0 - await pub_client.publish("/+", b"data", QOS_0) + await pub_client.subscribe([ + ("my/+/topic", QOS_0) + ]) + await asyncio.sleep(0.5) + + # need to build & send packet directly to bypass client's check of invalid topic name + # see test_client.py::test_publish_to_incorrect_wildcard for client checks + packet = PublishPacket.build(topic_name='my/topic', message=b'messages', + packet_id=None, dup_flag=False, qos=QOS_0, retain=False) + packet.topic_name = "my/+/topic" + await pub_client._handler._send_packet(packet) + await asyncio.sleep(0.5) + + with pytest.raises(asyncio.TimeoutError): + msg = await pub_client.deliver_message(timeout_duration=1) + assert msg is None + await asyncio.sleep(0.1) await pub_client.disconnect() +@pytest.mark.asyncio +async def test_client_publish_asterisk(broker): + """'*' is a valid, non-wildcard character for MQTT.""" + assert broker.transitions.is_started() + pub_client = MQTTClient(config={'auto_reconnect': False}) + ret = await pub_client.connect("mqtt://127.0.0.1/") + assert ret == 0 + + await pub_client.subscribe([ + ("my*/topic", QOS_0), + ("my/+/topic", QOS_0) + ]) + await asyncio.sleep(0.1) + await pub_client.publish('my*/topic', b'my valid message', QOS_0, retain=False) + await asyncio.sleep(0.1) + msg = await pub_client.deliver_message(timeout_duration=1) + assert msg is not None + assert msg.topic == "my*/topic" + assert msg.data == b'my valid message' + + await asyncio.sleep(0.1) + msg = await pub_client.publish('my/****/topic', b'my valid message', QOS_0, retain=False) + assert msg is not None + assert msg.topic == "my/****/topic" + assert msg.data == b'my valid message' + + await pub_client.disconnect() + + @pytest.mark.asyncio async def test_client_publish_big(broker, mock_plugin_manager): pub_client = MQTTClient() diff --git a/tests/test_client.py b/tests/test_client.py index 4859c4e..f0975ab 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,7 +7,7 @@ import pytest from amqtt.broker import Broker from amqtt.client import MQTTClient -from amqtt.errors import ClientError, ConnectError +from amqtt.errors import ClientError, ConnectError, MQTTError from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" @@ -490,3 +490,17 @@ async def test_client_no_auth(): await client.connect("mqtt://127.0.0.1:1883/") await broker.shutdown() + + +@pytest.mark.asyncio +async def test_publish_to_incorrect_wildcard(broker_fixture): + client = MQTTClient(config={'auto_reconnect': False}) + await client.connect("mqtt://127.0.0.1/") + + with pytest.raises(MQTTError): + await client.publish("my/+/topic", b'plus-sign wildcard topic invalid publish') + with pytest.raises(MQTTError): + await client.publish("topic/#", b'hash wildcard topic invalid publish') + + await client.publish("topic/*", b'asterisk topic normal publish') + await client.disconnect() diff --git a/tests/test_dollar_topics.py b/tests/test_dollar_topics.py new file mode 100644 index 0000000..374f548 --- /dev/null +++ b/tests/test_dollar_topics.py @@ -0,0 +1,110 @@ +import asyncio +import logging + +import pytest + +from amqtt.broker import Broker +from amqtt.client import MQTTClient +from amqtt.mqtt.constants import QOS_0 + + +logger = logging.getLogger(__name__) + +@pytest.mark.asyncio +async def test_publish_to_dollar_sign_topics(): + """Applications cannot use a topic with a leading $ character for their own purposes [MQTT-4.7.2-1].""" + + cfg = { + 'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}}, + 'plugins': {'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True}}, + } + + b = Broker(config=cfg) + await b.start() + await asyncio.sleep(0.1) + c = MQTTClient(config={'auto_reconnect': False}) + await c.connect() + await asyncio.sleep(0.1) + await c.subscribe( + [('$#', QOS_0), + ('#', QOS_0)] + ) + await asyncio.sleep(0.1) + await c.publish('$MY', b'message should be blocked') + await asyncio.sleep(0.1) + + with pytest.raises(asyncio.TimeoutError): + # wait long enough for broker sys plugin to run + _ = await c.deliver_message(timeout_duration=1) + + await c.disconnect() + await asyncio.sleep(0.1) + await b.shutdown() + +@pytest.mark.asyncio +async def test_hash_will_not_receive_dollar(): + """A subscription to “#” will not receive any messages published to a topic beginning with a $ [MQTT-4.7.2-1].""" + + cfg = { + 'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}}, + 'plugins': { + 'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True}, + 'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 2} + } + } + + b = Broker(config=cfg) + await b.start() + await asyncio.sleep(0.1) + c = MQTTClient(config={'auto_reconnect': False}) + await c.connect() + await asyncio.sleep(0.1) + await c.subscribe( + [('#', QOS_0)] + ) + await asyncio.sleep(0.1) + + with pytest.raises(asyncio.TimeoutError): + # wait long enough for broker sys plugin to run + _ = await c.deliver_message(timeout_duration=5) + + await c.disconnect() + await asyncio.sleep(0.1) + await b.shutdown() + + +@pytest.mark.asyncio +async def test_plus_will_not_receive_dollar(): + """A subscription to “+/monitor/Clients” will not receive any messages published to “$SYS/monitor/Clients [MQTT-4.7.2-1]""" + # BrokerSysPlugin doesn't use $SYS/monitor/Clients, so this is an equivalent test with $SYS/broker topics + + cfg = { + 'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}}, + 'plugins': { + 'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True}, + 'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 2} + } + } + + b = Broker(config=cfg) + await b.start() + await asyncio.sleep(0.1) + c = MQTTClient(config={'auto_reconnect': False}) + await c.connect() + await asyncio.sleep(0.1) + await c.subscribe( + [('+/broker/#', QOS_0), + ('+/broker/time', QOS_0), + ('+/broker/clients/#', QOS_0), + ('+/broker/+/maximum', QOS_0) + ] + ) + await asyncio.sleep(0.1) + + with pytest.raises(asyncio.TimeoutError): + # wait long enough for broker sys plugin to run + _ = await c.deliver_message(timeout_duration=5) + + await c.disconnect() + await asyncio.sleep(0.1) + await b.shutdown()