kopia lustrzana https://github.com/Yakifo/amqtt
commit
6f724b9a23
|
@ -34,3 +34,8 @@ site/
|
|||
_build/
|
||||
.hypothesis/
|
||||
coverage.xml
|
||||
|
||||
#----- generated files -----
|
||||
*.log
|
||||
*memray*
|
||||
.coverage*
|
||||
|
|
|
@ -29,6 +29,7 @@ from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Sessio
|
|||
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
|
||||
|
||||
from .events import BrokerEvents
|
||||
from .mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
from .mqtt.disconnect import DisconnectPacket
|
||||
from .plugins.manager import PluginManager
|
||||
|
||||
|
@ -435,6 +436,7 @@ class Broker:
|
|||
await self._delete_session(client_session.client_id)
|
||||
else:
|
||||
client_session.client_id = gen_client_id()
|
||||
|
||||
client_session.parent = 0
|
||||
# Get session from cache
|
||||
elif client_session.client_id in self._sessions:
|
||||
|
@ -494,9 +496,18 @@ class Broker:
|
|||
|
||||
self.logger.debug(f"{client_session.client_id} Start messages handling")
|
||||
await handler.start()
|
||||
|
||||
# publish messages that were retained because the client session was disconnected
|
||||
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
|
||||
await self._publish_session_retained_messages(client_session)
|
||||
|
||||
# if this is not a new session, there are subscriptions associated with them; publish any topic retained messages
|
||||
self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.")
|
||||
for topic in self._subscriptions:
|
||||
await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
|
||||
|
||||
|
||||
|
||||
await self._client_message_loop(client_session, handler)
|
||||
|
||||
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
|
||||
|
@ -878,11 +889,20 @@ class Broker:
|
|||
qos = broadcast.get("qos", sub_qos)
|
||||
|
||||
# Retain all messages which cannot be broadcasted, due to the session not being connected
|
||||
if target_session.transitions.state != "connected":
|
||||
# but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]
|
||||
# and, if a client used anonymous authentication, there is no expectation that messages should be retained
|
||||
if (target_session.transitions.state != "connected"
|
||||
and not target_session.clean_session
|
||||
and qos in (QOS_1, QOS_2)
|
||||
and not target_session.is_anonymous):
|
||||
self.logger.debug(f"Session {target_session.client_id} is not connected, retaining message.")
|
||||
await self._retain_broadcast_message(broadcast, qos, target_session)
|
||||
continue
|
||||
|
||||
# Only broadcast the message to connected clients
|
||||
if target_session.transitions.state != "connected":
|
||||
continue
|
||||
|
||||
self.logger.debug(
|
||||
f"Broadcasting message from {format_client_message(session=broadcast['session'])}"
|
||||
f" on topic '{broadcast['topic']}' to {format_client_message(session=target_session)}",
|
||||
|
|
|
@ -598,7 +598,7 @@ class MQTTClient:
|
|||
session.cadata = broker_conf.get("cadata")
|
||||
|
||||
if cleansession is not None:
|
||||
broker_conf["cleansession"] = cleansession
|
||||
broker_conf["cleansession"] = cleansession # noop?
|
||||
session.clean_session = cleansession
|
||||
else:
|
||||
session.clean_session = self.config.get("cleansession", True)
|
||||
|
|
|
@ -192,7 +192,7 @@ class ConnectPayload(MQTTPayload[ConnectVariableHeader]):
|
|||
# A Server MAY allow a Client to supply a ClientId that has a length of zero bytes
|
||||
# [MQTT-3.1.3-6]
|
||||
payload.client_id = gen_client_id()
|
||||
# indicator to trow exception in case CLEAN_SESSION_FLAG is set to False
|
||||
# indicator to throw exception in case CLEAN_SESSION_FLAG is set to False
|
||||
payload.client_id_is_random = True
|
||||
|
||||
# Read will topic, username and password
|
||||
|
|
|
@ -26,6 +26,7 @@ class AnonymousAuthPlugin(BaseAuthPlugin):
|
|||
|
||||
if self._allow_anonymous:
|
||||
self.context.logger.debug("Authentication success: config allows anonymous")
|
||||
session.is_anonymous = True
|
||||
return True
|
||||
|
||||
if session and session.username:
|
||||
|
|
|
@ -262,6 +262,10 @@ class PluginManager(Generic[C]):
|
|||
def _schedule_coro(self, coro: Awaitable[str | bool | None]) -> asyncio.Future[str | bool | None]:
|
||||
return asyncio.ensure_future(coro)
|
||||
|
||||
def _clean_fired_events(self, future: asyncio.Future[Any]) -> None:
|
||||
with contextlib.suppress(KeyError, ValueError):
|
||||
self._fired_events.remove(future)
|
||||
|
||||
async def fire_event(self, event_name: Events, *, wait: bool = False, **method_kwargs: Any) -> None:
|
||||
"""Fire an event to plugins.
|
||||
|
||||
|
@ -287,12 +291,7 @@ class PluginManager(Generic[C]):
|
|||
|
||||
coro_instance: Awaitable[Any] = call_method(event_awaitable, method_kwargs)
|
||||
tasks.append(asyncio.ensure_future(coro_instance))
|
||||
|
||||
def clean_fired_events(future: asyncio.Future[Any]) -> None:
|
||||
with contextlib.suppress(KeyError, ValueError):
|
||||
self._fired_events.remove(future)
|
||||
|
||||
tasks[-1].add_done_callback(clean_fired_events)
|
||||
tasks[-1].add_done_callback(self._clean_fired_events)
|
||||
|
||||
self._fired_events.extend(tasks)
|
||||
if wait and tasks:
|
||||
|
|
|
@ -4,6 +4,7 @@ ping_delay: 1
|
|||
default_qos: 0
|
||||
default_retain: false
|
||||
auto_reconnect: true
|
||||
cleansession: true
|
||||
reconnect_max_interval: 10
|
||||
reconnect_retries: 2
|
||||
broker:
|
||||
|
|
|
@ -145,12 +145,15 @@ class Session:
|
|||
# Used to store incoming ApplicationMessage while publish protocol flows
|
||||
self.inflight_in: OrderedDict[int, IncomingApplicationMessage] = OrderedDict()
|
||||
|
||||
# Stores messages retained for this session
|
||||
# Stores messages retained for this session (specifically when the client is disconnected)
|
||||
self.retained_messages: Queue[ApplicationMessage] = Queue()
|
||||
|
||||
# Stores PUBLISH messages ID received in order and ready for application process
|
||||
self.delivered_message_queue: Queue[ApplicationMessage] = Queue()
|
||||
|
||||
# identify anonymous client sessions or clients which didn't identify themselves
|
||||
self.is_anonymous: bool = False
|
||||
|
||||
def _init_states(self) -> None:
|
||||
self.transitions = Machine(states=Session.states, initial="new")
|
||||
self.transitions.add_transition(
|
||||
|
|
|
@ -22,8 +22,7 @@ listener.
|
|||
|
||||
### `timeout-disconnect-delay` *(int)*
|
||||
|
||||
Client disconnect timeout without a keep-alive
|
||||
|
||||
Client disconnect timeout without a keep-alive.
|
||||
|
||||
### `plugins` *(mapping)*
|
||||
|
||||
|
|
|
@ -17,9 +17,9 @@ pytest_plugins = ["pytest_logdog"]
|
|||
|
||||
test_config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
"ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 10},
|
||||
"wss": {"type": "ws", "bind": "127.0.0.1:8081", "max_connections": 10},
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 15},
|
||||
"ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 15},
|
||||
"wss": {"type": "ws", "bind": "127.0.0.1:8081", "max_connections": 15},
|
||||
},
|
||||
"sys_interval": 0,
|
||||
"auth": {
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import logging.config
|
||||
import secrets
|
||||
import socket
|
||||
import string
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import psutil
|
||||
|
@ -22,8 +25,49 @@ from amqtt.mqtt.pubrec import PubrecPacket
|
|||
from amqtt.mqtt.pubrel import PubrelPacket
|
||||
from amqtt.session import OutgoingApplicationMessage
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
# formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
# logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
|
||||
LOGGING_CONFIG = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
|
||||
'formatters': {
|
||||
'default': {
|
||||
'format': '[%(asctime)s] %(levelname)s %(name)s: %(message)s',
|
||||
},
|
||||
},
|
||||
|
||||
'handlers': {
|
||||
'console': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'level': 'DEBUG',
|
||||
'formatter': 'default',
|
||||
'stream': 'ext://sys.stdout',
|
||||
}
|
||||
},
|
||||
|
||||
'root': {
|
||||
'handlers': ['console'],
|
||||
'level': 'DEBUG',
|
||||
},
|
||||
|
||||
'loggers': {
|
||||
'transitions': {
|
||||
'handlers': ['console'],
|
||||
'level': 'WARNING',
|
||||
'propagate': False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -101,10 +145,11 @@ async def test_connect_tcp(broker):
|
|||
connections_number = 10
|
||||
|
||||
# mqtt 3.1 requires a connect packet, otherwise the socket connection is rejected
|
||||
static_connect_packet = b'\x10\x1b\x00\x04MQTT\x04\x02\x00<\x00\x0ftest-client-123'
|
||||
|
||||
sockets = []
|
||||
for i in range(connections_number):
|
||||
static_connect_packet = b'\x10\x1b\x00\x04MQTT\x04\x02\x00<\x00\x0ftest-client-12' + f"{i}".encode()
|
||||
|
||||
s = socket.create_connection(("127.0.0.1", 1883))
|
||||
s.send(static_connect_packet)
|
||||
sockets.append(s)
|
||||
|
@ -122,9 +167,11 @@ async def test_connect_tcp(broker):
|
|||
tcp_connections = [conn for conn in connections if conn.laddr.port == 1883]
|
||||
assert len(tcp_connections) == connections_number + 1 # Including the Broker's listening socket
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
for conn in connections:
|
||||
assert conn.status in ("ESTABLISHED", "LISTEN")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
# close all connections
|
||||
for s in sockets:
|
||||
s.close()
|
||||
|
@ -626,35 +673,142 @@ async def test_client_subscribe_publish_dollar_topic_2(broker):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_retain_subscribe(broker):
|
||||
sub_client = MQTTClient()
|
||||
async def test_client_publish_clean_session_subscribe(broker):
|
||||
|
||||
sub_client = MQTTClient(client_id='test_client', config={'auto_reconnect': False})
|
||||
await sub_client.connect("mqtt://127.0.0.1", cleansession=False)
|
||||
ret = await sub_client.subscribe(
|
||||
[("/qos0", QOS_0), ("/qos1", QOS_1), ("/qos2", QOS_2)],
|
||||
)
|
||||
assert ret == [QOS_0, QOS_1, QOS_2]
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await _client_publish("/qos0", b"data", QOS_0, retain=True)
|
||||
await _client_publish("/qos1", b"data", QOS_1, retain=True)
|
||||
await _client_publish("/qos2", b"data", QOS_2, retain=True)
|
||||
await sub_client.reconnect()
|
||||
for qos in [QOS_0, QOS_1, QOS_2]:
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
await _client_publish("/qos0", b"data0", QOS_0) # should not be retained
|
||||
await _client_publish("/qos1", b"data1", QOS_1)
|
||||
await _client_publish("/qos2", b"data2", QOS_2)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
await sub_client.reconnect(cleansession=False)
|
||||
for qos in [QOS_1, QOS_2]:
|
||||
log.debug(f"TEST QOS: {qos}")
|
||||
message = await sub_client.deliver_message()
|
||||
message = await sub_client.deliver_message(timeout_duration=2)
|
||||
log.debug(f"Message: {message.publish_packet if message else None!r}")
|
||||
assert message is not None
|
||||
assert message.topic == f"/qos{qos}"
|
||||
assert message.data == b"data"
|
||||
assert message.data == f"data{qos}".encode("utf-8")
|
||||
assert message.qos == qos
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await sub_client.deliver_message(timeout_duration=1)
|
||||
assert message is not None, "no other messages should have been retained"
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_retain_with_new_subscribe(broker):
|
||||
await asyncio.sleep(2)
|
||||
sub_client1 = MQTTClient(client_id='test_client1')
|
||||
await sub_client1.connect("mqtt://127.0.0.1")
|
||||
|
||||
await sub_client1.disconnect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await _client_publish("/qos0", b"data0", QOS_0, retain=True)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
sub_client2 = MQTTClient(client_id='test_client2')
|
||||
await sub_client2.connect("mqtt://127.0.0.1")
|
||||
|
||||
# should receive the retained message on subscription
|
||||
ret = await sub_client2.subscribe(
|
||||
[("/qos0", QOS_0)],
|
||||
)
|
||||
assert ret == [QOS_0]
|
||||
|
||||
message = await sub_client2.deliver_message(timeout_duration=1)
|
||||
assert message is not None
|
||||
assert message.topic == "/qos0"
|
||||
assert message.data == b"data0"
|
||||
assert message.qos == QOS_0
|
||||
await sub_client2.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_retain_latest_with_new_subscribe(broker):
|
||||
await asyncio.sleep(2)
|
||||
sub_client1 = MQTTClient(client_id='test_client1')
|
||||
await sub_client1.connect("mqtt://127.0.0.1")
|
||||
|
||||
await sub_client1.disconnect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await _client_publish("/qos0", b"data a", QOS_0, retain=True)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
sub_client2 = MQTTClient(client_id='test_client2')
|
||||
await sub_client2.connect("mqtt://127.0.0.1")
|
||||
|
||||
await _client_publish("/qos0", b"data b", QOS_0, retain=True)
|
||||
|
||||
# should receive the retained message on subscription
|
||||
ret = await sub_client2.subscribe(
|
||||
[("/qos0", QOS_0)],
|
||||
)
|
||||
assert ret == [QOS_0]
|
||||
|
||||
message = await sub_client2.deliver_message(timeout_duration=1)
|
||||
assert message is not None
|
||||
assert message.topic == "/qos0"
|
||||
assert message.data == b"data b"
|
||||
assert message.qos == QOS_0
|
||||
await sub_client2.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_retain_subscribe_on_reconnect(broker):
|
||||
await asyncio.sleep(2)
|
||||
sub_client = MQTTClient(client_id='test_client')
|
||||
await sub_client.connect("mqtt://127.0.0.1", cleansession=False)
|
||||
ret = await sub_client.subscribe(
|
||||
[("/qos0", QOS_0)],
|
||||
)
|
||||
assert ret == [QOS_0]
|
||||
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await _client_publish("/qos0", b"data0", QOS_0, retain=True)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await sub_client.reconnect(cleansession=False)
|
||||
|
||||
message = await sub_client.deliver_message(timeout_duration=1)
|
||||
assert message is not None
|
||||
assert message.topic == "/qos0"
|
||||
assert message.data == b"data0"
|
||||
assert message.qos == QOS_0
|
||||
await sub_client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def _client_publish(topic, data, qos, retain=False) -> int | OutgoingApplicationMessage:
|
||||
pub_client = MQTTClient()
|
||||
|
||||
gen_id = "pub_"
|
||||
valid_chars = string.ascii_letters + string.digits
|
||||
gen_id += "".join(secrets.choice(valid_chars) for _ in range(16))
|
||||
|
||||
pub_client = MQTTClient(client_id=gen_id)
|
||||
ret: int | OutgoingApplicationMessage = await pub_client.connect("mqtt://127.0.0.1/")
|
||||
assert ret == 0
|
||||
ret = await pub_client.publish(topic, data, qos, retain)
|
||||
|
|
Ładowanie…
Reference in New Issue