amqtt/tests/test_broker.py

699 wiersze
21 KiB
Python

# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import asyncio
import logging
import socket
import sys
from unittest.mock import call, MagicMock, patch
import psutil
import pytest
from amqtt.adapters import StreamReaderAdapter, StreamWriterAdapter
from amqtt.broker import (
EVENT_BROKER_PRE_START,
EVENT_BROKER_POST_START,
EVENT_BROKER_PRE_SHUTDOWN,
EVENT_BROKER_POST_SHUTDOWN,
EVENT_BROKER_CLIENT_CONNECTED,
EVENT_BROKER_CLIENT_DISCONNECTED,
EVENT_BROKER_CLIENT_SUBSCRIBED,
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
EVENT_BROKER_MESSAGE_RECEIVED,
)
from amqtt.client import MQTTClient, ConnectException
from amqtt.mqtt import (
ConnectPacket,
ConnackPacket,
PublishPacket,
PubrecPacket,
PubrelPacket,
PubcompPacket,
DisconnectPacket,
)
from amqtt.mqtt.connect import ConnectVariableHeader, ConnectPayload
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
formatter = (
"[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.DEBUG, format=formatter)
log = logging.getLogger(__name__)
# monkey patch MagicMock
# taken from https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression
async def async_magic():
pass
MagicMock.__await__ = lambda x: async_magic().__await__()
@pytest.mark.asyncio
async def test_start_stop(broker, mock_plugin_manager):
mock_plugin_manager.assert_has_calls(
[
call().fire_event(EVENT_BROKER_PRE_START),
call().fire_event(EVENT_BROKER_POST_START),
],
any_order=True,
)
mock_plugin_manager.reset_mock()
await broker.shutdown()
mock_plugin_manager.assert_has_calls(
[
call().fire_event(EVENT_BROKER_PRE_SHUTDOWN),
call().fire_event(EVENT_BROKER_POST_SHUTDOWN),
],
any_order=True,
)
assert broker.transitions.is_stopped()
@pytest.mark.asyncio
async def test_client_connect(broker, mock_plugin_manager):
client = MQTTClient()
ret = await client.connect("mqtt://127.0.0.1/")
assert ret == 0
assert client.session.client_id in broker._sessions
await client.disconnect()
await asyncio.sleep(0.01)
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_CLIENT_CONNECTED, client_id=client.session.client_id
),
call().fire_event(
EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client.session.client_id
),
],
any_order=True,
)
@pytest.mark.asyncio
async def test_connect_tcp(broker):
process = psutil.Process()
connections_number = 10
sockets = [
socket.create_connection(("127.0.0.1", 1883)) for _ in range(connections_number)
]
connections = process.connections()
await asyncio.sleep(0.1)
# max number of connections is 10
assert broker._servers["default"].conn_count == connections_number
# Extra connection for the listening Broker
assert len(connections) == connections_number + 1
for conn in connections:
assert conn.status == "ESTABLISHED" or conn.status == "LISTEN"
# close all connections
for s in sockets:
s.close()
connections = process.connections()
for conn in connections:
assert conn.status == "CLOSE_WAIT" or conn.status == "LISTEN"
await asyncio.sleep(0.1)
assert broker._servers["default"].conn_count == 0
# Add one more connection
s = socket.create_connection(("127.0.0.1", 1883))
open_connections = []
for conn in process.connections():
if conn.status == "ESTABLISHED":
open_connections.append(conn)
assert len(open_connections) == 1
assert broker._servers["default"].conn_count == 1
@pytest.mark.asyncio
async def test_client_connect_will_flag(broker):
conn_reader, conn_writer = await asyncio.open_connection("127.0.0.1", 1883)
reader = StreamReaderAdapter(conn_reader)
writer = StreamWriterAdapter(conn_writer)
vh = ConnectVariableHeader()
payload = ConnectPayload()
vh.keep_alive = 10
vh.clean_session_flag = False
vh.will_retain_flag = False
vh.will_flag = True
vh.will_qos = QOS_0
payload.client_id = "test_id"
payload.will_message = b"test"
payload.will_topic = "/topic"
connect = ConnectPacket(vh=vh, payload=payload)
await connect.to_stream(writer)
await ConnackPacket.from_stream(reader)
await asyncio.sleep(0.1)
disconnect = DisconnectPacket()
await disconnect.to_stream(writer)
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_connect_clean_session_false(broker):
client = MQTTClient(client_id="", config={"auto_reconnect": False})
return_code = None
try:
await client.connect("mqtt://127.0.0.1/", cleansession=False)
except ConnectException as ce:
return_code = ce.return_code
assert return_code == 0x02
assert client.session.client_id not in broker._sessions
await client.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_subscribe(broker, mock_plugin_manager):
client = MQTTClient()
ret = await client.connect("mqtt://127.0.0.1/")
assert ret == 0
await client.subscribe([("/topic", QOS_0)])
# Test if the client test client subscription is registered
assert "/topic" in broker._subscriptions
subs = broker._subscriptions["/topic"]
assert len(subs) == 1
(s, qos) = subs[0]
assert s == client.session
assert qos == QOS_0
await client.disconnect()
await asyncio.sleep(0.1)
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
client_id=client.session.client_id,
topic="/topic",
qos=QOS_0,
)
],
any_order=True,
)
@pytest.mark.asyncio
async def test_client_subscribe_twice(broker, mock_plugin_manager):
client = MQTTClient()
ret = await client.connect("mqtt://127.0.0.1/")
assert ret == 0
await client.subscribe([("/topic", QOS_0)])
# Test if the client test client subscription is registered
assert "/topic" in broker._subscriptions
subs = broker._subscriptions["/topic"]
assert len(subs) == 1
(s, qos) = subs[0]
assert s == client.session
assert qos == QOS_0
await client.subscribe([("/topic", QOS_0)])
assert len(subs) == 1
(s, qos) = subs[0]
assert s == client.session
assert qos == QOS_0
await client.disconnect()
await asyncio.sleep(0.1)
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
client_id=client.session.client_id,
topic="/topic",
qos=QOS_0,
)
],
any_order=True,
)
@pytest.mark.asyncio
async def test_client_unsubscribe(broker, mock_plugin_manager):
client = MQTTClient()
ret = await client.connect("mqtt://127.0.0.1/")
assert ret == 0
await client.subscribe([("/topic", QOS_0)])
# Test if the client test client subscription is registered
assert "/topic" in broker._subscriptions
subs = broker._subscriptions["/topic"]
assert len(subs) == 1
(s, qos) = subs[0]
assert s == client.session
assert qos == QOS_0
await client.unsubscribe(["/topic"])
await asyncio.sleep(0.1)
assert broker._subscriptions["/topic"] == []
await client.disconnect()
await asyncio.sleep(0.1)
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
client_id=client.session.client_id,
topic="/topic",
qos=QOS_0,
),
call().fire_event(
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
client_id=client.session.client_id,
topic="/topic",
),
],
any_order=True,
)
@pytest.mark.asyncio
async def test_client_publish(broker, mock_plugin_manager):
pub_client = MQTTClient()
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
ret_message = await pub_client.publish("/topic", b"data", QOS_0)
await pub_client.disconnect()
assert broker._retained_messages == {}
await asyncio.sleep(0.1)
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_MESSAGE_RECEIVED,
client_id=pub_client.session.client_id,
message=ret_message,
),
],
any_order=True,
)
@pytest.mark.asyncio
async def test_client_publish_acl_permitted(acl_broker):
sub_client = MQTTClient()
ret = await sub_client.connect("mqtt://user2:user2password@127.0.0.1:1884/")
assert ret == 0
ret = await sub_client.subscribe([("public/subtopic/test", QOS_0)])
assert ret == [QOS_0]
pub_client = MQTTClient()
ret = await pub_client.connect("mqtt://user1:user1password@127.0.0.1:1884/")
assert ret == 0
await pub_client.publish("public/subtopic/test", b"data", QOS_0)
message = await sub_client.deliver_message(timeout=1)
await pub_client.disconnect()
await sub_client.disconnect()
assert message is not None
assert message.topic == "public/subtopic/test"
assert message.data == b"data"
assert message.qos == QOS_0
@pytest.mark.asyncio
async def test_client_publish_acl_forbidden(acl_broker):
sub_client = MQTTClient()
ret = await sub_client.connect("mqtt://user2:user2password@127.0.0.1:1884/")
assert ret == 0
ret = await sub_client.subscribe([("public/forbidden/test", QOS_0)])
assert ret == [QOS_0]
pub_client = MQTTClient()
ret = await pub_client.connect("mqtt://user1:user1password@127.0.0.1:1884/")
assert ret == 0
await pub_client.publish("public/forbidden/test", b"data", QOS_0)
try:
await sub_client.deliver_message(timeout=1)
assert False, "Should not have worked"
except asyncio.TimeoutError:
pass
await pub_client.disconnect()
await sub_client.disconnect()
@pytest.mark.asyncio
async def test_client_publish_acl_permitted_sub_forbidden(acl_broker):
sub_client1 = MQTTClient()
ret = await sub_client1.connect("mqtt://user2:user2password@127.0.0.1:1884/")
assert ret == 0
sub_client2 = MQTTClient()
ret = await sub_client2.connect("mqtt://user3:user3password@127.0.0.1:1884/")
assert ret == 0
ret = await sub_client1.subscribe([("public/subtopic/test", QOS_0)])
assert ret == [QOS_0]
ret = await sub_client2.subscribe([("public/subtopic/test", QOS_0)])
assert ret == [0x80]
pub_client = MQTTClient()
ret = await pub_client.connect("mqtt://user1:user1password@127.0.0.1:1884/")
assert ret == 0
await pub_client.publish("public/subtopic/test", b"data", QOS_0)
message = await sub_client1.deliver_message(timeout=1)
try:
await sub_client2.deliver_message(timeout=1)
assert False, "Should not have worked"
except asyncio.TimeoutError:
pass
await pub_client.disconnect()
await sub_client1.disconnect()
await sub_client2.disconnect()
assert message is not None
assert message.topic == "public/subtopic/test"
assert message.data == b"data"
assert message.qos == QOS_0
@pytest.mark.asyncio
async def test_client_publish_dup(broker):
conn_reader, conn_writer = await asyncio.open_connection("127.0.0.1", 1883)
reader = StreamReaderAdapter(conn_reader)
writer = StreamWriterAdapter(conn_writer)
vh = ConnectVariableHeader()
payload = ConnectPayload()
vh.keep_alive = 10
vh.clean_session_flag = False
vh.will_retain_flag = False
payload.client_id = "test_id"
connect = ConnectPacket(vh=vh, payload=payload)
await connect.to_stream(writer)
await ConnackPacket.from_stream(reader)
publish_1 = PublishPacket.build("/test", b"data", 1, False, QOS_2, False)
await publish_1.to_stream(writer)
asyncio.ensure_future(PubrecPacket.from_stream(reader))
await asyncio.sleep(2)
publish_dup = PublishPacket.build("/test", b"data", 1, True, QOS_2, False)
await publish_dup.to_stream(writer)
await PubrecPacket.from_stream(reader)
pubrel = PubrelPacket.build(1)
await pubrel.to_stream(writer)
await PubcompPacket.from_stream(reader)
disconnect = DisconnectPacket()
await disconnect.to_stream(writer)
@pytest.mark.asyncio
async def test_client_publish_invalid_topic(broker):
assert broker.transitions.is_started()
pub_client = MQTTClient()
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.publish("/+", b"data", QOS_0)
await asyncio.sleep(0.1)
await pub_client.disconnect()
@pytest.mark.asyncio
async def test_client_publish_big(broker, mock_plugin_manager):
pub_client = MQTTClient()
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
ret_message = await pub_client.publish(
"/topic", bytearray(b"\x99" * 256 * 1024), QOS_2
)
await pub_client.disconnect()
assert broker._retained_messages == {}
await asyncio.sleep(0.1)
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_MESSAGE_RECEIVED,
client_id=pub_client.session.client_id,
message=ret_message,
),
],
any_order=True,
)
@pytest.mark.asyncio
async def test_client_publish_retain(broker):
pub_client = MQTTClient()
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.publish("/topic", b"data", QOS_0, retain=True)
await pub_client.disconnect()
await asyncio.sleep(0.1)
assert "/topic" in broker._retained_messages
retained_message = broker._retained_messages["/topic"]
assert retained_message.source_session == pub_client.session
assert retained_message.topic == "/topic"
assert retained_message.data == b"data"
assert retained_message.qos == QOS_0
@pytest.mark.asyncio
async def test_client_publish_retain_delete(broker):
pub_client = MQTTClient()
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
await pub_client.publish("/topic", b"", QOS_0, retain=True)
await pub_client.disconnect()
await asyncio.sleep(0.1)
assert "/topic" not in broker._retained_messages
@pytest.mark.asyncio
async def test_client_subscribe_publish(broker):
sub_client = MQTTClient()
await sub_client.connect("mqtt://127.0.0.1")
ret = await sub_client.subscribe(
[("/qos0", QOS_0), ("/qos1", QOS_1), ("/qos2", QOS_2)]
)
assert ret == [QOS_0, QOS_1, QOS_2]
await _client_publish("/qos0", b"data", QOS_0)
await _client_publish("/qos1", b"data", QOS_1)
await _client_publish("/qos2", b"data", QOS_2)
await asyncio.sleep(0.1)
for qos in [QOS_0, QOS_1, QOS_2]:
message = await sub_client.deliver_message()
assert message is not None
assert message.topic == "/qos%s" % qos
assert message.data == b"data"
assert message.qos == qos
await sub_client.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_subscribe_invalid(broker):
sub_client = MQTTClient()
await sub_client.connect("mqtt://127.0.0.1")
ret = await sub_client.subscribe(
[
("+", QOS_0),
("+/tennis/#", QOS_0),
("sport+", QOS_0),
("sport/+/player1", QOS_0),
]
)
assert ret == [QOS_0, QOS_0, 0x80, QOS_0]
await asyncio.sleep(0.1)
await sub_client.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_subscribe_publish_dollar_topic_1(broker):
assert broker.transitions.is_started()
sub_client = MQTTClient()
await sub_client.connect("mqtt://127.0.0.1")
ret = await sub_client.subscribe([("#", QOS_0)])
assert ret == [QOS_0]
await _client_publish("/topic", b"data", QOS_0)
message = await sub_client.deliver_message()
assert message is not None
await _client_publish("$topic", b"data", QOS_0)
await asyncio.sleep(0.1)
message = None
try:
message = await sub_client.deliver_message(timeout=2)
except asyncio.TimeoutError:
pass
except RuntimeError as e:
# The loop is closed with pending tasks. Needs fine tuning.
log.warning(e)
assert message is None
await sub_client.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_client_subscribe_publish_dollar_topic_2(broker):
sub_client = MQTTClient()
await sub_client.connect("mqtt://127.0.0.1")
ret = await sub_client.subscribe([("+/monitor/Clients", QOS_0)])
assert ret == [QOS_0]
await _client_publish("test/monitor/Clients", b"data", QOS_0)
message = await sub_client.deliver_message()
assert message is not None
await _client_publish("$SYS/monitor/Clients", b"data", QOS_0)
await asyncio.sleep(0.1)
message = None
try:
message = await sub_client.deliver_message(timeout=2)
except asyncio.TimeoutError:
pass
except RuntimeError as e:
# The loop is closed with pending tasks. Needs fine tuning.
log.warning(e)
assert message is None
await sub_client.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
@pytest.mark.xfail(
reason="see https://github.com/Yakifo/aio-amqtt/issues/16", strict=False
)
async def test_client_publish_retain_subscribe(broker):
sub_client = MQTTClient()
await sub_client.connect("mqtt://127.0.0.1", cleansession=False)
ret = await sub_client.subscribe(
[("/qos0", QOS_0), ("/qos1", QOS_1), ("/qos2", QOS_2)]
)
assert ret == [QOS_0, QOS_1, QOS_2]
await sub_client.disconnect()
await asyncio.sleep(0.1)
await _client_publish("/qos0", b"data", QOS_0, retain=True)
await _client_publish("/qos1", b"data", QOS_1, retain=True)
await _client_publish("/qos2", b"data", QOS_2, retain=True)
await sub_client.reconnect()
for qos in [QOS_0, QOS_1, QOS_2]:
log.debug("TEST QOS: %d" % qos)
message = await sub_client.deliver_message()
log.debug("Message: " + repr(message.publish_packet))
assert message is not None
assert message.topic == "/qos%s" % qos
assert message.data == b"data"
assert message.qos == qos
await sub_client.disconnect()
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def _client_publish(topic, data, qos, retain=False):
pub_client = MQTTClient()
ret = await pub_client.connect("mqtt://127.0.0.1/")
assert ret == 0
ret = await pub_client.publish(topic, data, qos, retain)
await pub_client.disconnect()
return ret
def test_matches_multi_level_wildcard(broker):
test_filter = "sport/tennis/player1/#"
for bad_topic in [
"sport/tennis",
"sport/tennis/",
]:
assert not broker.matches(bad_topic, test_filter)
for good_topic in [
"sport/tennis/player1",
"sport/tennis/player1/",
"sport/tennis/player1/ranking",
"sport/tennis/player1/score/wimbledon",
]:
assert broker.matches(good_topic, test_filter)
def test_matches_single_level_wildcard(broker):
test_filter = "sport/tennis/+"
for bad_topic in [
"sport/tennis",
"sport/tennis/player1/",
"sport/tennis/player1/ranking",
]:
assert not broker.matches(bad_topic, test_filter)
for good_topic in [
"sport/tennis/",
"sport/tennis/player1",
"sport/tennis/player2",
]:
assert broker.matches(good_topic, test_filter)
@pytest.mark.asyncio
async def test_broker_broadcast_cancellation(broker):
topic = "test"
data = b"data"
qos = QOS_0
sub_client = MQTTClient()
await sub_client.connect("mqtt://127.0.0.1")
await sub_client.subscribe([(topic, qos)])
with patch.object(
BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError
) as mocked_mqtt_publish:
await _client_publish(topic, data, qos)
# Second publish triggers the awaiting of first `mqtt_publish` task
await _client_publish(topic, data, qos)
await asyncio.sleep(0.01)
# `assert_awaited` does not exist in Python before `3.8`
if sys.version_info >= (3, 8):
mocked_mqtt_publish.assert_awaited()
else:
mocked_mqtt_publish.assert_called()
# Ensure broadcast loop is still functional and can deliver the message
await _client_publish(topic, data, qos)
message = await asyncio.wait_for(sub_client.deliver_message(), timeout=1)
assert message