From 38b2145234a09bdab6c7c2fa4fef821fe4d957b4 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Thu, 3 Jul 2025 12:05:05 -0400 Subject: [PATCH 1/3] adding tests for failure cases for different connect packet properties --- tests/mqtt/test_connect.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/mqtt/test_connect.py b/tests/mqtt/test_connect.py index f552d46..f11325b 100644 --- a/tests/mqtt/test_connect.py +++ b/tests/mqtt/test_connect.py @@ -1,9 +1,11 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader -from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader +from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, PUBLISH class ConnectPacketTest(unittest.TestCase): @@ -150,3 +152,36 @@ class ConnectPacketTest(unittest.TestCase): assert message.username == "user" assert message.payload.password == "password" assert message.password == "password" + + def test_incorrect_fixed_header(self): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + connect_packet = ConnectPacket(fixed=header) + +@pytest.mark.parametrize("prop", [ + "proto_name", + "proto_level", + "username_flag", + "password_flag", + "clean_session_flag", + "will_retain_flag", + "will_qos", + "will_flag", + "reserved_flag", + "client_id", + "client_id_is_random", + "will_topic", + "will_message", + "username", + "password", + "keep_alive", +]) +def test_empty_variable_header(prop): + connect_packet = ConnectPacket() + + with pytest.raises(ValueError): + assert getattr(connect_packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(connect_packet, prop, "a value") + From a2e5a6705956384ccb1a3838b5fa759bec5658e4 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Thu, 3 Jul 2025 14:32:13 -0400 Subject: [PATCH 2/3] resolves Yakifo/amqtt#250 : * was being blocked as a valid topic character in publish/receive, even though the invalid topic wildcard characters are '#' and '+'. also, add test coverage for error cases when creating different packet types. --- amqtt/mqtt/protocol/handler.py | 5 ++-- amqtt/mqtt/publish.py | 2 +- tests/mqtt/test_connack.py | 25 +++++++++++++++++ tests/mqtt/test_connect.py | 14 +++++----- tests/mqtt/test_puback.py | 22 +++++++++++++++ tests/mqtt/test_pubcomp.py | 21 ++++++++++++++ tests/mqtt/test_publish.py | 28 +++++++++++++++++++ tests/mqtt/test_pubrec.py | 22 +++++++++++++++ tests/mqtt/test_pubrel.py | 22 +++++++++++++++ tests/test_broker.py | 51 ++++++++++++++++++++++++++++++++-- tests/test_client.py | 16 ++++++++++- 11 files changed, 214 insertions(+), 14 deletions(-) create mode 100644 tests/mqtt/test_connack.py diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index eff191c..66bc1cb 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -520,8 +520,9 @@ class ProtocolHandler(Generic[C]): elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket): task = asyncio.create_task(self.handle_disconnect(packet)) elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket): - # TODO: why is this not like all other inside create_task? - await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet)) + # q: why is this not like all other inside a create_task? + # a: the connection needs to be established before any other packet tasks for this new session are scheduled + await self.handle_connect(packet) if task: running_tasks.append(task) except MQTTError: diff --git a/amqtt/mqtt/publish.py b/amqtt/mqtt/publish.py index 5d368c7..b3972da 100644 --- a/amqtt/mqtt/publish.py +++ b/amqtt/mqtt/publish.py @@ -12,7 +12,7 @@ class PublishVariableHeader(MQTTVariableHeader): def __init__(self, topic_name: str, packet_id: int | None = None) -> None: super().__init__() - if "*" in topic_name: + if "#" in topic_name or "+" in topic_name: msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters." raise MQTTError(msg) self.topic_name = topic_name diff --git a/tests/mqtt/test_connack.py b/tests/mqtt/test_connack.py new file mode 100644 index 0000000..58de854 --- /dev/null +++ b/tests/mqtt/test_connack.py @@ -0,0 +1,25 @@ +import pytest +from amqtt.errors import AMQTTError +from amqtt.mqtt.connack import ConnackPacket +from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH + + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = ConnackPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "return_code", + "session_parent" +]) +def test_empty_variable_header(prop): + packet = ConnackPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") diff --git a/tests/mqtt/test_connect.py b/tests/mqtt/test_connect.py index f11325b..4f1385a 100644 --- a/tests/mqtt/test_connect.py +++ b/tests/mqtt/test_connect.py @@ -153,10 +153,10 @@ class ConnectPacketTest(unittest.TestCase): assert message.payload.password == "password" assert message.password == "password" - def test_incorrect_fixed_header(self): - header = MQTTFixedHeader(PUBLISH, 0x00) - with pytest.raises(AMQTTError): - connect_packet = ConnectPacket(fixed=header) +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = ConnectPacket(fixed=header) @pytest.mark.parametrize("prop", [ "proto_name", @@ -177,11 +177,11 @@ class ConnectPacketTest(unittest.TestCase): "keep_alive", ]) def test_empty_variable_header(prop): - connect_packet = ConnectPacket() + packet = ConnectPacket() with pytest.raises(ValueError): - assert getattr(connect_packet, prop) is not None + assert getattr(packet, prop) is not None with pytest.raises(ValueError): - assert setattr(connect_packet, prop, "a value") + assert setattr(packet, prop, "a value") diff --git a/tests/mqtt/test_puback.py b/tests/mqtt/test_puback.py index d2f3a17..bce08c7 100644 --- a/tests/mqtt/test_puback.py +++ b/tests/mqtt/test_puback.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt import PUBLISH, MQTTFixedHeader from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket @@ -20,3 +23,22 @@ class PubackPacketTest(unittest.TestCase): publish = PubackPacket(variable_header=variable_header) out = publish.to_bytes() assert out == b"@\x02\x00\n" + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + connect_packet = PubackPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "packet_id", +]) +def test_empty_variable_header(prop): + connect_packet = PubackPacket() + + with pytest.raises(ValueError): + assert getattr(connect_packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(connect_packet, prop, "a value") diff --git a/tests/mqtt/test_pubcomp.py b/tests/mqtt/test_pubcomp.py index 86ea7ca..8f0b69b 100644 --- a/tests/mqtt/test_pubcomp.py +++ b/tests/mqtt/test_pubcomp.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt import MQTTFixedHeader, PUBLISH from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket @@ -20,3 +23,21 @@ class PubcompPacketTest(unittest.TestCase): publish = PubcompPacket(variable_header=variable_header) out = publish.to_bytes() assert out == b"p\x02\x00\n" + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = PubcompPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "packet_id" +]) +def test_empty_variable_header(prop): + packet = PubcompPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") diff --git a/tests/mqtt/test_publish.py b/tests/mqtt/test_publish.py index 71a115d..c67e0c5 100644 --- a/tests/mqtt/test_publish.py +++ b/tests/mqtt/test_publish.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt.packet import MQTTFixedHeader, CONNECT from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader @@ -116,3 +119,28 @@ class PublishPacketTest(unittest.TestCase): assert packet.dup_flag assert packet.qos == QOS_2 assert packet.retain_flag + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(CONNECT, 0x00) + with pytest.raises(AMQTTError): + _ = PublishPacket(fixed=header) + +def test_set_flags(): + packet = PublishPacket() + packet.set_flags(dup_flag=True, qos=QOS_1, retain_flag=True) + + +@pytest.mark.parametrize("prop", [ + "packet_id", + "data", + "topic_name" +]) +def test_empty_variable_header(prop): + packet = PublishPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") diff --git a/tests/mqtt/test_pubrec.py b/tests/mqtt/test_pubrec.py index 08f8088..0911570 100644 --- a/tests/mqtt/test_pubrec.py +++ b/tests/mqtt/test_pubrec.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket @@ -20,3 +23,22 @@ class PubrecPacketTest(unittest.TestCase): publish = PubrecPacket(variable_header=variable_header) out = publish.to_bytes() assert out == b"P\x02\x00\n" + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = PubrecPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "packet_id" +]) +def test_empty_variable_header(prop): + packet = PubrecPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") diff --git a/tests/mqtt/test_pubrel.py b/tests/mqtt/test_pubrel.py index 2c89d49..a803540 100644 --- a/tests/mqtt/test_pubrel.py +++ b/tests/mqtt/test_pubrel.py @@ -1,7 +1,10 @@ import asyncio import unittest +import pytest from amqtt.adapters import BufferReader +from amqtt.errors import AMQTTError +from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket @@ -20,3 +23,22 @@ class PubrelPacketTest(unittest.TestCase): publish = PubrelPacket(variable_header=variable_header) out = publish.to_bytes() assert out == b"b\x02\x00\n" + + +def test_incorrect_fixed_header(): + header = MQTTFixedHeader(PUBLISH, 0x00) + with pytest.raises(AMQTTError): + _ = PubrelPacket(fixed=header) + + +@pytest.mark.parametrize("prop", [ + "packet_id" +]) +def test_empty_variable_header(prop): + packet = PubrelPacket() + + with pytest.raises(ValueError): + assert getattr(packet, prop) is not None + + with pytest.raises(ValueError): + assert setattr(packet, prop, "a value") \ No newline at end of file diff --git a/tests/test_broker.py b/tests/test_broker.py index cc44947..95ebca5 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -468,17 +468,62 @@ async def test_client_publish_dup(broker): @pytest.mark.asyncio -async def test_client_publish_invalid_topic(broker): +async def test_client_publishing_invalid_topic(broker): assert broker.transitions.is_started() - pub_client = MQTTClient() + pub_client = MQTTClient(config={'auto_reconnect': False}) ret = await pub_client.connect("mqtt://127.0.0.1/") assert ret == 0 - await pub_client.publish("/+", b"data", QOS_0) + await pub_client.subscribe([ + ("my/+/topic", QOS_0) + ]) + await asyncio.sleep(0.5) + + # need to build & send packet directly to bypass client's check of invalid topic name + # see test_client.py::test_publish_to_incorrect_wildcard for client checks + packet = PublishPacket.build(topic_name='my/topic', message=b'messages', + packet_id=None, dup_flag=False, qos=QOS_0, retain=False) + packet.topic_name = "my/+/topic" + await pub_client._handler._send_packet(packet) + await asyncio.sleep(0.5) + + with pytest.raises(asyncio.TimeoutError): + msg = await pub_client.deliver_message(timeout_duration=1) + assert msg is None + await asyncio.sleep(0.1) await pub_client.disconnect() +@pytest.mark.asyncio +async def test_client_publish_asterisk(broker): + """'*' is a valid, non-wildcard character for MQTT.""" + assert broker.transitions.is_started() + pub_client = MQTTClient(config={'auto_reconnect': False}) + ret = await pub_client.connect("mqtt://127.0.0.1/") + assert ret == 0 + + await pub_client.subscribe([ + ("my*/topic", QOS_0), + ("my/+/topic", QOS_0) + ]) + await asyncio.sleep(0.1) + await pub_client.publish('my*/topic', b'my valid message', QOS_0, retain=False) + await asyncio.sleep(0.1) + msg = await pub_client.deliver_message(timeout_duration=1) + assert msg is not None + assert msg.topic == "my*/topic" + assert msg.data == b'my valid message' + + await asyncio.sleep(0.1) + msg = await pub_client.publish('my/****/topic', b'my valid message', QOS_0, retain=False) + assert msg is not None + assert msg.topic == "my/****/topic" + assert msg.data == b'my valid message' + + await pub_client.disconnect() + + @pytest.mark.asyncio async def test_client_publish_big(broker, mock_plugin_manager): pub_client = MQTTClient() diff --git a/tests/test_client.py b/tests/test_client.py index 4859c4e..c10999a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,7 +7,7 @@ import pytest from amqtt.broker import Broker from amqtt.client import MQTTClient -from amqtt.errors import ClientError, ConnectError +from amqtt.errors import ClientError, ConnectError, MQTTError from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" @@ -490,3 +490,17 @@ async def test_client_no_auth(): await client.connect("mqtt://127.0.0.1:1883/") await broker.shutdown() + + +@pytest.mark.asyncio +async def test_publish_to_incorrect_wildcard(broker_fixture): + client = MQTTClient(config={'auto_reconnect': False}) + await client.connect("mqtt://127.0.0.1/") + + with pytest.raises(MQTTError): + await client.publish("my/+/topic", b'start wildcard topic publish') + with pytest.raises(MQTTError): + await client.publish("topic/#", b'hash wildcard topic publish') + + await client.publish("topic/*", b'start wildcard topic publish') + await client.disconnect() From 571434ed04caf628a36ad6638fa14b9410d34051 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Thu, 3 Jul 2025 16:10:47 -0400 Subject: [PATCH 3/3] fixing test comments --- tests/test_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index c10999a..f0975ab 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -498,9 +498,9 @@ async def test_publish_to_incorrect_wildcard(broker_fixture): await client.connect("mqtt://127.0.0.1/") with pytest.raises(MQTTError): - await client.publish("my/+/topic", b'start wildcard topic publish') + await client.publish("my/+/topic", b'plus-sign wildcard topic invalid publish') with pytest.raises(MQTTError): - await client.publish("topic/#", b'hash wildcard topic publish') + await client.publish("topic/#", b'hash wildcard topic invalid publish') - await client.publish("topic/*", b'start wildcard topic publish') + await client.publish("topic/*", b'asterisk topic normal publish') await client.disconnect()