Merge pull request #251 from ajmirsky/increased_test_coverage

publishing to a topic with `*` is allowed, while `#` and `+` are not
pull/257/head
Andrew Mirsky 2025-07-07 12:07:11 -04:00 zatwierdzone przez GitHub
commit 8022e01bb0
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
11 zmienionych plików z 243 dodań i 8 usunięć

Wyświetl plik

@ -520,8 +520,9 @@ class ProtocolHandler(Generic[C]):
elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket):
task = asyncio.create_task(self.handle_disconnect(packet))
elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket):
# TODO: why is this not like all other inside create_task?
await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet))
# q: why is this not like all other inside a create_task?
# a: the connection needs to be established before any other packet tasks for this new session are scheduled
await self.handle_connect(packet)
if task:
running_tasks.append(task)
except MQTTError:

Wyświetl plik

@ -12,7 +12,7 @@ class PublishVariableHeader(MQTTVariableHeader):
def __init__(self, topic_name: str, packet_id: int | None = None) -> None:
super().__init__()
if "*" in topic_name:
if "#" in topic_name or "+" in topic_name:
msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."
raise MQTTError(msg)
self.topic_name = topic_name

Wyświetl plik

@ -0,0 +1,25 @@
import pytest
from amqtt.errors import AMQTTError
from amqtt.mqtt.connack import ConnackPacket
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = ConnackPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"return_code",
"session_parent"
])
def test_empty_variable_header(prop):
packet = ConnackPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,9 +1,11 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, PUBLISH
class ConnectPacketTest(unittest.TestCase):
@ -150,3 +152,36 @@ class ConnectPacketTest(unittest.TestCase):
assert message.username == "user"
assert message.payload.password == "password"
assert message.password == "password"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = ConnectPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"proto_name",
"proto_level",
"username_flag",
"password_flag",
"clean_session_flag",
"will_retain_flag",
"will_qos",
"will_flag",
"reserved_flag",
"client_id",
"client_id_is_random",
"will_topic",
"will_message",
"username",
"password",
"keep_alive",
])
def test_empty_variable_header(prop):
packet = ConnectPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt import PUBLISH, MQTTFixedHeader
from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket
@ -20,3 +23,22 @@ class PubackPacketTest(unittest.TestCase):
publish = PubackPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"@\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
connect_packet = PubackPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id",
])
def test_empty_variable_header(prop):
connect_packet = PubackPacket()
with pytest.raises(ValueError):
assert getattr(connect_packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(connect_packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket
@ -20,3 +23,21 @@ class PubcompPacketTest(unittest.TestCase):
publish = PubcompPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"p\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubcompPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubcompPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, CONNECT
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader
@ -116,3 +119,28 @@ class PublishPacketTest(unittest.TestCase):
assert packet.dup_flag
assert packet.qos == QOS_2
assert packet.retain_flag
def test_incorrect_fixed_header():
header = MQTTFixedHeader(CONNECT, 0x00)
with pytest.raises(AMQTTError):
_ = PublishPacket(fixed=header)
def test_set_flags():
packet = PublishPacket()
packet.set_flags(dup_flag=True, qos=QOS_1, retain_flag=True)
@pytest.mark.parametrize("prop", [
"packet_id",
"data",
"topic_name"
])
def test_empty_variable_header(prop):
packet = PublishPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket
@ -20,3 +23,22 @@ class PubrecPacketTest(unittest.TestCase):
publish = PubrecPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"P\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubrecPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubrecPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -1,7 +1,10 @@
import asyncio
import unittest
import pytest
from amqtt.adapters import BufferReader
from amqtt.errors import AMQTTError
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket
@ -20,3 +23,22 @@ class PubrelPacketTest(unittest.TestCase):
publish = PubrelPacket(variable_header=variable_header)
out = publish.to_bytes()
assert out == b"b\x02\x00\n"
def test_incorrect_fixed_header():
header = MQTTFixedHeader(PUBLISH, 0x00)
with pytest.raises(AMQTTError):
_ = PubrelPacket(fixed=header)
@pytest.mark.parametrize("prop", [
"packet_id"
])
def test_empty_variable_header(prop):
packet = PubrelPacket()
with pytest.raises(ValueError):
assert getattr(packet, prop) is not None
with pytest.raises(ValueError):
assert setattr(packet, prop, "a value")

Wyświetl plik

@ -510,17 +510,62 @@ async def test_client_publish_dup(broker):
@pytest.mark.asyncio
async def test_client_publish_invalid_topic(broker):
async def test_client_publishing_invalid_topic(broker):
assert broker.transitions.is_started()
pub_client = MQTTClient()
pub_client = MQTTClient(config={'auto_reconnect': False})
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.publish("/+", b"data", QOS_0)
await pub_client.subscribe([
("my/+/topic", QOS_0)
])
await asyncio.sleep(0.5)
# need to build & send packet directly to bypass client's check of invalid topic name
# see test_client.py::test_publish_to_incorrect_wildcard for client checks
packet = PublishPacket.build(topic_name='my/topic', message=b'messages',
packet_id=None, dup_flag=False, qos=QOS_0, retain=False)
packet.topic_name = "my/+/topic"
await pub_client._handler._send_packet(packet)
await asyncio.sleep(0.5)
with pytest.raises(asyncio.TimeoutError):
msg = await pub_client.deliver_message(timeout_duration=1)
assert msg is None
await asyncio.sleep(0.1)
await pub_client.disconnect()
@pytest.mark.asyncio
async def test_client_publish_asterisk(broker):
"""'*' is a valid, non-wildcard character for MQTT."""
assert broker.transitions.is_started()
pub_client = MQTTClient(config={'auto_reconnect': False})
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.subscribe([
("my*/topic", QOS_0),
("my/+/topic", QOS_0)
])
await asyncio.sleep(0.1)
await pub_client.publish('my*/topic', b'my valid message', QOS_0, retain=False)
await asyncio.sleep(0.1)
msg = await pub_client.deliver_message(timeout_duration=1)
assert msg is not None
assert msg.topic == "my*/topic"
assert msg.data == b'my valid message'
await asyncio.sleep(0.1)
msg = await pub_client.publish('my/****/topic', b'my valid message', QOS_0, retain=False)
assert msg is not None
assert msg.topic == "my/****/topic"
assert msg.data == b'my valid message'
await pub_client.disconnect()
@pytest.mark.asyncio
async def test_client_publish_big(broker, mock_plugin_manager):
pub_client = MQTTClient()

Wyświetl plik

@ -7,7 +7,7 @@ import pytest
from amqtt.broker import Broker
from amqtt.client import MQTTClient
from amqtt.errors import ClientError, ConnectError
from amqtt.errors import ClientError, ConnectError, MQTTError
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
@ -490,3 +490,17 @@ async def test_client_no_auth():
await client.connect("mqtt://127.0.0.1:1883/")
await broker.shutdown()
@pytest.mark.asyncio
async def test_publish_to_incorrect_wildcard(broker_fixture):
client = MQTTClient(config={'auto_reconnect': False})
await client.connect("mqtt://127.0.0.1/")
with pytest.raises(MQTTError):
await client.publish("my/+/topic", b'plus-sign wildcard topic invalid publish')
with pytest.raises(MQTTError):
await client.publish("topic/#", b'hash wildcard topic invalid publish')
await client.publish("topic/*", b'asterisk topic normal publish')
await client.disconnect()