kopia lustrzana https://github.com/Yakifo/amqtt
Merge remote-tracking branch 'source/0.11.2-rc' into session_persistence
commit
3fa54ab7a5
|
@ -183,6 +183,8 @@ class Broker:
|
|||
self._subscriptions: dict[str, list[tuple[Session, int]]] = {}
|
||||
self._retained_messages: dict[str, RetainedApplicationMessage] = {}
|
||||
|
||||
self._topic_filter_matchers: dict[str, re.Pattern[str]] = {}
|
||||
|
||||
# Broadcast queue for outgoing messages
|
||||
self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||
self._broadcast_task: asyncio.Task[Any] | None = None
|
||||
|
@ -690,6 +692,11 @@ class Broker:
|
|||
f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
|
||||
)
|
||||
return False
|
||||
if app_message.topic.startswith("$"):
|
||||
self.logger.warning(
|
||||
f"[MQTT-4.7.2-1] - {client_session.client_id} cannot use a topic with a leading $ character."
|
||||
)
|
||||
return False
|
||||
|
||||
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
|
||||
if not permitted:
|
||||
|
@ -908,9 +915,6 @@ class Broker:
|
|||
self.logger.debug(f"Processing broadcast message: {broadcast}")
|
||||
|
||||
for k_filter, subscriptions in self._subscriptions.items():
|
||||
if broadcast["topic"].startswith("$") and (k_filter.startswith(("+", "#"))):
|
||||
self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
|
||||
continue
|
||||
|
||||
# Skip all subscriptions which do not match the topic
|
||||
if not self._matches(broadcast["topic"], k_filter):
|
||||
|
@ -1039,11 +1043,21 @@ class Broker:
|
|||
)
|
||||
|
||||
def _matches(self, topic: str, a_filter: str) -> bool:
|
||||
if topic.startswith("$") and (a_filter.startswith(("+", "#"))):
|
||||
self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
|
||||
return False
|
||||
|
||||
if "#" not in a_filter and "+" not in a_filter:
|
||||
# if filter doesn't contain wildcard, return exact match
|
||||
return a_filter == topic
|
||||
# else use regex
|
||||
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
|
||||
|
||||
# else use regex (re.compile is an expensive operation, store the matcher for future use)
|
||||
if a_filter not in self._topic_filter_matchers:
|
||||
self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter)
|
||||
.replace("\\#", "?.*")
|
||||
.replace("\\+", "[^/]*")
|
||||
.lstrip("?"))
|
||||
match_pattern = self._topic_filter_matchers[a_filter]
|
||||
return bool(match_pattern.fullmatch(topic))
|
||||
|
||||
def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:
|
||||
|
|
|
@ -520,8 +520,9 @@ class ProtocolHandler(Generic[C]):
|
|||
elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket):
|
||||
task = asyncio.create_task(self.handle_disconnect(packet))
|
||||
elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket):
|
||||
# TODO: why is this not like all other inside create_task?
|
||||
await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet))
|
||||
# q: why is this not like all other inside a create_task?
|
||||
# a: the connection needs to be established before any other packet tasks for this new session are scheduled
|
||||
await self.handle_connect(packet)
|
||||
if task:
|
||||
running_tasks.append(task)
|
||||
except MQTTError:
|
||||
|
|
|
@ -12,7 +12,7 @@ class PublishVariableHeader(MQTTVariableHeader):
|
|||
|
||||
def __init__(self, topic_name: str, packet_id: int | None = None) -> None:
|
||||
super().__init__()
|
||||
if "*" in topic_name:
|
||||
if "#" in topic_name or "+" in topic_name:
|
||||
msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."
|
||||
raise MQTTError(msg)
|
||||
self.topic_name = topic_name
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
import pytest
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.connack import ConnackPacket
|
||||
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
|
||||
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = ConnackPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"return_code",
|
||||
"session_parent"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = ConnackPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
|
@ -1,9 +1,11 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
|
||||
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader
|
||||
from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, PUBLISH
|
||||
|
||||
|
||||
class ConnectPacketTest(unittest.TestCase):
|
||||
|
@ -150,3 +152,36 @@ class ConnectPacketTest(unittest.TestCase):
|
|||
assert message.username == "user"
|
||||
assert message.payload.password == "password"
|
||||
assert message.password == "password"
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = ConnectPacket(fixed=header)
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"proto_name",
|
||||
"proto_level",
|
||||
"username_flag",
|
||||
"password_flag",
|
||||
"clean_session_flag",
|
||||
"will_retain_flag",
|
||||
"will_qos",
|
||||
"will_flag",
|
||||
"reserved_flag",
|
||||
"client_id",
|
||||
"client_id_is_random",
|
||||
"will_topic",
|
||||
"will_message",
|
||||
"username",
|
||||
"password",
|
||||
"keep_alive",
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = ConnectPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt import PUBLISH, MQTTFixedHeader
|
||||
from amqtt.mqtt.puback import PacketIdVariableHeader, PubackPacket
|
||||
|
||||
|
||||
|
@ -20,3 +23,22 @@ class PubackPacketTest(unittest.TestCase):
|
|||
publish = PubackPacket(variable_header=variable_header)
|
||||
out = publish.to_bytes()
|
||||
assert out == b"@\x02\x00\n"
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
connect_packet = PubackPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id",
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
connect_packet = PubackPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(connect_packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(connect_packet, prop, "a value")
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt import MQTTFixedHeader, PUBLISH
|
||||
from amqtt.mqtt.pubcomp import PacketIdVariableHeader, PubcompPacket
|
||||
|
||||
|
||||
|
@ -20,3 +23,21 @@ class PubcompPacketTest(unittest.TestCase):
|
|||
publish = PubcompPacket(variable_header=variable_header)
|
||||
out = publish.to_bytes()
|
||||
assert out == b"p\x02\x00\n"
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = PubcompPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = PubcompPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.packet import MQTTFixedHeader, CONNECT
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
from amqtt.mqtt.publish import PublishPacket, PublishPayload, PublishVariableHeader
|
||||
|
||||
|
@ -116,3 +119,28 @@ class PublishPacketTest(unittest.TestCase):
|
|||
assert packet.dup_flag
|
||||
assert packet.qos == QOS_2
|
||||
assert packet.retain_flag
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(CONNECT, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = PublishPacket(fixed=header)
|
||||
|
||||
def test_set_flags():
|
||||
packet = PublishPacket()
|
||||
packet.set_flags(dup_flag=True, qos=QOS_1, retain_flag=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id",
|
||||
"data",
|
||||
"topic_name"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = PublishPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
|
||||
from amqtt.mqtt.pubrec import PacketIdVariableHeader, PubrecPacket
|
||||
|
||||
|
||||
|
@ -20,3 +23,22 @@ class PubrecPacketTest(unittest.TestCase):
|
|||
publish = PubrecPacket(variable_header=variable_header)
|
||||
out = publish.to_bytes()
|
||||
assert out == b"P\x02\x00\n"
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = PubrecPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = PubrecPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from amqtt.adapters import BufferReader
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.mqtt.packet import MQTTFixedHeader, PUBLISH
|
||||
from amqtt.mqtt.pubrel import PacketIdVariableHeader, PubrelPacket
|
||||
|
||||
|
||||
|
@ -20,3 +23,22 @@ class PubrelPacketTest(unittest.TestCase):
|
|||
publish = PubrelPacket(variable_header=variable_header)
|
||||
out = publish.to_bytes()
|
||||
assert out == b"b\x02\x00\n"
|
||||
|
||||
|
||||
def test_incorrect_fixed_header():
|
||||
header = MQTTFixedHeader(PUBLISH, 0x00)
|
||||
with pytest.raises(AMQTTError):
|
||||
_ = PubrelPacket(fixed=header)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prop", [
|
||||
"packet_id"
|
||||
])
|
||||
def test_empty_variable_header(prop):
|
||||
packet = PubrelPacket()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert getattr(packet, prop) is not None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert setattr(packet, prop, "a value")
|
|
@ -510,17 +510,62 @@ async def test_client_publish_dup(broker):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_invalid_topic(broker):
|
||||
async def test_client_publishing_invalid_topic(broker):
|
||||
assert broker.transitions.is_started()
|
||||
pub_client = MQTTClient()
|
||||
pub_client = MQTTClient(config={'auto_reconnect': False})
|
||||
ret = await pub_client.connect("mqtt://127.0.0.1/")
|
||||
assert ret == 0
|
||||
|
||||
await pub_client.publish("/+", b"data", QOS_0)
|
||||
await pub_client.subscribe([
|
||||
("my/+/topic", QOS_0)
|
||||
])
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# need to build & send packet directly to bypass client's check of invalid topic name
|
||||
# see test_client.py::test_publish_to_incorrect_wildcard for client checks
|
||||
packet = PublishPacket.build(topic_name='my/topic', message=b'messages',
|
||||
packet_id=None, dup_flag=False, qos=QOS_0, retain=False)
|
||||
packet.topic_name = "my/+/topic"
|
||||
await pub_client._handler._send_packet(packet)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
msg = await pub_client.deliver_message(timeout_duration=1)
|
||||
assert msg is None
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await pub_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_asterisk(broker):
|
||||
"""'*' is a valid, non-wildcard character for MQTT."""
|
||||
assert broker.transitions.is_started()
|
||||
pub_client = MQTTClient(config={'auto_reconnect': False})
|
||||
ret = await pub_client.connect("mqtt://127.0.0.1/")
|
||||
assert ret == 0
|
||||
|
||||
await pub_client.subscribe([
|
||||
("my*/topic", QOS_0),
|
||||
("my/+/topic", QOS_0)
|
||||
])
|
||||
await asyncio.sleep(0.1)
|
||||
await pub_client.publish('my*/topic', b'my valid message', QOS_0, retain=False)
|
||||
await asyncio.sleep(0.1)
|
||||
msg = await pub_client.deliver_message(timeout_duration=1)
|
||||
assert msg is not None
|
||||
assert msg.topic == "my*/topic"
|
||||
assert msg.data == b'my valid message'
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
msg = await pub_client.publish('my/****/topic', b'my valid message', QOS_0, retain=False)
|
||||
assert msg is not None
|
||||
assert msg.topic == "my/****/topic"
|
||||
assert msg.data == b'my valid message'
|
||||
|
||||
await pub_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_big(broker, mock_plugin_manager):
|
||||
pub_client = MQTTClient()
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import ClientError, ConnectError
|
||||
from amqtt.errors import ClientError, ConnectError, MQTTError
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
|
@ -490,3 +490,17 @@ async def test_client_no_auth():
|
|||
await client.connect("mqtt://127.0.0.1:1883/")
|
||||
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_to_incorrect_wildcard(broker_fixture):
|
||||
client = MQTTClient(config={'auto_reconnect': False})
|
||||
await client.connect("mqtt://127.0.0.1/")
|
||||
|
||||
with pytest.raises(MQTTError):
|
||||
await client.publish("my/+/topic", b'plus-sign wildcard topic invalid publish')
|
||||
with pytest.raises(MQTTError):
|
||||
await client.publish("topic/#", b'hash wildcard topic invalid publish')
|
||||
|
||||
await client.publish("topic/*", b'asterisk topic normal publish')
|
||||
await client.disconnect()
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_0
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_to_dollar_sign_topics():
|
||||
"""Applications cannot use a topic with a leading $ character for their own purposes [MQTT-4.7.2-1]."""
|
||||
|
||||
cfg = {
|
||||
'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}},
|
||||
'plugins': {'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True}},
|
||||
}
|
||||
|
||||
b = Broker(config=cfg)
|
||||
await b.start()
|
||||
await asyncio.sleep(0.1)
|
||||
c = MQTTClient(config={'auto_reconnect': False})
|
||||
await c.connect()
|
||||
await asyncio.sleep(0.1)
|
||||
await c.subscribe(
|
||||
[('$#', QOS_0),
|
||||
('#', QOS_0)]
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
await c.publish('$MY', b'message should be blocked')
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
# wait long enough for broker sys plugin to run
|
||||
_ = await c.deliver_message(timeout_duration=1)
|
||||
|
||||
await c.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await b.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hash_will_not_receive_dollar():
|
||||
"""A subscription to “#” will not receive any messages published to a topic beginning with a $ [MQTT-4.7.2-1]."""
|
||||
|
||||
cfg = {
|
||||
'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}},
|
||||
'plugins': {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True},
|
||||
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 2}
|
||||
}
|
||||
}
|
||||
|
||||
b = Broker(config=cfg)
|
||||
await b.start()
|
||||
await asyncio.sleep(0.1)
|
||||
c = MQTTClient(config={'auto_reconnect': False})
|
||||
await c.connect()
|
||||
await asyncio.sleep(0.1)
|
||||
await c.subscribe(
|
||||
[('#', QOS_0)]
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
# wait long enough for broker sys plugin to run
|
||||
_ = await c.deliver_message(timeout_duration=5)
|
||||
|
||||
await c.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await b.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plus_will_not_receive_dollar():
|
||||
"""A subscription to “+/monitor/Clients” will not receive any messages published to “$SYS/monitor/Clients [MQTT-4.7.2-1]"""
|
||||
# BrokerSysPlugin doesn't use $SYS/monitor/Clients, so this is an equivalent test with $SYS/broker topics
|
||||
|
||||
cfg = {
|
||||
'listeners': {'default': {'type': 'tcp', 'bind': '127.0.0.1'}},
|
||||
'plugins': {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': {"allow_anonymous": True},
|
||||
'amqtt.plugins.sys.broker.BrokerSysPlugin': {"sys_interval": 2}
|
||||
}
|
||||
}
|
||||
|
||||
b = Broker(config=cfg)
|
||||
await b.start()
|
||||
await asyncio.sleep(0.1)
|
||||
c = MQTTClient(config={'auto_reconnect': False})
|
||||
await c.connect()
|
||||
await asyncio.sleep(0.1)
|
||||
await c.subscribe(
|
||||
[('+/broker/#', QOS_0),
|
||||
('+/broker/time', QOS_0),
|
||||
('+/broker/clients/#', QOS_0),
|
||||
('+/broker/+/maximum', QOS_0)
|
||||
]
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
# wait long enough for broker sys plugin to run
|
||||
_ = await c.deliver_message(timeout_duration=5)
|
||||
|
||||
await c.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await b.shutdown()
|
Ładowanie…
Reference in New Issue