kopia lustrzana https://github.com/Yakifo/amqtt
1190 wiersze
53 KiB
Python
1190 wiersze
53 KiB
Python
import asyncio
|
|
from asyncio import CancelledError, futures
|
|
from collections import deque
|
|
from collections.abc import Generator
|
|
from functools import partial
|
|
import logging
|
|
from math import floor
|
|
import re
|
|
import ssl
|
|
import time
|
|
from typing import Any, ClassVar, TypeAlias
|
|
|
|
from transitions import Machine, MachineError
|
|
import websockets.asyncio.server
|
|
from websockets.asyncio.server import ServerConnection
|
|
|
|
from amqtt.adapters import (
|
|
ReaderAdapter,
|
|
StreamReaderAdapter,
|
|
StreamWriterAdapter,
|
|
WebSocketsReader,
|
|
WebSocketsWriter,
|
|
WriterAdapter,
|
|
)
|
|
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType
|
|
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
|
|
from amqtt.mqtt3.protocol.broker_handler import BrokerProtocolHandler
|
|
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
|
from amqtt.utils import format_client_message, gen_client_id
|
|
|
|
from .events import BrokerEvents
|
|
from .mqtt3.constants import QOS_0, QOS_1, QOS_2
|
|
from .mqtt3.disconnect import DisconnectPacket
|
|
from .plugins.manager import PluginManager
|
|
|
|
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
|
|
|
|
# Default port numbers
|
|
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
|
|
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
|
|
|
|
|
|
class RetainedApplicationMessage(ApplicationMessage):
|
|
__slots__ = ("data", "qos", "source_session", "topic")
|
|
|
|
def __init__(self, source_session: Session | None, topic: str, data: bytes | bytearray, qos: int | None = None) -> None:
|
|
super().__init__(None, topic, qos, data, retain=True)
|
|
self.source_session = source_session
|
|
self.topic = topic
|
|
self.data = data
|
|
self.qos = qos
|
|
|
|
|
|
class Server:
|
|
"""Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle."""
|
|
|
|
def __init__(
|
|
self,
|
|
listener_name: str,
|
|
server_instance: asyncio.Server | websockets.asyncio.server.Server,
|
|
max_connections: int = -1,
|
|
) -> None:
|
|
self.logger = logging.getLogger(__name__)
|
|
self.instance = server_instance
|
|
self.conn_count = 0
|
|
self.listener_name = listener_name
|
|
self.max_connections = max_connections
|
|
self.semaphore = asyncio.Semaphore(max_connections) if max_connections > 0 else None
|
|
|
|
async def acquire_connection(self) -> None:
|
|
if self.semaphore:
|
|
await self.semaphore.acquire()
|
|
self.conn_count += 1
|
|
self.logger.info(
|
|
f"Listener '{self.listener_name}': {self.conn_count}/"
|
|
f"{self.max_connections if self.max_connections > 0 else '∞'} connections acquired",
|
|
)
|
|
|
|
def release_connection(self) -> None:
|
|
if self.semaphore:
|
|
self.semaphore.release()
|
|
self.conn_count -= 1
|
|
self.logger.info(
|
|
f"Listener '{self.listener_name}': {self.conn_count}/"
|
|
f"{self.max_connections if self.max_connections > 0 else '∞'} connections acquired",
|
|
)
|
|
|
|
async def close_instance(self) -> None:
|
|
if self.instance:
|
|
self.instance.close()
|
|
await self.instance.wait_closed()
|
|
|
|
|
|
class ExternalServer(Server):
|
|
"""For external listeners, the connection lifecycle is handled by that implementation so these are no-ops."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__("aiohttp", None) # type: ignore[arg-type]
|
|
|
|
async def acquire_connection(self) -> None:
|
|
pass
|
|
|
|
def release_connection(self) -> None:
|
|
pass
|
|
|
|
async def close_instance(self) -> None:
|
|
pass
|
|
|
|
|
|
class BrokerContext(BaseContext):
|
|
"""Used to provide the server's context as well as public methods for accessing internal state."""
|
|
|
|
def __init__(self, broker: "Broker") -> None:
|
|
super().__init__()
|
|
self.config: BrokerConfig | None = None
|
|
self._broker_instance = broker
|
|
|
|
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
|
"""Send message to all client sessions subscribing to `topic`."""
|
|
await self._broker_instance.internal_message_broadcast(topic, data, qos)
|
|
|
|
async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
|
|
await self._broker_instance.retain_message(None, topic_name, data, qos)
|
|
|
|
@property
|
|
def sessions(self) -> Generator[Session]:
|
|
for session in self._broker_instance.sessions.values():
|
|
yield session[0]
|
|
|
|
def get_session(self, client_id: str) -> Session | None:
|
|
"""Return the session associated with `client_id`, if it exists."""
|
|
return self._broker_instance.sessions.get(client_id, (None, None))[0]
|
|
|
|
@property
|
|
def retained_messages(self) -> dict[str, RetainedApplicationMessage]:
|
|
return self._broker_instance.retained_messages
|
|
|
|
@property
|
|
def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
|
|
return self._broker_instance.subscriptions
|
|
|
|
async def add_subscription(self, client_id: str, topic: str | None, qos: int | None) -> None:
|
|
"""Create a topic subscription for the given `client_id`.
|
|
|
|
If a client session doesn't exist for `client_id`, create a disconnected session.
|
|
If `topic` and `qos` are both `None`, only create the client session.
|
|
"""
|
|
if client_id not in self._broker_instance.sessions:
|
|
broker_handler, session = self._broker_instance.create_offline_session(client_id)
|
|
self._broker_instance._sessions[client_id] = (session, broker_handler) # noqa: SLF001
|
|
|
|
if topic is not None and qos is not None:
|
|
session, _ = self._broker_instance.sessions[client_id]
|
|
await self._broker_instance.add_subscription((topic, qos), session)
|
|
|
|
|
|
class Broker:
|
|
"""MQTT 3.1.1 compliant broker implementation.
|
|
|
|
Args:
|
|
config: `BrokerConfig` or dictionary of equivalent structure options (see [broker configuration](broker_config.md)).
|
|
loop: asyncio loop. defaults to `asyncio.new_event_loop()`.
|
|
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
|
|
|
|
Raises:
|
|
BrokerError: problem with broker configuration
|
|
PluginImportError: if importing a plugin from configuration
|
|
PluginInitError: if initialization plugin fails
|
|
|
|
"""
|
|
|
|
states: ClassVar[list[str]] = [
|
|
"new",
|
|
"starting",
|
|
"started",
|
|
"not_started",
|
|
"stopping",
|
|
"stopped",
|
|
"not_stopped",
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
config: BrokerConfig | dict[str, Any] | None = None,
|
|
loop: asyncio.AbstractEventLoop | None = None,
|
|
plugin_namespace: str | None = None,
|
|
) -> None:
|
|
"""Initialize the broker."""
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
if isinstance(config, dict):
|
|
self.config = BrokerConfig.from_dict(config)
|
|
else:
|
|
self.config = config or BrokerConfig()
|
|
|
|
# listeners are populated from default within BrokerConfig
|
|
self.listeners_config = self.config.listeners
|
|
|
|
self._loop = loop or asyncio.get_running_loop()
|
|
self._servers: dict[str, Server] = {}
|
|
self._init_states()
|
|
self._sessions: dict[str, tuple[Session, BrokerProtocolHandler]] = {}
|
|
self._subscriptions: dict[str, list[tuple[Session, int]]] = {}
|
|
self._retained_messages: dict[str, RetainedApplicationMessage] = {}
|
|
|
|
self._topic_filter_matchers: dict[str, re.Pattern[str]] = {}
|
|
|
|
# Broadcast queue for outgoing messages
|
|
self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
self._broadcast_task: asyncio.Task[Any] | None = None
|
|
self._broadcast_shutdown_waiter: asyncio.Future[Any] = futures.Future()
|
|
|
|
# Tasks queue for managing broadcasting tasks
|
|
self._tasks_queue: deque[asyncio.Task[OutgoingApplicationMessage]] = deque()
|
|
|
|
# Task for session monitor
|
|
self._session_monitor_task: asyncio.Task[Any] | None = None
|
|
|
|
# Initialize plugins manager
|
|
|
|
context = BrokerContext(self)
|
|
context.config = self.config
|
|
namespace = plugin_namespace or "amqtt.broker.plugins"
|
|
self.plugins_manager = PluginManager(namespace, context, self._loop)
|
|
|
|
def _init_states(self) -> None:
|
|
self.transitions = Machine(states=Broker.states, initial="new")
|
|
self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change)
|
|
self.transitions.add_transition(trigger="starting_fail", source="starting", dest="not_started")
|
|
self.transitions.add_transition(trigger="starting_success", source="starting", dest="started")
|
|
self.transitions.add_transition(trigger="shutdown", source="started", dest="stopping")
|
|
self.transitions.add_transition(trigger="stopping_success", source="stopping", dest="stopped")
|
|
self.transitions.add_transition(trigger="stopping_failure", source="stopping", dest="not_stopped")
|
|
self.transitions.add_transition(trigger="start", source="stopped", dest="starting")
|
|
|
|
def _log_state_change(self) -> None:
|
|
self.logger.debug(f"State transition: {self.transitions.state}")
|
|
|
|
async def start(self) -> None:
|
|
"""Start the broker to serve with the given configuration.
|
|
|
|
Start method opens network sockets and will start listening for incoming connections.
|
|
"""
|
|
try:
|
|
self._sessions.clear()
|
|
self._subscriptions.clear()
|
|
self._retained_messages.clear()
|
|
self.transitions.start()
|
|
self.logger.debug("Broker starting")
|
|
except (MachineError, ValueError) as exc:
|
|
# Backwards compat: MachineError is raised by transitions < 0.5.0.
|
|
self.logger.warning(f"[WARN-0001] Invalid method call at this moment: {exc}")
|
|
msg = f"Broker instance can't be started: {exc}"
|
|
raise BrokerError(msg) from exc
|
|
|
|
await self.plugins_manager.fire_event(BrokerEvents.PRE_START)
|
|
try:
|
|
await self._start_listeners()
|
|
self.transitions.starting_success()
|
|
await self.plugins_manager.fire_event(BrokerEvents.POST_START)
|
|
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
|
|
self._session_monitor_task = asyncio.create_task(self._session_monitor())
|
|
self.logger.debug("Broker started")
|
|
except Exception as e:
|
|
self.logger.exception("Broker startup failed")
|
|
self.transitions.starting_fail()
|
|
msg = f"Broker instance can't be started: {e}"
|
|
raise BrokerError(msg) from e
|
|
|
|
async def _start_listeners(self) -> None:
|
|
"""Start network listeners based on the configuration."""
|
|
for listener_name, listener in self.listeners_config.items():
|
|
if "bind" not in listener:
|
|
self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
|
|
continue
|
|
|
|
max_connections = listener.get("max_connections", -1)
|
|
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
|
|
|
|
# for listeners which are external, don't need to create a server
|
|
if listener.type == ListenerType.EXTERNAL:
|
|
|
|
# broker still needs to associate a new connection to the listener
|
|
self.logger.info(f"External listener exists for '{listener_name}' ")
|
|
self._servers[listener_name] = ExternalServer()
|
|
else:
|
|
# for tcp and websockets, start servers to listen for inbound connections
|
|
try:
|
|
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
|
|
except ValueError as e:
|
|
msg = f"Invalid port value in bind value: {listener['bind']}"
|
|
raise BrokerError(msg) from e
|
|
|
|
instance = await self._create_server_instance(listener_name, listener.type, address, port, ssl_context)
|
|
self._servers[listener_name] = Server(listener_name, instance, max_connections)
|
|
|
|
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
|
|
|
|
@staticmethod
|
|
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
|
|
"""Create an SSL context for a listener."""
|
|
try:
|
|
ssl_context = ssl.create_default_context(
|
|
ssl.Purpose.CLIENT_AUTH,
|
|
cafile=listener.get("cafile"),
|
|
capath=listener.get("capath"),
|
|
cadata=listener.get("cadata"),
|
|
)
|
|
ssl_context.load_cert_chain(listener["certfile"], listener["keyfile"])
|
|
ssl_context.verify_mode = ssl.CERT_OPTIONAL
|
|
except KeyError as ke:
|
|
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
|
|
raise BrokerError(msg) from ke
|
|
except FileNotFoundError as fnfe:
|
|
msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
|
|
raise BrokerError(msg) from fnfe
|
|
return ssl_context
|
|
|
|
async def _create_server_instance(
|
|
self,
|
|
listener_name: str,
|
|
listener_type: ListenerType,
|
|
address: str | None,
|
|
port: int,
|
|
ssl_context: ssl.SSLContext | None,
|
|
) -> asyncio.Server | websockets.asyncio.server.Server:
|
|
"""Create a server instance for a listener."""
|
|
match listener_type:
|
|
case ListenerType.TCP:
|
|
return await asyncio.start_server(
|
|
partial(self.stream_connected, listener_name=listener_name),
|
|
address,
|
|
port,
|
|
reuse_address=True,
|
|
ssl=ssl_context,
|
|
)
|
|
case ListenerType.WS:
|
|
return await websockets.serve(
|
|
partial(self.ws_connected, listener_name=listener_name),
|
|
address,
|
|
port,
|
|
ssl=ssl_context,
|
|
subprotocols=[websockets.Subprotocol("mqtt")],
|
|
)
|
|
case _:
|
|
msg = f"Unsupported listener type: {listener_type}"
|
|
raise BrokerError(msg)
|
|
|
|
async def _session_monitor(self) -> None:
|
|
|
|
self.logger.info("Starting session expiration monitor.")
|
|
|
|
while True:
|
|
|
|
session_count_before = len(self._sessions)
|
|
|
|
# clean or anonymous sessions don't retain messages (or subscriptions); the session can be filtered out
|
|
sessions_to_remove = [client_id for client_id, (session, _) in self._sessions.items()
|
|
if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session)]
|
|
|
|
# if session expiration is enabled, check to see if any of the sessions are disconnected and past expiration
|
|
if self.config.session_expiry_interval is not None:
|
|
retain_after = floor(time.time() - self.config.session_expiry_interval)
|
|
|
|
sessions_to_remove += [client_id for client_id, (session, _) in self._sessions.items()
|
|
if session.transitions.state == "disconnected" and
|
|
session.last_disconnect_time and
|
|
session.last_disconnect_time < retain_after]
|
|
|
|
for client_id in sessions_to_remove:
|
|
await self._cleanup_session(client_id)
|
|
|
|
if session_count_before > (session_count_after := len(self._sessions)):
|
|
self.logger.debug(f"Expired {session_count_before - session_count_after} sessions")
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
async def shutdown(self) -> None:
|
|
"""Stop broker instance."""
|
|
self.logger.info("Shutting down broker...")
|
|
# Fire broker_shutdown event to plugins
|
|
await self.plugins_manager.fire_event(BrokerEvents.PRE_SHUTDOWN)
|
|
|
|
# Cleanup all sessions
|
|
for client_id in list(self._sessions.keys()):
|
|
await self._cleanup_session(client_id)
|
|
|
|
# Clear retained messages
|
|
self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages")
|
|
self._retained_messages.clear()
|
|
|
|
self.transitions.shutdown()
|
|
|
|
await self._shutdown_broadcast_loop()
|
|
if self._session_monitor_task:
|
|
self._session_monitor_task.cancel()
|
|
|
|
for server in self._servers.values():
|
|
await server.close_instance()
|
|
|
|
if not self._broadcast_queue.empty():
|
|
self.logger.warning(f"{self._broadcast_queue.qsize()} messages not broadcasted")
|
|
# Clear the broadcast queue
|
|
while not self._broadcast_queue.empty():
|
|
self._broadcast_queue.get_nowait()
|
|
|
|
self.logger.info("Broker closed")
|
|
await self.plugins_manager.fire_event(BrokerEvents.POST_SHUTDOWN)
|
|
self.transitions.stopping_success()
|
|
|
|
async def _cleanup_session(self, client_id: str) -> None:
|
|
"""Centralized cleanup logic for a session."""
|
|
session, handler = self._sessions.pop(client_id, (None, None))
|
|
|
|
if handler:
|
|
self.logger.debug(f"Stopping handler for session {client_id}")
|
|
await self._stop_handler(handler)
|
|
if session:
|
|
self.logger.debug(f"Clearing all subscriptions for session {client_id}")
|
|
await self._del_all_subscriptions(session)
|
|
session.clear_queues()
|
|
|
|
async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
|
return await self._broadcast_message(None, topic, data, qos)
|
|
|
|
async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None:
|
|
await self._client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
|
|
|
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
|
|
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
|
|
|
async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
|
|
"""Engage the broker in handling the data stream to/from an established connection."""
|
|
await self._client_connected(listener_name, reader, writer)
|
|
|
|
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
|
|
"""Handle a new client connection."""
|
|
server = self._servers.get(listener_name)
|
|
if not server:
|
|
msg = f"Invalid listener name '{listener_name}'"
|
|
raise BrokerError(msg)
|
|
|
|
await server.acquire_connection()
|
|
remote_info = writer.get_peer_info()
|
|
if remote_info is None:
|
|
self.logger.warning("Remote info could not be retrieved from peer info")
|
|
return
|
|
|
|
remote_address, remote_port = remote_info
|
|
self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'")
|
|
|
|
try:
|
|
handler, client_session = await self._initialize_client_session(reader, writer, remote_address, remote_port)
|
|
except (AMQTTError, MQTTError, NoDataError) as exc:
|
|
self.logger.warning(f"Failed to initialize client session: {exc}")
|
|
server.release_connection()
|
|
return
|
|
|
|
try:
|
|
await self._handle_client_session(reader, writer, client_session, handler, server, listener_name)
|
|
except (AMQTTError, MQTTError, NoDataError) as exc:
|
|
self.logger.warning(f"Error while handling client session: {exc}")
|
|
finally:
|
|
self.logger.debug(f"{client_session.client_id} Client disconnected")
|
|
server.release_connection()
|
|
|
|
async def _initialize_client_session(
|
|
self,
|
|
reader: ReaderAdapter,
|
|
writer: WriterAdapter,
|
|
remote_address: str,
|
|
remote_port: int,
|
|
) -> tuple[BrokerProtocolHandler, Session]:
|
|
"""Initialize a client session and protocol handler."""
|
|
# Wait for first packet and expect a CONNECT
|
|
try:
|
|
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
|
|
except AMQTTError as exc:
|
|
self.logger.warning(
|
|
f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:"
|
|
f" Can't read first packet as CONNECT: {exc}",
|
|
)
|
|
raise AMQTTError(exc) from exc
|
|
except MQTTError as exc:
|
|
self.logger.exception(
|
|
f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}",
|
|
)
|
|
await writer.close()
|
|
raise MQTTError(exc) from exc
|
|
except NoDataError as exc:
|
|
self.logger.error( # noqa: TRY400
|
|
f"No data from {format_client_message(address=remote_address, port=remote_port)} : {exc}",
|
|
)
|
|
raise AMQTTError(exc) from exc
|
|
|
|
if client_session.clean_session:
|
|
# Delete existing session and create a new one
|
|
if client_session.client_id is not None and client_session.client_id != "":
|
|
await self._delete_session(client_session.client_id)
|
|
else:
|
|
client_session.client_id = gen_client_id()
|
|
|
|
client_session.parent = 0
|
|
# Get session from cache
|
|
elif client_session.client_id in self._sessions:
|
|
self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}")
|
|
|
|
# even though the session previously existed, the new connection can bring updated configuration and credentials
|
|
existing_client_session, _ = self._sessions[client_session.client_id]
|
|
existing_client_session.will_flag = client_session.will_flag
|
|
existing_client_session.will_message = client_session.will_message
|
|
existing_client_session.will_topic = client_session.will_topic
|
|
existing_client_session.will_qos = client_session.will_qos
|
|
existing_client_session.keep_alive = client_session.keep_alive
|
|
existing_client_session.username = client_session.username
|
|
existing_client_session.password = client_session.password
|
|
client_session = existing_client_session
|
|
client_session.parent = 1
|
|
else:
|
|
client_session.parent = 0
|
|
|
|
timeout_disconnect_delay = self.config.get("timeout-disconnect-delay", 0)
|
|
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
|
|
client_session.keep_alive += timeout_disconnect_delay
|
|
|
|
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
|
|
return handler, client_session
|
|
|
|
def create_offline_session(self, client_id: str) -> tuple[BrokerProtocolHandler, Session]:
|
|
session = Session()
|
|
session.client_id = client_id
|
|
|
|
bph = BrokerProtocolHandler(self.plugins_manager, session)
|
|
session.transitions.disconnect()
|
|
return bph, session
|
|
|
|
async def _handle_client_session(
|
|
self,
|
|
reader: ReaderAdapter,
|
|
writer: WriterAdapter,
|
|
client_session: Session,
|
|
handler: BrokerProtocolHandler,
|
|
server: Server,
|
|
listener_name: str,
|
|
) -> None:
|
|
"""Handle the lifecycle of a client session."""
|
|
authenticated = await self._authenticate(client_session, self.listeners_config[listener_name])
|
|
if not authenticated:
|
|
await writer.close()
|
|
return
|
|
|
|
if client_session.client_id is None:
|
|
msg = "Client ID was not correctly created/set."
|
|
raise BrokerError(msg)
|
|
|
|
while True:
|
|
try:
|
|
client_session.transitions.connect()
|
|
break
|
|
except (MachineError, ValueError):
|
|
if client_session.transitions.is_connected():
|
|
self.logger.warning(f"Client {client_session.client_id} is already connected, performing take-over.")
|
|
old_session = self._sessions[client_session.client_id]
|
|
await old_session[1].handle_connection_closed()
|
|
await old_session[1].stop()
|
|
break
|
|
self.logger.warning(f"Client {client_session.client_id} is reconnecting too quickly, make it wait")
|
|
await asyncio.sleep(1)
|
|
|
|
handler.attach(client_session, reader, writer)
|
|
self._sessions[client_session.client_id] = (client_session, handler)
|
|
|
|
await handler.mqtt_connack_authorize(authenticated)
|
|
await self.plugins_manager.fire_event(BrokerEvents.CLIENT_CONNECTED,
|
|
client_id=client_session.client_id,
|
|
client_session=client_session)
|
|
|
|
self.logger.debug(f"{client_session.client_id} Start messages handling")
|
|
await handler.start()
|
|
|
|
# publish messages that were retained because the client session was disconnected
|
|
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
|
|
await self._publish_session_retained_messages(client_session)
|
|
|
|
# if this is not a new session, there are subscriptions associated with them; publish any topic retained messages
|
|
self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.")
|
|
for topic in self._subscriptions:
|
|
await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session)
|
|
|
|
await self._client_message_loop(client_session, handler)
|
|
|
|
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
|
|
"""Run the main loop to handle client messages."""
|
|
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
|
|
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
|
|
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
|
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
|
|
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
|
|
connected = True
|
|
|
|
while connected:
|
|
try:
|
|
done, _ = await asyncio.wait(
|
|
[
|
|
disconnect_waiter,
|
|
subscribe_waiter,
|
|
unsubscribe_waiter,
|
|
wait_deliver,
|
|
],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
|
|
if disconnect_waiter in done:
|
|
# handle the disconnection: normal or abnormal result, either way, the client is no longer connected
|
|
await self._handle_disconnect(client_session, handler, disconnect_waiter)
|
|
connected = False
|
|
|
|
# no need to reschedule the `disconnect_waiter` since we're exiting the message loop
|
|
|
|
if subscribe_waiter in done:
|
|
await self._handle_subscription(client_session, handler, subscribe_waiter)
|
|
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
|
self.logger.debug(repr(self._subscriptions))
|
|
|
|
if unsubscribe_waiter in done:
|
|
await self._handle_unsubscription(client_session, handler, unsubscribe_waiter)
|
|
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
|
|
|
|
if wait_deliver in done:
|
|
if not await self._handle_message_delivery(client_session, handler, wait_deliver):
|
|
break
|
|
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
|
|
|
|
except asyncio.CancelledError:
|
|
self.logger.debug("Client loop cancelled")
|
|
break
|
|
|
|
disconnect_waiter.cancel()
|
|
subscribe_waiter.cancel()
|
|
unsubscribe_waiter.cancel()
|
|
wait_deliver.cancel()
|
|
|
|
async def _handle_disconnect(
|
|
self,
|
|
client_session: Session,
|
|
handler: BrokerProtocolHandler,
|
|
disconnect_waiter: asyncio.Future[Any],
|
|
) -> None:
|
|
"""Handle client disconnection.
|
|
|
|
Args:
|
|
client_session (Session): client session
|
|
handler (BrokerProtocolHandler): broker protocol handler
|
|
disconnect_waiter (asyncio.Future[Any]): future to wait for disconnection
|
|
|
|
"""
|
|
# check the disconnected waiter result
|
|
result = disconnect_waiter.result()
|
|
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
|
|
# if the client disconnects abruptly by sending no message or the message isn't a disconnect packet
|
|
if result is None or not isinstance(result, DisconnectPacket):
|
|
self.logger.debug(f"Will flag: {client_session.will_flag}")
|
|
if client_session.will_flag:
|
|
self.logger.debug(
|
|
f"Client {format_client_message(client_session)} disconnected abnormally, sending will message",
|
|
)
|
|
await self._broadcast_message(
|
|
client_session,
|
|
client_session.will_topic,
|
|
client_session.will_message,
|
|
client_session.will_qos,
|
|
)
|
|
if client_session.will_retain:
|
|
await self.retain_message(
|
|
client_session,
|
|
client_session.will_topic,
|
|
client_session.will_message,
|
|
client_session.will_qos,
|
|
)
|
|
|
|
# normal or not, let's end the client's session
|
|
self.logger.debug(f"{client_session.client_id} Disconnecting session")
|
|
await self._stop_handler(handler)
|
|
client_session.transitions.disconnect()
|
|
await self.plugins_manager.fire_event(BrokerEvents.CLIENT_DISCONNECTED,
|
|
client_id=client_session.client_id,
|
|
client_session=client_session)
|
|
|
|
async def _handle_subscription(
|
|
self,
|
|
client_session: Session,
|
|
handler: BrokerProtocolHandler,
|
|
subscribe_waiter: asyncio.Future[Any],
|
|
) -> None:
|
|
"""Handle client subscription."""
|
|
self.logger.debug(f"{client_session.client_id} handling subscription")
|
|
subscriptions = subscribe_waiter.result()
|
|
return_codes = [await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics]
|
|
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
|
|
for index, subscription in enumerate(subscriptions.topics):
|
|
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
|
|
await self.plugins_manager.fire_event(
|
|
BrokerEvents.CLIENT_SUBSCRIBED,
|
|
client_id=client_session.client_id,
|
|
topic=subscription[0],
|
|
qos=subscription[1],
|
|
)
|
|
await self._publish_retained_messages_for_subscription(subscription, client_session)
|
|
|
|
async def _handle_unsubscription(
|
|
self,
|
|
client_session: Session,
|
|
handler: BrokerProtocolHandler,
|
|
unsubscribe_waiter: asyncio.Future[Any],
|
|
) -> None:
|
|
"""Handle client unsubscription."""
|
|
self.logger.debug(f"{client_session.client_id} handling unsubscription")
|
|
unsubscription = unsubscribe_waiter.result()
|
|
for topic in unsubscription.topics:
|
|
self._del_subscription(topic, client_session)
|
|
await self.plugins_manager.fire_event(
|
|
BrokerEvents.CLIENT_UNSUBSCRIBED,
|
|
client_id=client_session.client_id,
|
|
topic=topic,
|
|
)
|
|
await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id)
|
|
|
|
async def _handle_message_delivery(
|
|
self,
|
|
client_session: Session,
|
|
handler: BrokerProtocolHandler,
|
|
wait_deliver: asyncio.Future[Any],
|
|
) -> bool:
|
|
"""Handle message delivery to the client."""
|
|
self.logger.debug(f"{client_session.client_id} handling message delivery")
|
|
app_message = wait_deliver.result()
|
|
|
|
# notify of a message's receipt, even if a client isn't necessarily allowed to send it
|
|
await self.plugins_manager.fire_event(
|
|
BrokerEvents.MESSAGE_RECEIVED,
|
|
client_id=client_session.client_id,
|
|
message=app_message,
|
|
)
|
|
|
|
if app_message is None:
|
|
self.logger.debug("app_message was empty!")
|
|
return True
|
|
if not app_message.topic:
|
|
self.logger.warning(
|
|
f"[MQTT-4.7.3-1] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
|
|
)
|
|
return False
|
|
if "#" in app_message.topic or "+" in app_message.topic:
|
|
self.logger.warning(
|
|
f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
|
|
)
|
|
return False
|
|
if app_message.topic.startswith("$"):
|
|
self.logger.warning(
|
|
f"[MQTT-4.7.2-1] - {client_session.client_id} cannot use a topic with a leading $ character."
|
|
)
|
|
return False
|
|
|
|
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
|
|
if not permitted:
|
|
self.logger.info(f"{client_session.client_id} not allowed to publish to TOPIC {app_message.topic}.")
|
|
else:
|
|
# notify that a received message is valid and is allowed to be distributed to other clients
|
|
await self.plugins_manager.fire_event(
|
|
BrokerEvents.MESSAGE_BROADCAST,
|
|
client_id=client_session.client_id,
|
|
message=app_message,
|
|
)
|
|
await self._broadcast_message(client_session, app_message.topic, app_message.data)
|
|
if app_message.publish_packet and app_message.publish_packet.retain_flag:
|
|
await self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
|
return True
|
|
|
|
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
|
|
"""Create a BrokerProtocolHandler and attach to a session."""
|
|
handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
|
|
handler.attach(session, reader, writer)
|
|
return handler
|
|
|
|
async def _stop_handler(self, handler: BrokerProtocolHandler) -> None:
|
|
"""Stop a running handler and detach if from the session."""
|
|
try:
|
|
await handler.stop()
|
|
# a failure in stopping a handler shouldn't cause the broker to fail
|
|
except asyncio.QueueEmpty:
|
|
self.logger.exception("Failed to stop handler")
|
|
|
|
async def _authenticate(self, session: Session, _: ListenerConfig) -> bool:
|
|
"""Call the authenticate method on registered plugins to test user authentication.
|
|
|
|
User is considered authenticated if all plugins called returns True.
|
|
Plugins authenticate() method are supposed to return :
|
|
- True if user is authentication succeed
|
|
- False if user authentication fails
|
|
- None if authentication can't be achieved (then plugin result is then ignored)
|
|
:param session:
|
|
:return:
|
|
"""
|
|
returns = await self.plugins_manager.map_plugin_auth(session=session)
|
|
|
|
results = [result for _, result in returns.items() if result is not None] if returns else []
|
|
if len(results) < 1:
|
|
self.logger.debug("Authentication failed: no plugin responded with a boolean")
|
|
return False
|
|
|
|
if all(results):
|
|
self.logger.debug("Authentication succeeded")
|
|
return True
|
|
|
|
for plugin, result in returns.items():
|
|
self.logger.debug(f"Authentication '{plugin.__class__.__name__}' result: {result}")
|
|
|
|
return False
|
|
|
|
async def retain_message(
|
|
self,
|
|
source_session: Session | None,
|
|
topic_name: str | None,
|
|
data: bytes | bytearray | None,
|
|
qos: int | None = None,
|
|
) -> None:
|
|
if data and topic_name is not None:
|
|
# If retained flag set, store the message for further subscriptions
|
|
self.logger.debug(f"Retaining message on topic {topic_name}")
|
|
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
|
|
|
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
|
client_id=None,
|
|
retained_message=self._retained_messages[topic_name])
|
|
|
|
# [MQTT-3.3.1-10]
|
|
elif topic_name in self._retained_messages:
|
|
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
|
|
|
|
cleared_message = self._retained_messages[topic_name]
|
|
cleared_message.data = b""
|
|
|
|
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
|
client_id=None,
|
|
retained_message=cleared_message)
|
|
|
|
del self._retained_messages[topic_name]
|
|
|
|
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
|
|
topic_filter, qos = subscription
|
|
if "#" in topic_filter and not topic_filter.endswith("#"):
|
|
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
|
|
return 0x80
|
|
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
|
|
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
|
|
return 0x80
|
|
# Check if the client is authorised to connect to the topic
|
|
if not await self._topic_filtering(session, topic_filter, Action.SUBSCRIBE):
|
|
return 0x80
|
|
|
|
# Ensure "max-qos" is an integer before using it
|
|
max_qos = self.config.get("max-qos", qos)
|
|
if not isinstance(max_qos, int):
|
|
max_qos = qos
|
|
|
|
qos = min(qos, max_qos)
|
|
if topic_filter not in self._subscriptions:
|
|
self._subscriptions[topic_filter] = []
|
|
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
|
|
self._subscriptions[topic_filter].append((session, qos))
|
|
else:
|
|
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
|
|
return qos
|
|
|
|
async def _topic_filtering(self, session: Session, topic: str, action: Action) -> bool:
|
|
"""Call the topic_filtering method on registered plugins to check that the subscription is allowed.
|
|
|
|
User is considered allowed if all plugins called return True.
|
|
Plugins topic_filtering() method are supposed to return :
|
|
- True if MQTT client can be subscribed to the topic
|
|
- False if MQTT client is not allowed to subscribe to the topic
|
|
- None if topic filtering can't be achieved (then plugin result is then ignored)
|
|
:param session:
|
|
:param topic: Topic in which the client wants to subscribe / publish
|
|
:param action: What is being done with the topic? subscribe or publish
|
|
:return:
|
|
"""
|
|
if not self.plugins_manager.is_topic_filtering_enabled():
|
|
return True
|
|
|
|
results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
|
|
return all(result for result in results.values())
|
|
|
|
async def _delete_session(self, client_id: str) -> None:
|
|
"""Delete an existing session data, for example due to clean session set in CONNECT."""
|
|
session = self._sessions.pop(client_id, (None, None))[0]
|
|
|
|
if session is None:
|
|
self.logger.debug(f"Delete session : session {client_id} doesn't exist")
|
|
return
|
|
self.logger.debug(f"Deleted existing session {session!r}")
|
|
|
|
# Delete subscriptions
|
|
self.logger.debug(f"Deleting session {session!r} subscriptions")
|
|
await self._del_all_subscriptions(session)
|
|
session.clear_queues()
|
|
|
|
async def _del_all_subscriptions(self, session: Session) -> None:
|
|
"""Delete all topic subscriptions for a given session."""
|
|
filter_queue: deque[str] = deque()
|
|
for topic in self._subscriptions:
|
|
if self._del_subscription(topic, session):
|
|
filter_queue.append(topic)
|
|
for topic in filter_queue:
|
|
if not self._subscriptions[topic]:
|
|
del self._subscriptions[topic]
|
|
|
|
def _del_subscription(self, a_filter: str, session: Session) -> int:
|
|
"""Delete a session subscription on a given topic.
|
|
|
|
:param a_filter: The topic filter for the subscription.
|
|
:param session: The session to be unsubscribed.
|
|
:return: The number of deleted subscriptions (0 or 1).
|
|
"""
|
|
deleted = 0
|
|
try:
|
|
subscriptions = self._subscriptions[a_filter]
|
|
for index, (sub_session, _qos) in enumerate(subscriptions):
|
|
if sub_session.client_id == session.client_id:
|
|
self.logger.debug(
|
|
f"Removing subscription on topic '{a_filter}' for client {format_client_message(session=session)}",
|
|
)
|
|
subscriptions.pop(index)
|
|
deleted += 1
|
|
break
|
|
except KeyError:
|
|
self.logger.debug(f"Unsubscription on topic '{a_filter}' for client {format_client_message(session=session)}")
|
|
|
|
return deleted
|
|
|
|
async def _broadcast_loop(self) -> None:
|
|
"""Run the main loop to broadcast messages."""
|
|
running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = self._tasks_queue
|
|
|
|
try:
|
|
while True:
|
|
while running_tasks and running_tasks[0].done():
|
|
task = running_tasks.popleft()
|
|
try:
|
|
task.result()
|
|
except CancelledError:
|
|
self.logger.info(f"Task has been cancelled: {task}")
|
|
# if a task fails, don't want it to cause the broker to fail
|
|
except Exception: # pylint: disable=W0718
|
|
self.logger.exception(f"Task failed and will be skipped: {task}")
|
|
|
|
run_broadcast_task = asyncio.ensure_future(self._run_broadcast(running_tasks))
|
|
|
|
completed, _ = await asyncio.wait(
|
|
[run_broadcast_task, self._broadcast_shutdown_waiter],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
|
|
# Shutdown has been triggered by the broker, so stop the loop execution
|
|
if self._broadcast_shutdown_waiter in completed:
|
|
run_broadcast_task.cancel()
|
|
break
|
|
|
|
except BaseException:
|
|
self.logger.exception("Broadcast loop stopped by exception")
|
|
raise
|
|
finally:
|
|
# Wait until current broadcasting tasks end
|
|
if running_tasks:
|
|
await asyncio.gather(*running_tasks)
|
|
|
|
async def _run_broadcast(self, running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]]) -> None:
|
|
"""Process a single broadcast message."""
|
|
broadcast = await self._broadcast_queue.get()
|
|
|
|
self.logger.debug(f"Processing broadcast message: {broadcast}")
|
|
|
|
for k_filter, subscriptions in self._subscriptions.items():
|
|
|
|
# Skip all subscriptions which do not match the topic
|
|
if not self._matches(broadcast["topic"], k_filter):
|
|
self.logger.debug(f"Topic '{broadcast['topic']}' does not match filter '{k_filter}'")
|
|
continue
|
|
|
|
for target_session, sub_qos in subscriptions:
|
|
qos = broadcast.get("qos", sub_qos)
|
|
|
|
sendable = await self._topic_filtering(target_session, topic=broadcast["topic"], action=Action.RECEIVE)
|
|
if not sendable:
|
|
self.logger.info(
|
|
f"{target_session.client_id} not allowed to receive messages from TOPIC {broadcast['topic']}.")
|
|
continue
|
|
|
|
# Retain all messages which cannot be broadcasted, due to the session not being connected
|
|
# but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]
|
|
# and, if a client used anonymous authentication, there is no expectation that messages should be retained
|
|
if (target_session.transitions.state != "connected"
|
|
and not target_session.clean_session
|
|
and qos in (QOS_1, QOS_2)
|
|
and not target_session.is_anonymous):
|
|
self.logger.debug(f"Session {target_session.client_id} is not connected, retaining message.")
|
|
await self._retain_broadcast_message(broadcast, qos, target_session)
|
|
continue
|
|
|
|
# Only broadcast the message to connected clients
|
|
if target_session.transitions.state != "connected":
|
|
continue
|
|
|
|
self.logger.debug(
|
|
f"Broadcasting message from {format_client_message(session=broadcast['session'])}"
|
|
f" on topic '{broadcast['topic']}' to {format_client_message(session=target_session)}",
|
|
)
|
|
|
|
handler = self._get_handler(target_session)
|
|
if handler:
|
|
task = asyncio.ensure_future(
|
|
handler.mqtt_publish(
|
|
broadcast["topic"],
|
|
broadcast["data"],
|
|
qos,
|
|
retain=False,
|
|
),
|
|
)
|
|
running_tasks.append(task)
|
|
|
|
async def _retain_broadcast_message(self, broadcast: dict[str, Any], qos: int, target_session: Session) -> None:
|
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
self.logger.debug(
|
|
f"retaining application message from {format_client_message(session=broadcast['session'])}"
|
|
f" on topic '{broadcast['topic']}' to client '{format_client_message(session=target_session)}'",
|
|
)
|
|
|
|
retained_message = RetainedApplicationMessage(broadcast["session"], broadcast["topic"], broadcast["data"], qos)
|
|
await target_session.retained_messages.put(retained_message)
|
|
|
|
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
|
client_id=target_session.client_id,
|
|
retained_message=retained_message)
|
|
|
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}")
|
|
|
|
async def _shutdown_broadcast_loop(self) -> None:
|
|
if self._broadcast_task and not self._broadcast_shutdown_waiter.done():
|
|
self._broadcast_shutdown_waiter.set_result(True)
|
|
try:
|
|
await asyncio.wait_for(self._broadcast_task, timeout=30)
|
|
except TimeoutError as e:
|
|
self.logger.warning(f"Failed to cleanly shutdown broadcast loop: {e}")
|
|
|
|
if not self._broadcast_queue.empty():
|
|
self.logger.warning(f"{self._broadcast_queue.qsize()} messages not broadcasted")
|
|
|
|
self._broadcast_shutdown_waiter = asyncio.Future()
|
|
|
|
async def _broadcast_message(
|
|
self,
|
|
session: Session | None,
|
|
topic: str | None,
|
|
data: bytes | bytearray | None,
|
|
force_qos: int | None = None,
|
|
) -> None:
|
|
broadcast: _BROADCAST = {"session": session, "topic": topic, "data": data}
|
|
if force_qos is not None:
|
|
broadcast["qos"] = force_qos
|
|
await self._broadcast_queue.put(broadcast)
|
|
|
|
async def _publish_session_retained_messages(self, session: Session) -> None:
|
|
self.logger.debug(
|
|
f"Publishing {session.retained_messages.qsize()}"
|
|
f" messages retained for session {format_client_message(session=session)}",
|
|
)
|
|
publish_tasks = []
|
|
handler = self._get_handler(session)
|
|
if handler:
|
|
while not session.retained_messages.empty():
|
|
retained = await session.retained_messages.get()
|
|
publish_tasks.append(
|
|
asyncio.ensure_future(
|
|
handler.mqtt_publish(retained.topic, retained.data, retained.qos, retain=True),
|
|
),
|
|
)
|
|
if publish_tasks:
|
|
await asyncio.wait(publish_tasks)
|
|
|
|
async def _publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None:
|
|
self.logger.debug(
|
|
f"Begin broadcasting messages retained due to subscription on '{subscription[0]}'"
|
|
f" from {format_client_message(session=session)}",
|
|
)
|
|
publish_tasks = []
|
|
|
|
topic_filter, qos = subscription
|
|
for topic, retained in self._retained_messages.items():
|
|
self.logger.debug(f"matching : {topic} {topic_filter}")
|
|
if self._matches(topic, topic_filter):
|
|
self.logger.debug(f"{topic} and {topic_filter} match")
|
|
handler = self._get_handler(session)
|
|
if handler:
|
|
publish_tasks.append(
|
|
asyncio.Task(
|
|
handler.mqtt_publish(retained.topic, retained.data, min(qos, retained.qos or qos), retain=True),
|
|
),
|
|
)
|
|
if publish_tasks:
|
|
await asyncio.wait(publish_tasks)
|
|
self.logger.debug(
|
|
f"End broadcasting messages retained due to subscription on '{subscription[0]}'"
|
|
f" from {format_client_message(session=session)}",
|
|
)
|
|
|
|
def _matches(self, topic: str, a_filter: str) -> bool:
|
|
if topic.startswith("$") and (a_filter.startswith(("+", "#"))):
|
|
self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
|
|
return False
|
|
|
|
if "#" not in a_filter and "+" not in a_filter:
|
|
# if filter doesn't contain wildcard, return exact match
|
|
return a_filter == topic
|
|
|
|
# else use regex (re.compile is an expensive operation, store the matcher for future use)
|
|
if a_filter not in self._topic_filter_matchers:
|
|
self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter)
|
|
.replace("\\#", "?.*")
|
|
.replace("\\+", "[^/]*")
|
|
.lstrip("?"))
|
|
match_pattern = self._topic_filter_matchers[a_filter]
|
|
return bool(match_pattern.fullmatch(topic))
|
|
|
|
def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:
|
|
client_id = session.client_id
|
|
if client_id:
|
|
return self._sessions.get(client_id, (None, None))[1]
|
|
return None
|
|
|
|
@classmethod
|
|
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
|
|
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
|
|
|
|
- Address can be specified using one of the following methods:
|
|
- empty string - all interfaces default port
|
|
- 1883 - Port number only (listen all interfaces)
|
|
- :1883 - Port number only (listen all interfaces)
|
|
- 0.0.0.0:1883 - IPv4 address
|
|
- [::]:1883 - IPv6 address
|
|
"""
|
|
|
|
def _parse_port(port_str: str) -> int:
|
|
port_str = port_str.removeprefix(":")
|
|
|
|
if not port_str:
|
|
return default_port
|
|
|
|
return int(port_str)
|
|
|
|
if port_str.startswith("["): # IPv6 literal
|
|
try:
|
|
addr_end = port_str.index("]")
|
|
except ValueError as e:
|
|
msg = "Expecting '[' to be followed by ']'"
|
|
raise ValueError(msg) from e
|
|
|
|
return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
|
|
|
|
if ":" in port_str:
|
|
address, port_str = port_str.rsplit(":", 1)
|
|
return (address or None, _parse_port(port_str))
|
|
|
|
try:
|
|
return (None, _parse_port(port_str))
|
|
except ValueError:
|
|
return (port_str, default_port)
|
|
|
|
@property
|
|
def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
|
|
return self._subscriptions
|
|
|
|
@property
|
|
def retained_messages(self) -> dict[str, RetainedApplicationMessage]:
|
|
return self._retained_messages
|
|
|
|
@property
|
|
def sessions(self) -> dict[str, tuple[Session, BrokerProtocolHandler]]:
|
|
return self._sessions
|