kopia lustrzana https://github.com/Yakifo/amqtt
Merge pull request #251 from ajmirsky/increased_test_coverage
publishing to a topic with `*` is allowed, while `#` and `+` are notpull/257/head
commit
8022e01bb0
|
@ -520,8 +520,9 @@ class ProtocolHandler(Generic[C]):
|
||||||
elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket):
|
elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket):
|
||||||
task = asyncio.create_task(self.handle_disconnect(packet))
|
task = asyncio.create_task(self.handle_disconnect(packet))
|
||||||
elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket):
|
elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket):
|
||||||
# TODO: why is this not like all other inside create_task?
|
# q: why is this not like all other inside a create_task?
|
||||||
await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet))
|
# a: the connection needs to be established before any other packet tasks for this new session are scheduled
|
||||||
|
await self.handle_connect(packet)
|
||||||
if task:
|
if task:
|
||||||
running_tasks.append(task)
|
running_tasks.append(task)
|
||||||
except MQTTError:
|
except MQTTError:
|
||||||
|
|
|
@ -12,7 +12,7 @@ class PublishVariableHeader(MQTTVariableHeader):
|
||||||
|
|
||||||
def __init__(self, topic_name: str, packet_id: int | None = None) -> None:
|
def __init__(self, topic_name: str, packet_id: int | None = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if "*" in topic_name:
|
if "#" in topic_name or "+" in topic_name:
|
||||||
msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."
|
msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."
|
||||||
raise MQTTError(msg)
|
raise MQTTError(msg)
|
||||||
self.topic_name = topic_name
|
self.topic_name = topic_name
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
import pytest
|
||||||
|
from amqtt.errors import AMQTTError
|
||||||
|
from amqtt.mqtt.connack import ConnackPacket
|
||||||
|
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_incorrect_fixed_header():
|
||||||
|
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||||
|
with pytest.raises(AMQTTError):
|
||||||
|
_ = ConnackPacket(fixed=header)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prop", [
|
||||||
|
"return_code",
|
||||||
|
"session_parent"
|
||||||
|
])
|
||||||
|
def test_empty_variable_header(prop):
|
||||||
|
packet = ConnackPacket()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert getattr(packet, prop) is not None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert setattr(packet, prop, "a value")
|
|
@ -1,9 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
|
|
||||||
from amqtt.adapters import BufferReader
|
from amqtt.adapters import BufferReader
|
||||||
|
from amqtt.errors import AMQTTError
|
||||||
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
|
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
|
||||||
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader
|
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, PUBLISH
|
||||||
|
|
||||||
|
|
||||||
class ConnectPacketTest(unittest.TestCase):
|
class ConnectPacketTest(unittest.TestCase):
|
||||||
|
@ -150,3 +152,36 @@ class ConnectPacketTest(unittest.TestCase):
|
||||||
assert message.username == "user"
|
assert message.username == "user"
|
||||||
assert message.payload.password == "password"
|
assert message.payload.password == "password"
|
||||||
assert message.password == "password"
|
assert message.password == "password"
|
||||||
|
|
||||||
|
def test_incorrect_fixed_header():
|
||||||
|
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||||
|
with pytest.raises(AMQTTError):
|
||||||
|
_ = ConnectPacket(fixed=header)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prop", [
|
||||||
|
"proto_name",
|
||||||
|
"proto_level",
|
||||||
|
"username_flag",
|
||||||
|
"password_flag",
|
||||||
|
"clean_session_flag",
|
||||||
|
"will_retain_flag",
|
||||||
|
"will_qos",
|
||||||
|
"will_flag",
|
||||||
|
"reserved_flag",
|
||||||
|
"client_id",
|
||||||
|
"client_id_is_random",
|
||||||
|
"will_topic",
|
||||||
|
"will_message",
|
||||||
|
"username",
|
||||||
|
"password",
|
||||||
|
"keep_alive",
|
||||||
|
])
|
||||||
|
def test_empty_variable_header(prop):
|
||||||
|
packet = ConnectPacket()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert getattr(packet, prop) is not None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert setattr(packet, prop, "a value")
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
|
|
||||||
from amqtt.adapters import BufferReader
|
from amqtt.adapters import BufferReader
|
||||||
|
from amqtt.errors import AMQTTError
|
||||||
|
from amqtt.mqtt import PUBLISH, MQTTFixedHeader
|
||||||
from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket
|
from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,3 +23,22 @@ class PubackPacketTest(unittest.TestCase):
|
||||||
publish = PubackPacket(variable_header=variable_header)
|
publish = PubackPacket(variable_header=variable_header)
|
||||||
out = publish.to_bytes()
|
out = publish.to_bytes()
|
||||||
assert out == b"@\x02\x00\n"
|
assert out == b"@\x02\x00\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_incorrect_fixed_header():
|
||||||
|
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||||
|
with pytest.raises(AMQTTError):
|
||||||
|
connect_packet = PubackPacket(fixed=header)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prop", [
|
||||||
|
"packet_id",
|
||||||
|
])
|
||||||
|
def test_empty_variable_header(prop):
|
||||||
|
connect_packet = PubackPacket()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert getattr(connect_packet, prop) is not None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert setattr(connect_packet, prop, "a value")
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
|
|
||||||
from amqtt.adapters import BufferReader
|
from amqtt.adapters import BufferReader
|
||||||
|
from amqtt.errors import AMQTTError
|
||||||
|
from amqtt.mqtt import MQTTFixedHeader, PUBLISH
|
||||||
from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket
|
from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,3 +23,21 @@ class PubcompPacketTest(unittest.TestCase):
|
||||||
publish = PubcompPacket(variable_header=variable_header)
|
publish = PubcompPacket(variable_header=variable_header)
|
||||||
out = publish.to_bytes()
|
out = publish.to_bytes()
|
||||||
assert out == b"p\x02\x00\n"
|
assert out == b"p\x02\x00\n"
|
||||||
|
|
||||||
|
def test_incorrect_fixed_header():
|
||||||
|
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||||
|
with pytest.raises(AMQTTError):
|
||||||
|
_ = PubcompPacket(fixed=header)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prop", [
|
||||||
|
"packet_id"
|
||||||
|
])
|
||||||
|
def test_empty_variable_header(prop):
|
||||||
|
packet = PubcompPacket()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert getattr(packet, prop) is not None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert setattr(packet, prop, "a value")
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
|
|
||||||
from amqtt.adapters import BufferReader
|
from amqtt.adapters import BufferReader
|
||||||
|
from amqtt.errors import AMQTTError
|
||||||
|
from amqtt.mqtt.packet import MQTTFixedHeader, CONNECT
|
||||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||||
from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader
|
from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader
|
||||||
|
|
||||||
|
@ -116,3 +119,28 @@ class PublishPacketTest(unittest.TestCase):
|
||||||
assert packet.dup_flag
|
assert packet.dup_flag
|
||||||
assert packet.qos == QOS_2
|
assert packet.qos == QOS_2
|
||||||
assert packet.retain_flag
|
assert packet.retain_flag
|
||||||
|
|
||||||
|
|
||||||
|
def test_incorrect_fixed_header():
|
||||||
|
header = MQTTFixedHeader(CONNECT, 0x00)
|
||||||
|
with pytest.raises(AMQTTError):
|
||||||
|
_ = PublishPacket(fixed=header)
|
||||||
|
|
||||||
|
def test_set_flags():
|
||||||
|
packet = PublishPacket()
|
||||||
|
packet.set_flags(dup_flag=True, qos=QOS_1, retain_flag=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prop", [
|
||||||
|
"packet_id",
|
||||||
|
"data",
|
||||||
|
"topic_name"
|
||||||
|
])
|
||||||
|
def test_empty_variable_header(prop):
|
||||||
|
packet = PublishPacket()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert getattr(packet, prop) is not None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert setattr(packet, prop, "a value")
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
|
|
||||||
from amqtt.adapters import BufferReader
|
from amqtt.adapters import BufferReader
|
||||||
|
from amqtt.errors import AMQTTError
|
||||||
|
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
|
||||||
from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket
|
from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,3 +23,22 @@ class PubrecPacketTest(unittest.TestCase):
|
||||||
publish = PubrecPacket(variable_header=variable_header)
|
publish = PubrecPacket(variable_header=variable_header)
|
||||||
out = publish.to_bytes()
|
out = publish.to_bytes()
|
||||||
assert out == b"P\x02\x00\n"
|
assert out == b"P\x02\x00\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_incorrect_fixed_header():
|
||||||
|
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||||
|
with pytest.raises(AMQTTError):
|
||||||
|
_ = PubrecPacket(fixed=header)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prop", [
|
||||||
|
"packet_id"
|
||||||
|
])
|
||||||
|
def test_empty_variable_header(prop):
|
||||||
|
packet = PubrecPacket()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert getattr(packet, prop) is not None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert setattr(packet, prop, "a value")
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
|
|
||||||
from amqtt.adapters import BufferReader
|
from amqtt.adapters import BufferReader
|
||||||
|
from amqtt.errors import AMQTTError
|
||||||
|
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
|
||||||
from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket
|
from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,3 +23,22 @@ class PubrelPacketTest(unittest.TestCase):
|
||||||
publish = PubrelPacket(variable_header=variable_header)
|
publish = PubrelPacket(variable_header=variable_header)
|
||||||
out = publish.to_bytes()
|
out = publish.to_bytes()
|
||||||
assert out == b"b\x02\x00\n"
|
assert out == b"b\x02\x00\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_incorrect_fixed_header():
|
||||||
|
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||||
|
with pytest.raises(AMQTTError):
|
||||||
|
_ = PubrelPacket(fixed=header)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prop", [
|
||||||
|
"packet_id"
|
||||||
|
])
|
||||||
|
def test_empty_variable_header(prop):
|
||||||
|
packet = PubrelPacket()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert getattr(packet, prop) is not None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert setattr(packet, prop, "a value")
|
|
@ -510,17 +510,62 @@ async def test_client_publish_dup(broker):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_client_publish_invalid_topic(broker):
|
async def test_client_publishing_invalid_topic(broker):
|
||||||
assert broker.transitions.is_started()
|
assert broker.transitions.is_started()
|
||||||
pub_client = MQTTClient()
|
pub_client = MQTTClient(config={'auto_reconnect': False})
|
||||||
ret = await pub_client.connect("mqtt://127.0.0.1/")
|
ret = await pub_client.connect("mqtt://127.0.0.1/")
|
||||||
assert ret == 0
|
assert ret == 0
|
||||||
|
|
||||||
await pub_client.publish("/+", b"data", QOS_0)
|
await pub_client.subscribe([
|
||||||
|
("my/+/topic", QOS_0)
|
||||||
|
])
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# need to build & send packet directly to bypass client's check of invalid topic name
|
||||||
|
# see test_client.py::test_publish_to_incorrect_wildcard for client checks
|
||||||
|
packet = PublishPacket.build(topic_name='my/topic', message=b'messages',
|
||||||
|
packet_id=None, dup_flag=False, qos=QOS_0, retain=False)
|
||||||
|
packet.topic_name = "my/+/topic"
|
||||||
|
await pub_client._handler._send_packet(packet)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
msg = await pub_client.deliver_message(timeout_duration=1)
|
||||||
|
assert msg is None
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
await pub_client.disconnect()
|
await pub_client.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_publish_asterisk(broker):
|
||||||
|
"""'*' is a valid, non-wildcard character for MQTT."""
|
||||||
|
assert broker.transitions.is_started()
|
||||||
|
pub_client = MQTTClient(config={'auto_reconnect': False})
|
||||||
|
ret = await pub_client.connect("mqtt://127.0.0.1/")
|
||||||
|
assert ret == 0
|
||||||
|
|
||||||
|
await pub_client.subscribe([
|
||||||
|
("my*/topic", QOS_0),
|
||||||
|
("my/+/topic", QOS_0)
|
||||||
|
])
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await pub_client.publish('my*/topic', b'my valid message', QOS_0, retain=False)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
msg = await pub_client.deliver_message(timeout_duration=1)
|
||||||
|
assert msg is not None
|
||||||
|
assert msg.topic == "my*/topic"
|
||||||
|
assert msg.data == b'my valid message'
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
msg = await pub_client.publish('my/****/topic', b'my valid message', QOS_0, retain=False)
|
||||||
|
assert msg is not None
|
||||||
|
assert msg.topic == "my/****/topic"
|
||||||
|
assert msg.data == b'my valid message'
|
||||||
|
|
||||||
|
await pub_client.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_client_publish_big(broker, mock_plugin_manager):
|
async def test_client_publish_big(broker, mock_plugin_manager):
|
||||||
pub_client = MQTTClient()
|
pub_client = MQTTClient()
|
||||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
||||||
|
|
||||||
from amqtt.broker import Broker
|
from amqtt.broker import Broker
|
||||||
from amqtt.client import MQTTClient
|
from amqtt.client import MQTTClient
|
||||||
from amqtt.errors import ClientError, ConnectError
|
from amqtt.errors import ClientError, ConnectError, MQTTError
|
||||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||||
|
|
||||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||||
|
@ -490,3 +490,17 @@ async def test_client_no_auth():
|
||||||
await client.connect("mqtt://127.0.0.1:1883/")
|
await client.connect("mqtt://127.0.0.1:1883/")
|
||||||
|
|
||||||
await broker.shutdown()
|
await broker.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_to_incorrect_wildcard(broker_fixture):
|
||||||
|
client = MQTTClient(config={'auto_reconnect': False})
|
||||||
|
await client.connect("mqtt://127.0.0.1/")
|
||||||
|
|
||||||
|
with pytest.raises(MQTTError):
|
||||||
|
await client.publish("my/+/topic", b'plus-sign wildcard topic invalid publish')
|
||||||
|
with pytest.raises(MQTTError):
|
||||||
|
await client.publish("topic/#", b'hash wildcard topic invalid publish')
|
||||||
|
|
||||||
|
await client.publish("topic/*", b'asterisk topic normal publish')
|
||||||
|
await client.disconnect()
|
||||||
|
|
Ładowanie…
Reference in New Issue