kopia lustrzana https://github.com/Yakifo/amqtt
Merge pull request #251 from ajmirsky/increased_test_coverage
publishing to a topic with `*` is allowed, while `#` and `+` are notpull/257/head
commit
8022e01bb0
|
@ -520,8 +520,9 @@ class ProtocolHandler(Generic[C]):
|
|||
elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket):
|
||||
task = asyncio.create_task(self.handle_disconnect(packet))
|
||||
elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket):
|
||||
# TODO: why is this not like all other inside create_task?
|
||||
await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet))
|
||||
# q: why is this not like all other inside a create_task?
|
||||
# a: the connection needs to be established before any other packet tasks for this new session are scheduled
|
||||
await self.handle_connect(packet)
|
||||
if task:
|
||||
running_tasks.append(task)
|
||||
except MQTTError:
|
||||
|
|
|
@ -12,7 +12,7 @@ class PublishVariableHeader(MQTTVariableHeader):
|
|||
|
||||
def __init__(self, topic_name: str, packet_id: int | None = None) -> None:
|
||||
super().__init__()
|
||||
if "*" in topic_name:
|
||||
if "#" in topic_name or "+" in topic_name:
|
||||
msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."
|
||||
raise MQTTError(msg)
|
||||
self.topic_name = topic_name
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
import pytest
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.connack import ConnackPacket
|
||||
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
|
||||
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = ConnackPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"return_code",
|
||||
"session_parent"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = ConnackPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
|
@ -1,9 +1,11 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
|
||||
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader
|
||||
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, PUBLISH
|
||||
|
||||
|
||||
class ConnectPacketTest(unittest.TestCase):
|
||||
|
@ -150,3 +152,36 @@ class ConnectPacketTest(unittest.TestCase):
|
|||
assert message.username == "user"
|
||||
assert message.payload.password == "password"
|
||||
assert message.password == "password"
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = ConnectPacket(fixed=header)
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"proto_name",
|
||||
"proto_level",
|
||||
"username_flag",
|
||||
"password_flag",
|
||||
"clean_session_flag",
|
||||
"will_retain_flag",
|
||||
"will_qos",
|
||||
"will_flag",
|
||||
"reserved_flag",
|
||||
"client_id",
|
||||
"client_id_is_random",
|
||||
"will_topic",
|
||||
"will_message",
|
||||
"username",
|
||||
"password",
|
||||
"keep_alive",
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = ConnectPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt import PUBLISH, MQTTFixedHeader
|
||||
from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket
|
||||
|
||||
|
||||
|
@ -20,3 +23,22 @@ class PubackPacketTest(unittest.TestCase):
|
|||
publish = PubackPacket(variable_header=variable_header)
|
||||
out = publish.to_bytes()
|
||||
assert out == b"@\x02\x00\n"
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
connect_packet = PubackPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id",
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
connect_packet = PubackPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(connect_packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(connect_packet, prop, "a value")
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt import MQTTFixedHeader, PUBLISH
|
||||
from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket
|
||||
|
||||
|
||||
|
@ -20,3 +23,21 @@ class PubcompPacketTest(unittest.TestCase):
|
|||
publish = PubcompPacket(variable_header=variable_header)
|
||||
out = publish.to_bytes()
|
||||
assert out == b"p\x02\x00\n"
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = PubcompPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = PubcompPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.packet import MQTTFixedHeader, CONNECT
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader
|
||||
|
||||
|
@ -116,3 +119,28 @@ class PublishPacketTest(unittest.TestCase):
|
|||
assert packet.dup_flag
|
||||
assert packet.qos == QOS_2
|
||||
assert packet.retain_flag
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(CONNECT, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = PublishPacket(fixed=header)
|
||||
|
||||
def test_set_flags():
|
||||
packet = PublishPacket()
|
||||
packet.set_flags(dup_flag=True, qos=QOS_1, retain_flag=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id",
|
||||
"data",
|
||||
"topic_name"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = PublishPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
|
||||
from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket
|
||||
|
||||
|
||||
|
@ -20,3 +23,22 @@ class PubrecPacketTest(unittest.TestCase):
|
|||
publish = PubrecPacket(variable_header=variable_header)
|
||||
out = publish.to_bytes()
|
||||
assert out == b"P\x02\x00\n"
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = PubrecPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = PubrecPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
|
||||
from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket
|
||||
|
||||
|
||||
|
@ -20,3 +23,22 @@ class PubrelPacketTest(unittest.TestCase):
|
|||
publish = PubrelPacket(variable_header=variable_header)
|
||||
out = publish.to_bytes()
|
||||
assert out == b"b\x02\x00\n"
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = PubrelPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = PubrelPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
|
@ -510,17 +510,62 @@ async def test_client_publish_dup(broker):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_invalid_topic(broker):
|
||||
async def test_client_publishing_invalid_topic(broker):
|
||||
assert broker.transitions.is_started()
|
||||
pub_client = MQTTClient()
|
||||
pub_client = MQTTClient(config={'auto_reconnect': False})
|
||||
ret = await pub_client.connect("mqtt://127.0.0.1/")
|
||||
assert ret == 0
|
||||
|
||||
await pub_client.publish("/+", b"data", QOS_0)
|
||||
await pub_client.subscribe([
|
||||
("my/+/topic", QOS_0)
|
||||
])
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# need to build & send packet directly to bypass client's check of invalid topic name
|
||||
# see test_client.py::test_publish_to_incorrect_wildcard for client checks
|
||||
packet = PublishPacket.build(topic_name='my/topic', message=b'messages',
|
||||
packet_id=None, dup_flag=False, qos=QOS_0, retain=False)
|
||||
packet.topic_name = "my/+/topic"
|
||||
await pub_client._handler._send_packet(packet)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
msg = await pub_client.deliver_message(timeout_duration=1)
|
||||
assert msg is None
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await pub_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_asterisk(broker):
|
||||
"""'*' is a valid, non-wildcard character for MQTT."""
|
||||
assert broker.transitions.is_started()
|
||||
pub_client = MQTTClient(config={'auto_reconnect': False})
|
||||
ret = await pub_client.connect("mqtt://127.0.0.1/")
|
||||
assert ret == 0
|
||||
|
||||
await pub_client.subscribe([
|
||||
("my*/topic", QOS_0),
|
||||
("my/+/topic", QOS_0)
|
||||
])
|
||||
await asyncio.sleep(0.1)
|
||||
await pub_client.publish('my*/topic', b'my valid message', QOS_0, retain=False)
|
||||
await asyncio.sleep(0.1)
|
||||
msg = await pub_client.deliver_message(timeout_duration=1)
|
||||
assert msg is not None
|
||||
assert msg.topic == "my*/topic"
|
||||
assert msg.data == b'my valid message'
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
msg = await pub_client.publish('my/****/topic', b'my valid message', QOS_0, retain=False)
|
||||
assert msg is not None
|
||||
assert msg.topic == "my/****/topic"
|
||||
assert msg.data == b'my valid message'
|
||||
|
||||
await pub_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_big(broker, mock_plugin_manager):
|
||||
pub_client = MQTTClient()
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import ClientError, ConnectError
|
||||
from amqtt.errors import ClientError, ConnectError, MQTTError
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
|
@ -490,3 +490,17 @@ async def test_client_no_auth():
|
|||
await client.connect("mqtt://127.0.0.1:1883/")
|
||||
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_to_incorrect_wildcard(broker_fixture):
|
||||
client = MQTTClient(config={'auto_reconnect': False})
|
||||
await client.connect("mqtt://127.0.0.1/")
|
||||
|
||||
with pytest.raises(MQTTError):
|
||||
await client.publish("my/+/topic", b'plus-sign wildcard topic invalid publish')
|
||||
with pytest.raises(MQTTError):
|
||||
await client.publish("topic/#", b'hash wildcard topic invalid publish')
|
||||
|
||||
await client.publish("topic/*", b'asterisk topic normal publish')
|
||||
await client.disconnect()
|
||||
|
|
Ładowanie…
Reference in New Issue