resolves Yakifo/amqtt#250 : * was being blocked as a valid topic character in publish/receive, even though the invalid topic wildcard characters are '#' and '+'. also, add test coverage for error cases when creating different packet types.

pull/251/head
Andrew Mirsky 2025-07-03 14:32:13 -04:00
rodzic 38b2145234
commit a2e5a67059
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
11 zmienionych plików z 214 dodań i 14 usunięć

Wyświetl plik

@ -520,8 +520,9 @@ class ProtocolHandler(Generic[C]):
elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket):
task = asyncio.create_task(self.handle_disconnect(packet))
elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket):
# TODO: why is this not like all other inside create_task?
await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet))
# q: why is this not like all other inside a create_task?
# a: the connection needs to be established before any other packet tasks for this new session are scheduled
await self.handle_connect(packet)
if task:
running_tasks.append(task)
except MQTTError:

Wyświetl plik

@ -12,7 +12,7 @@ class PublishVariableHeader(MQTTVariableHeader):
def __init__(self, topic_name: str, packet_id: int | None = None) -> None:
super().__init__()
if "*" in topic_name:
if "#" in topic_name or "+" in topic_name:
msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."
raise MQTTError(msg)
self.topic_name = topic_name

Wyświetl plik

@ -0,0 +1,25 @@
import pytest
from amqtt.errors import AMQTTError
from amqtt.mqtt.connack import ConnackPacket
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = ConnackPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"return_code",
"session_parent"
])
def test_empty_variable_header(prop):
packet = ConnackPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -153,10 +153,10 @@ class ConnectPacketTest(unittest.TestCase):
assert message.payload.password == "password"
assert message.password == "password"
def test_incorrect_fixed_header(self):
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
connect_packet = ConnectPacket(fixed=header)
_ = ConnectPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"proto_name",
@ -177,11 +177,11 @@ class ConnectPacketTest(unittest.TestCase):
"keep_alive",
])
def test_empty_variable_header(prop):
connect_packet = ConnectPacket()
packet = ConnectPacket()
with pytest.raises(ValueError):
assert getattr(connect_packet, prop) is not None
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(connect_packet, prop, "a value")
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt import PUBLISH, MQTTFixedHeader
from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket
@ -20,3 +23,22 @@ class PubackPacketTest(unittest.TestCase):
publish = PubackPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"@\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
connect_packet = PubackPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id",
])
def test_empty_variable_header(prop):
connect_packet = PubackPacket()
with pytest.raises(ValueError):
assert getattr(connect_packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(connect_packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket
@ -20,3 +23,21 @@ class PubcompPacketTest(unittest.TestCase):
publish = PubcompPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"p\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubcompPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubcompPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, CONNECT
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader
@ -116,3 +119,28 @@ class PublishPacketTest(unittest.TestCase):
assert packet.dup_flag
assert packet.qos == QOS_2
assert packet.retain_flag
def test_incorrect_fixed_header():
header = MQTTFixedHeader(CONNECT, 0x00)
with pytest.raises(AMQTTError):
_ = PublishPacket(fixed=header)
def test_set_flags():
packet = PublishPacket()
packet.set_flags(dup_flag=True, qos=QOS_1, retain_flag=True)
@pytest.mark.parametrize("prop", [
"packet_id",
"data",
"topic_name"
])
def test_empty_variable_header(prop):
packet = PublishPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket
@ -20,3 +23,22 @@ class PubrecPacketTest(unittest.TestCase):
publish = PubrecPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"P\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubrecPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubrecPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket
@ -20,3 +23,22 @@ class PubrelPacketTest(unittest.TestCase):
publish = PubrelPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"b\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubrelPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubrelPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -468,17 +468,62 @@ async def test_client_publish_dup(broker):
@pytest.mark.asyncio
async def test_client_publish_invalid_topic(broker):
async def test_client_publishing_invalid_topic(broker):
assert broker.transitions.is_started()
pub_client = MQTTClient()
pub_client = MQTTClient(config={'auto_reconnect': False})
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.publish("/+", b"data", QOS_0)
await pub_client.subscribe([
("my/+/topic", QOS_0)
])
await asyncio.sleep(0.5)
# need to build & send packet directly to bypass client's check of invalid topic name
# see test_client.py::test_publish_to_incorrect_wildcard for client checks
packet = PublishPacket.build(topic_name='my/topic', message=b'messages',
packet_id=None, dup_flag=False, qos=QOS_0, retain=False)
packet.topic_name = "my/+/topic"
await pub_client._handler._send_packet(packet)
await asyncio.sleep(0.5)
with pytest.raises(asyncio.TimeoutError):
msg = await pub_client.deliver_message(timeout_duration=1)
assert msg is None
await asyncio.sleep(0.1)
await pub_client.disconnect()
@pytest.mark.asyncio
async def test_client_publish_asterisk(broker):
"""'*' is a valid, non-wildcard character for MQTT."""
assert broker.transitions.is_started()
pub_client = MQTTClient(config={'auto_reconnect': False})
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.subscribe([
("my*/topic", QOS_0),
("my/+/topic", QOS_0)
])
await asyncio.sleep(0.1)
await pub_client.publish('my*/topic', b'my valid message', QOS_0, retain=False)
await asyncio.sleep(0.1)
msg = await pub_client.deliver_message(timeout_duration=1)
assert msg is not None
assert msg.topic == "my*/topic"
assert msg.data == b'my valid message'
await asyncio.sleep(0.1)
msg = await pub_client.publish('my/****/topic', b'my valid message', QOS_0, retain=False)
assert msg is not None
assert msg.topic == "my/****/topic"
assert msg.data == b'my valid message'
await pub_client.disconnect()
@pytest.mark.asyncio
async def test_client_publish_big(broker, mock_plugin_manager):
pub_client = MQTTClient()

Wyświetl plik

@ -7,7 +7,7 @@ import pytest
from amqtt.broker import Broker
from amqtt.client import MQTTClient
from amqtt.errors import ClientError, ConnectError
from amqtt.errors import ClientError, ConnectError, MQTTError
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
@ -490,3 +490,17 @@ async def test_client_no_auth():
await client.connect("mqtt://127.0.0.1:1883/")
await broker.shutdown()
@pytest.mark.asyncio
async def test_publish_to_incorrect_wildcard(broker_fixture):
client = MQTTClient(config={'auto_reconnect': False})
await client.connect("mqtt://127.0.0.1/")
with pytest.raises(MQTTError):
await client.publish("my/+/topic", b'start wildcard topic publish')
with pytest.raises(MQTTError):
await client.publish("topic/#", b'hash wildcard topic publish')
await client.publish("topic/*", b'start wildcard topic publish')
await client.disconnect()