|
|
|
@ -242,72 +242,10 @@ class Broker:
|
|
|
|
|
|
|
|
|
|
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
|
|
|
|
|
try:
|
|
|
|
|
# Start network listeners
|
|
|
|
|
for listener_name, listener in self.listeners_config.items():
|
|
|
|
|
if "bind" not in listener:
|
|
|
|
|
self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
max_connections = listener.get("max_connections", -1)
|
|
|
|
|
|
|
|
|
|
ssl_conection = None
|
|
|
|
|
ssl_active = listener.get("ssl", False) # accept string "on" / "off" or boolean
|
|
|
|
|
if isinstance(ssl_active, str):
|
|
|
|
|
ssl_active = ssl_active.upper() == "ON"
|
|
|
|
|
|
|
|
|
|
if ssl_active:
|
|
|
|
|
try:
|
|
|
|
|
ssl_conection = ssl.create_default_context(
|
|
|
|
|
ssl.Purpose.CLIENT_AUTH,
|
|
|
|
|
cafile=listener.get("cafile"),
|
|
|
|
|
capath=listener.get("capath"),
|
|
|
|
|
cadata=listener.get("cadata"),
|
|
|
|
|
)
|
|
|
|
|
ssl_conection.load_cert_chain(listener["certfile"], listener["keyfile"])
|
|
|
|
|
ssl_conection.verify_mode = ssl.CERT_OPTIONAL
|
|
|
|
|
except KeyError as ke:
|
|
|
|
|
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
|
|
|
|
|
raise BrokerError(msg) from ke
|
|
|
|
|
except FileNotFoundError as fnfe:
|
|
|
|
|
msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
|
|
|
|
|
raise BrokerError(msg) from fnfe
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
msg = f"Invalid port value in bind value: {listener['bind']}"
|
|
|
|
|
raise BrokerError(msg) from e
|
|
|
|
|
|
|
|
|
|
instance: asyncio.Server | websockets.asyncio.server.Server | None = None
|
|
|
|
|
if listener["type"] == "tcp":
|
|
|
|
|
cb_partial = partial(self.stream_connected, listener_name=listener_name)
|
|
|
|
|
instance = await asyncio.start_server(
|
|
|
|
|
cb_partial,
|
|
|
|
|
address,
|
|
|
|
|
port,
|
|
|
|
|
reuse_address=True,
|
|
|
|
|
ssl=ssl_conection,
|
|
|
|
|
)
|
|
|
|
|
self._servers[listener_name] = Server(listener_name, instance, max_connections)
|
|
|
|
|
elif listener["type"] == "ws":
|
|
|
|
|
cb_partial = partial(self.ws_connected, listener_name=listener_name)
|
|
|
|
|
instance = await websockets.serve(
|
|
|
|
|
cb_partial,
|
|
|
|
|
address,
|
|
|
|
|
port,
|
|
|
|
|
ssl=ssl_conection,
|
|
|
|
|
subprotocols=[websockets.Subprotocol("mqtt")],
|
|
|
|
|
)
|
|
|
|
|
self._servers[listener_name] = Server(listener_name, instance, max_connections)
|
|
|
|
|
|
|
|
|
|
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
|
|
|
|
|
|
|
|
|
|
await self._start_listeners()
|
|
|
|
|
self.transitions.starting_success()
|
|
|
|
|
await self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
|
|
|
|
|
|
|
|
|
|
# Start broadcast loop
|
|
|
|
|
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
|
|
|
|
|
|
|
|
|
|
self.logger.debug("Broker started")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
self.logger.exception("Broker startup failed")
|
|
|
|
@ -315,48 +253,128 @@ class Broker:
|
|
|
|
|
msg = f"Broker instance can't be started: {e}"
|
|
|
|
|
raise BrokerError(msg) from e
|
|
|
|
|
|
|
|
|
|
async def _start_listeners(self) -> None:
|
|
|
|
|
"""Start network listeners based on the configuration."""
|
|
|
|
|
for listener_name, listener in self.listeners_config.items():
|
|
|
|
|
if "bind" not in listener:
|
|
|
|
|
self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
max_connections = listener.get("max_connections", -1)
|
|
|
|
|
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
msg = f"Invalid port value in bind value: {listener['bind']}"
|
|
|
|
|
raise BrokerError(msg) from e
|
|
|
|
|
|
|
|
|
|
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
|
|
|
|
|
self._servers[listener_name] = Server(listener_name, instance, max_connections)
|
|
|
|
|
|
|
|
|
|
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
|
|
|
|
|
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
|
|
|
|
|
|
|
|
|
|
- Address can be specified using one of the following methods:
|
|
|
|
|
- empty string - all interfaces default port
|
|
|
|
|
- 1883 - Port number only (listen all interfaces)
|
|
|
|
|
- :1883 - Port number only (listen all interfaces)
|
|
|
|
|
- 0.0.0.0:1883 - IPv4 address
|
|
|
|
|
- [::]:1883 - IPv6 address
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _parse_port(port_str: str) -> int:
|
|
|
|
|
port_str = port_str.removeprefix(":")
|
|
|
|
|
|
|
|
|
|
if not port_str:
|
|
|
|
|
return default_port
|
|
|
|
|
|
|
|
|
|
return int(port_str)
|
|
|
|
|
|
|
|
|
|
if port_str.startswith("["): # IPv6 literal
|
|
|
|
|
try:
|
|
|
|
|
addr_end = port_str.index("]")
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
msg = "Expecting '[' to be followed by ']'"
|
|
|
|
|
raise ValueError(msg) from e
|
|
|
|
|
|
|
|
|
|
return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
|
|
|
|
|
|
|
|
|
|
if ":" in port_str:
|
|
|
|
|
address, port_str = port_str.rsplit(":", 1)
|
|
|
|
|
return (address or None, _parse_port(port_str))
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
return (None, _parse_port(port_str))
|
|
|
|
|
except ValueError:
|
|
|
|
|
return (port_str, default_port)
|
|
|
|
|
|
|
|
|
|
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext:
|
|
|
|
|
"""Create an SSL context for a listener."""
|
|
|
|
|
try:
|
|
|
|
|
ssl_context = ssl.create_default_context(
|
|
|
|
|
ssl.Purpose.CLIENT_AUTH,
|
|
|
|
|
cafile=listener.get("cafile"),
|
|
|
|
|
capath=listener.get("capath"),
|
|
|
|
|
cadata=listener.get("cadata"),
|
|
|
|
|
)
|
|
|
|
|
ssl_context.load_cert_chain(listener["certfile"], listener["keyfile"])
|
|
|
|
|
ssl_context.verify_mode = ssl.CERT_OPTIONAL
|
|
|
|
|
return ssl_context
|
|
|
|
|
except KeyError as ke:
|
|
|
|
|
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
|
|
|
|
|
raise BrokerError(msg) from ke
|
|
|
|
|
except FileNotFoundError as fnfe:
|
|
|
|
|
msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
|
|
|
|
|
raise BrokerError(msg) from fnfe
|
|
|
|
|
|
|
|
|
|
async def _create_server_instance(
|
|
|
|
|
self,
|
|
|
|
|
listener_name: str,
|
|
|
|
|
listener_type: str,
|
|
|
|
|
address: str | None,
|
|
|
|
|
port: int,
|
|
|
|
|
ssl_context: ssl.SSLContext | None,
|
|
|
|
|
) -> asyncio.Server | websockets.asyncio.server.Server:
|
|
|
|
|
"""Create a server instance for a listener."""
|
|
|
|
|
if listener_type == "tcp":
|
|
|
|
|
return await asyncio.start_server(
|
|
|
|
|
partial(self.stream_connected, listener_name=listener_name),
|
|
|
|
|
address,
|
|
|
|
|
port,
|
|
|
|
|
reuse_address=True,
|
|
|
|
|
ssl=ssl_context,
|
|
|
|
|
)
|
|
|
|
|
if listener_type == "ws":
|
|
|
|
|
return await websockets.serve(
|
|
|
|
|
partial(self.ws_connected, listener_name=listener_name),
|
|
|
|
|
address,
|
|
|
|
|
port,
|
|
|
|
|
ssl=ssl_context,
|
|
|
|
|
subprotocols=[websockets.Subprotocol("mqtt")],
|
|
|
|
|
)
|
|
|
|
|
msg = f"Unsupported listener type: {listener_type}"
|
|
|
|
|
raise BrokerError(msg)
|
|
|
|
|
|
|
|
|
|
async def shutdown(self) -> None:
|
|
|
|
|
"""Stop broker instance."""
|
|
|
|
|
try:
|
|
|
|
|
# # Wait for all in-flight tasks to complete before stopping session handlers
|
|
|
|
|
# for client_id, (session, handler) in self._sessions.items():
|
|
|
|
|
# if handler:
|
|
|
|
|
# self.logger.debug(f"Waiting for in-flight tasks to complete for session {client_id}")
|
|
|
|
|
# if session.inflight_out:
|
|
|
|
|
# # Directly use asyncio.sleep or another async operation in the loop
|
|
|
|
|
# await asyncio.gather(
|
|
|
|
|
# *(asyncio.sleep(0) for _ in session.inflight_out.values()),
|
|
|
|
|
# return_exceptions=True,
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
# Stop all session handlers
|
|
|
|
|
for client_id, (_, handler) in self._sessions.items():
|
|
|
|
|
if handler:
|
|
|
|
|
self.logger.debug(f"Stopping handler for session {client_id}")
|
|
|
|
|
await self._stop_handler(handler)
|
|
|
|
|
|
|
|
|
|
# Clear subscriptions
|
|
|
|
|
for topic, subscriptions in self._subscriptions.items():
|
|
|
|
|
self.logger.debug(f"Clearing subscriptions for topic '{topic}'")
|
|
|
|
|
for session, _ in subscriptions:
|
|
|
|
|
self._del_subscription(topic, session)
|
|
|
|
|
self._subscriptions.clear()
|
|
|
|
|
|
|
|
|
|
# Clear retained messages
|
|
|
|
|
if self._retained_messages:
|
|
|
|
|
self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages")
|
|
|
|
|
self._retained_messages.clear()
|
|
|
|
|
|
|
|
|
|
self._sessions.clear()
|
|
|
|
|
self.transitions.shutdown()
|
|
|
|
|
except (MachineError, ValueError) as exc:
|
|
|
|
|
# Backwards compat: MachineError is raised by transitions < 0.5.0.
|
|
|
|
|
self.logger.debug(f"Invalid method call at this moment: {exc}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
self.logger.info("Shutting down broker...")
|
|
|
|
|
# Fire broker_shutdown event to plugins
|
|
|
|
|
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
|
|
|
|
|
|
|
|
|
|
# Cleanup all sessions
|
|
|
|
|
for client_id in list(self._sessions.keys()):
|
|
|
|
|
await self._cleanup_session(client_id)
|
|
|
|
|
|
|
|
|
|
# Clear retained messages
|
|
|
|
|
self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages")
|
|
|
|
|
self._retained_messages.clear()
|
|
|
|
|
|
|
|
|
|
self.transitions.shutdown()
|
|
|
|
|
|
|
|
|
|
await self._shutdown_broadcast_loop()
|
|
|
|
|
|
|
|
|
|
for server in self._servers.values():
|
|
|
|
@ -372,59 +390,89 @@ class Broker:
|
|
|
|
|
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
|
|
|
|
|
self.transitions.stopping_success()
|
|
|
|
|
|
|
|
|
|
async def _cleanup_session(self, client_id: str) -> None:
|
|
|
|
|
"""Centralized cleanup logic for a session."""
|
|
|
|
|
session, handler = self._sessions.pop(client_id, (None, None))
|
|
|
|
|
|
|
|
|
|
if handler:
|
|
|
|
|
self.logger.debug(f"Stopping handler for session {client_id}")
|
|
|
|
|
await self._stop_handler(handler)
|
|
|
|
|
if session:
|
|
|
|
|
self.logger.debug(f"Clearing all subscriptions for session {client_id}")
|
|
|
|
|
await self._del_all_subscriptions(session)
|
|
|
|
|
session.clear_queues()
|
|
|
|
|
|
|
|
|
|
async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
|
|
|
|
return await self._broadcast_message(None, topic, data, qos)
|
|
|
|
|
|
|
|
|
|
async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None:
|
|
|
|
|
await self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
|
|
|
|
await self._client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
|
|
|
|
|
|
|
|
|
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
|
|
|
|
|
await self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
|
|
|
|
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
|
|
|
|
|
|
|
|
|
async def client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
|
|
|
|
|
# Wait for connection available on listener
|
|
|
|
|
server = self._servers.get(listener_name, None)
|
|
|
|
|
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
|
|
|
|
|
"""Handle a new client connection."""
|
|
|
|
|
server = self._servers.get(listener_name)
|
|
|
|
|
if not server:
|
|
|
|
|
msg = f"Invalid listener name '{listener_name}'"
|
|
|
|
|
raise BrokerError(msg)
|
|
|
|
|
await server.acquire_connection()
|
|
|
|
|
|
|
|
|
|
await server.acquire_connection()
|
|
|
|
|
remote_info = writer.get_peer_info()
|
|
|
|
|
if remote_info is None:
|
|
|
|
|
self.logger.warning("remote info could not get from peer info")
|
|
|
|
|
self.logger.warning("Remote info could not be retrieved from peer info")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
remote_address, remote_port = remote_info
|
|
|
|
|
self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
handler, client_session = await self._initialize_client_session(reader, writer, remote_address, remote_port)
|
|
|
|
|
except (AMQTTError, MQTTError, NoDataError) as exc:
|
|
|
|
|
self.logger.warning(f"Failed to initialize client session: {exc}")
|
|
|
|
|
server.release_connection()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
await self._handle_client_session(reader, writer, client_session, handler, server, listener_name)
|
|
|
|
|
except (AMQTTError, MQTTError, NoDataError) as exc:
|
|
|
|
|
self.logger.warning(f"Error while handling client session: {exc}")
|
|
|
|
|
finally:
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} Client disconnected")
|
|
|
|
|
server.release_connection()
|
|
|
|
|
|
|
|
|
|
async def _initialize_client_session(
|
|
|
|
|
self,
|
|
|
|
|
reader: ReaderAdapter,
|
|
|
|
|
writer: WriterAdapter,
|
|
|
|
|
remote_address: str,
|
|
|
|
|
remote_port: int,
|
|
|
|
|
) -> tuple[BrokerProtocolHandler, Session]:
|
|
|
|
|
"""Initialize a client session and protocol handler."""
|
|
|
|
|
# Wait for first packet and expect a CONNECT
|
|
|
|
|
try:
|
|
|
|
|
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
|
|
|
|
|
except AMQTTError as exc:
|
|
|
|
|
self.logger.warning(
|
|
|
|
|
f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:"
|
|
|
|
|
f"Can't read first packet an CONNECT: {exc}",
|
|
|
|
|
f" Can't read first packet as CONNECT: {exc}",
|
|
|
|
|
)
|
|
|
|
|
self.logger.debug("Connection closed")
|
|
|
|
|
server.release_connection()
|
|
|
|
|
return
|
|
|
|
|
except MQTTError:
|
|
|
|
|
raise AMQTTError(exc) from exc
|
|
|
|
|
except MQTTError as exc:
|
|
|
|
|
self.logger.exception(
|
|
|
|
|
f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}",
|
|
|
|
|
)
|
|
|
|
|
await writer.close()
|
|
|
|
|
server.release_connection()
|
|
|
|
|
self.logger.debug("Connection closed")
|
|
|
|
|
return
|
|
|
|
|
except NoDataError as ne:
|
|
|
|
|
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {ne}") # noqa: TRY400 # cannot replace with exception else test fails
|
|
|
|
|
server.release_connection()
|
|
|
|
|
return
|
|
|
|
|
raise MQTTError(exc) from exc
|
|
|
|
|
except NoDataError as exc:
|
|
|
|
|
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {exc}") # noqa: TRY400 # cannot replace with exception else pytest fails
|
|
|
|
|
raise NoDataError(exc) from exc
|
|
|
|
|
|
|
|
|
|
if client_session.clean_session:
|
|
|
|
|
# Delete existing session and create a new one
|
|
|
|
|
if client_session.client_id is not None and client_session.client_id != "":
|
|
|
|
|
self.delete_session(client_session.client_id)
|
|
|
|
|
await self._delete_session(client_session.client_id)
|
|
|
|
|
else:
|
|
|
|
|
client_session.client_id = gen_client_id()
|
|
|
|
|
client_session.parent = 0
|
|
|
|
@ -436,21 +484,32 @@ class Broker:
|
|
|
|
|
else:
|
|
|
|
|
client_session.parent = 0
|
|
|
|
|
|
|
|
|
|
timeout_disconnect_delay = self.config.get("timeout-disconnect-delay", 0)
|
|
|
|
|
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
|
|
|
|
|
client_session.keep_alive += timeout_disconnect_delay
|
|
|
|
|
|
|
|
|
|
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
|
|
|
|
|
return handler, client_session
|
|
|
|
|
|
|
|
|
|
async def _handle_client_session(
|
|
|
|
|
self,
|
|
|
|
|
reader: ReaderAdapter,
|
|
|
|
|
writer: WriterAdapter,
|
|
|
|
|
client_session: Session,
|
|
|
|
|
handler: BrokerProtocolHandler,
|
|
|
|
|
server: Server,
|
|
|
|
|
listener_name: str,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Handle the lifecycle of a client session."""
|
|
|
|
|
authenticated = await self._authenticate(client_session, self.listeners_config[listener_name])
|
|
|
|
|
if not authenticated:
|
|
|
|
|
await writer.close()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if client_session.client_id is None:
|
|
|
|
|
msg = "Client ID was not correctly created/set."
|
|
|
|
|
raise BrokerError(msg)
|
|
|
|
|
|
|
|
|
|
timeout_disconnect_delay = self.config.get("timeout-disconnect-delay")
|
|
|
|
|
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
|
|
|
|
|
client_session.keep_alive += timeout_disconnect_delay
|
|
|
|
|
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
|
|
|
|
|
|
|
|
|
|
authenticated = await self.authenticate(client_session, self.listeners_config[listener_name])
|
|
|
|
|
if not authenticated:
|
|
|
|
|
await writer.close()
|
|
|
|
|
server.release_connection() # Delete client from connections list
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
client_session.transitions.connect()
|
|
|
|
@ -469,14 +528,17 @@ class Broker:
|
|
|
|
|
self._sessions[client_session.client_id] = (client_session, handler)
|
|
|
|
|
|
|
|
|
|
await handler.mqtt_connack_authorize(authenticated)
|
|
|
|
|
|
|
|
|
|
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
|
|
|
|
|
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} Start messages handling")
|
|
|
|
|
await handler.start()
|
|
|
|
|
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
|
|
|
|
|
await self.publish_session_retained_messages(client_session)
|
|
|
|
|
await self._publish_session_retained_messages(client_session)
|
|
|
|
|
|
|
|
|
|
await self._client_message_loop(client_session, handler)
|
|
|
|
|
|
|
|
|
|
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
|
|
|
|
|
"""Run the main loop to handle client messages."""
|
|
|
|
|
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
|
|
|
|
|
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
|
|
|
|
|
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
|
|
|
@ -495,141 +557,160 @@ class Broker:
|
|
|
|
|
],
|
|
|
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
|
|
|
)
|
|
|
|
|
if disconnect_waiter in done:
|
|
|
|
|
result = disconnect_waiter.result()
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
|
|
|
|
|
if result is None:
|
|
|
|
|
self.logger.debug(f"Will flag: {client_session.will_flag}")
|
|
|
|
|
if client_session.will_flag:
|
|
|
|
|
self.logger.debug(
|
|
|
|
|
f"Client {format_client_message(client_session)} disconnected abnormally, sending will message",
|
|
|
|
|
)
|
|
|
|
|
await self._broadcast_message(
|
|
|
|
|
client_session,
|
|
|
|
|
client_session.will_topic,
|
|
|
|
|
client_session.will_message,
|
|
|
|
|
client_session.will_qos,
|
|
|
|
|
)
|
|
|
|
|
if client_session.will_retain:
|
|
|
|
|
self.retain_message(
|
|
|
|
|
client_session,
|
|
|
|
|
client_session.will_topic,
|
|
|
|
|
client_session.will_message,
|
|
|
|
|
client_session.will_qos,
|
|
|
|
|
)
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} Disconnecting session")
|
|
|
|
|
await self._stop_handler(handler)
|
|
|
|
|
client_session.transitions.disconnect()
|
|
|
|
|
await self.plugins_manager.fire_event(
|
|
|
|
|
EVENT_BROKER_CLIENT_DISCONNECTED,
|
|
|
|
|
client_id=client_session.client_id,
|
|
|
|
|
)
|
|
|
|
|
connected = False
|
|
|
|
|
# Recreate the disconnect_waiter task after processing
|
|
|
|
|
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
|
|
|
|
|
if unsubscribe_waiter in done:
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} handling unsubscription")
|
|
|
|
|
unsubscription = unsubscribe_waiter.result()
|
|
|
|
|
for topic in unsubscription.topics:
|
|
|
|
|
self._del_subscription(topic, client_session)
|
|
|
|
|
await self.plugins_manager.fire_event(
|
|
|
|
|
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
|
|
|
|
|
client_id=client_session.client_id,
|
|
|
|
|
topic=topic,
|
|
|
|
|
)
|
|
|
|
|
await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id)
|
|
|
|
|
# Recreate the unsubscribe_waiter task
|
|
|
|
|
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
|
|
|
|
|
if subscribe_waiter in done:
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} handling subscription")
|
|
|
|
|
subscriptions = subscribe_waiter.result()
|
|
|
|
|
return_codes = [
|
|
|
|
|
await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
|
|
|
|
|
for index, subscription in enumerate(subscriptions.topics):
|
|
|
|
|
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
|
|
|
|
|
await self.plugins_manager.fire_event(
|
|
|
|
|
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
|
|
|
|
client_id=client_session.client_id,
|
|
|
|
|
topic=subscription[0],
|
|
|
|
|
qos=subscription[1],
|
|
|
|
|
)
|
|
|
|
|
await self.publish_retained_messages_for_subscription(subscription, client_session)
|
|
|
|
|
# Recreate the subscribe_waiter task
|
|
|
|
|
if disconnect_waiter in done:
|
|
|
|
|
connected = await self._handle_disconnect(client_session, handler, disconnect_waiter)
|
|
|
|
|
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
|
|
|
|
|
|
|
|
|
|
if subscribe_waiter in done:
|
|
|
|
|
await self._handle_subscription(client_session, handler, subscribe_waiter)
|
|
|
|
|
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
|
|
|
|
self.logger.debug(repr(self._subscriptions))
|
|
|
|
|
|
|
|
|
|
if unsubscribe_waiter in done:
|
|
|
|
|
await self._handle_unsubscription(client_session, handler, unsubscribe_waiter)
|
|
|
|
|
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
|
|
|
|
|
|
|
|
|
|
if wait_deliver in done:
|
|
|
|
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} handling message delivery")
|
|
|
|
|
app_message = wait_deliver.result()
|
|
|
|
|
|
|
|
|
|
if app_message is None:
|
|
|
|
|
self.logger.debug("app_message was empty!")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if not app_message.topic:
|
|
|
|
|
self.logger.warning(
|
|
|
|
|
f"[MQTT-4.7.3-1] - {client_session.client_id}"
|
|
|
|
|
" invalid TOPIC sent in PUBLISH message, closing connection",
|
|
|
|
|
)
|
|
|
|
|
if not await self._handle_message_delivery(client_session, handler, wait_deliver):
|
|
|
|
|
break
|
|
|
|
|
if "#" in app_message.topic or "+" in app_message.topic:
|
|
|
|
|
self.logger.warning(
|
|
|
|
|
f"[MQTT-3.3.2-2] - {client_session.client_id}"
|
|
|
|
|
" invalid TOPIC sent in PUBLISH message, closing connection",
|
|
|
|
|
)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
permitted = await self.topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
|
|
|
|
|
if not permitted:
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.",
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
await self.plugins_manager.fire_event(
|
|
|
|
|
EVENT_BROKER_MESSAGE_RECEIVED,
|
|
|
|
|
client_id=client_session.client_id,
|
|
|
|
|
message=app_message,
|
|
|
|
|
)
|
|
|
|
|
await self._broadcast_message(client_session, app_message.topic, app_message.data)
|
|
|
|
|
if app_message.publish_packet is not None and app_message.publish_packet.retain_flag:
|
|
|
|
|
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
|
|
|
|
# Recreate the wait_deliver task
|
|
|
|
|
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
|
|
|
|
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
self.logger.debug("Client loop cancelled")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
disconnect_waiter.cancel()
|
|
|
|
|
subscribe_waiter.cancel()
|
|
|
|
|
unsubscribe_waiter.cancel()
|
|
|
|
|
wait_deliver.cancel()
|
|
|
|
|
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} Client disconnected")
|
|
|
|
|
server.release_connection()
|
|
|
|
|
async def _handle_disconnect(
|
|
|
|
|
self,
|
|
|
|
|
client_session: Session,
|
|
|
|
|
handler: BrokerProtocolHandler,
|
|
|
|
|
disconnect_waiter: asyncio.Future[Any],
|
|
|
|
|
) -> bool:
|
|
|
|
|
"""Handle client disconnection."""
|
|
|
|
|
result = disconnect_waiter.result()
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
|
|
|
|
|
if result is None:
|
|
|
|
|
self.logger.debug(f"Will flag: {client_session.will_flag}")
|
|
|
|
|
if client_session.will_flag:
|
|
|
|
|
self.logger.debug(
|
|
|
|
|
f"Client {format_client_message(client_session)} disconnected abnormally, sending will message",
|
|
|
|
|
)
|
|
|
|
|
await self._broadcast_message(
|
|
|
|
|
client_session,
|
|
|
|
|
client_session.will_topic,
|
|
|
|
|
client_session.will_message,
|
|
|
|
|
client_session.will_qos,
|
|
|
|
|
)
|
|
|
|
|
if client_session.will_retain:
|
|
|
|
|
self.retain_message(
|
|
|
|
|
client_session,
|
|
|
|
|
client_session.will_topic,
|
|
|
|
|
client_session.will_message,
|
|
|
|
|
client_session.will_qos,
|
|
|
|
|
)
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} Disconnecting session")
|
|
|
|
|
await self._stop_handler(handler)
|
|
|
|
|
client_session.transitions.disconnect()
|
|
|
|
|
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
|
|
|
|
|
"""Create a BrokerProtocolHandler and attach to a session.
|
|
|
|
|
async def _handle_subscription(
|
|
|
|
|
self,
|
|
|
|
|
client_session: Session,
|
|
|
|
|
handler: BrokerProtocolHandler,
|
|
|
|
|
subscribe_waiter: asyncio.Future[Any],
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Handle client subscription."""
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} handling subscription")
|
|
|
|
|
subscriptions = subscribe_waiter.result()
|
|
|
|
|
return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics]
|
|
|
|
|
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
|
|
|
|
|
for index, subscription in enumerate(subscriptions.topics):
|
|
|
|
|
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
|
|
|
|
|
await self.plugins_manager.fire_event(
|
|
|
|
|
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
|
|
|
|
client_id=client_session.client_id,
|
|
|
|
|
topic=subscription[0],
|
|
|
|
|
qos=subscription[1],
|
|
|
|
|
)
|
|
|
|
|
await self._publish_retained_messages_for_subscription(subscription, client_session)
|
|
|
|
|
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
|
|
|
|
|
handler.attach(session, reader, writer)
|
|
|
|
|
return handler
|
|
|
|
|
async def _handle_unsubscription(
|
|
|
|
|
self,
|
|
|
|
|
client_session: Session,
|
|
|
|
|
handler: BrokerProtocolHandler,
|
|
|
|
|
unsubscribe_waiter: asyncio.Future[Any],
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Handle client unsubscription."""
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} handling unsubscription")
|
|
|
|
|
unsubscription = unsubscribe_waiter.result()
|
|
|
|
|
for topic in unsubscription.topics:
|
|
|
|
|
self._del_subscription(topic, client_session)
|
|
|
|
|
await self.plugins_manager.fire_event(
|
|
|
|
|
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
|
|
|
|
|
client_id=client_session.client_id,
|
|
|
|
|
topic=topic,
|
|
|
|
|
)
|
|
|
|
|
await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id)
|
|
|
|
|
|
|
|
|
|
async def _handle_message_delivery(
|
|
|
|
|
self,
|
|
|
|
|
client_session: Session,
|
|
|
|
|
handler: BrokerProtocolHandler,
|
|
|
|
|
wait_deliver: asyncio.Future[Any],
|
|
|
|
|
) -> bool:
|
|
|
|
|
"""Handle message delivery to the client."""
|
|
|
|
|
self.logger.debug(f"{client_session.client_id} handling message delivery")
|
|
|
|
|
app_message = wait_deliver.result()
|
|
|
|
|
|
|
|
|
|
if app_message is None:
|
|
|
|
|
self.logger.debug("app_message was empty!")
|
|
|
|
|
return True
|
|
|
|
|
if not app_message.topic:
|
|
|
|
|
self.logger.warning(
|
|
|
|
|
f"[MQTT-4.7.3-1] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
|
|
|
|
|
)
|
|
|
|
|
return False
|
|
|
|
|
if "#" in app_message.topic or "+" in app_message.topic:
|
|
|
|
|
self.logger.warning(
|
|
|
|
|
f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
|
|
|
|
|
)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
|
|
|
|
|
if not permitted:
|
|
|
|
|
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
|
|
|
|
|
else:
|
|
|
|
|
await self.plugins_manager.fire_event(
|
|
|
|
|
EVENT_BROKER_MESSAGE_RECEIVED,
|
|
|
|
|
client_id=client_session.client_id,
|
|
|
|
|
message=app_message,
|
|
|
|
|
)
|
|
|
|
|
await self._broadcast_message(client_session, app_message.topic, app_message.data)
|
|
|
|
|
if app_message.publish_packet and app_message.publish_packet.retain_flag:
|
|
|
|
|
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
# async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
|
|
|
|
|
# """Create a BrokerProtocolHandler and attach to a session."""
|
|
|
|
|
# handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
|
|
|
|
|
# handler.attach(session, reader, writer)
|
|
|
|
|
# return handler
|
|
|
|
|
|
|
|
|
|
async def _stop_handler(self, handler: BrokerProtocolHandler) -> None:
|
|
|
|
|
"""Stop a running handler and detach if from the session.
|
|
|
|
|
|
|
|
|
|
:param handler:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
"""Stop a running handler and detach if from the session."""
|
|
|
|
|
try:
|
|
|
|
|
await handler.stop()
|
|
|
|
|
except Exception:
|
|
|
|
|
self.logger.exception("Failed to stop handler")
|
|
|
|
|
|
|
|
|
|
async def authenticate(self, session: Session, _: dict[str, Any]) -> bool:
|
|
|
|
|
async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool:
|
|
|
|
|
"""Call the authenticate method on registered plugins to test user authentication.
|
|
|
|
|
|
|
|
|
|
User is considered authenticated if all plugins called returns True.
|
|
|
|
@ -658,7 +739,49 @@ class Broker:
|
|
|
|
|
# If all plugins returned True, authentication is success
|
|
|
|
|
return auth_result
|
|
|
|
|
|
|
|
|
|
async def topic_filtering(self, session: Session, topic: str, action: Action) -> bool:
|
|
|
|
|
def retain_message(
|
|
|
|
|
self,
|
|
|
|
|
source_session: Session | None,
|
|
|
|
|
topic_name: str | None,
|
|
|
|
|
data: bytes | bytearray | None,
|
|
|
|
|
qos: int | None = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
if data and topic_name is not None:
|
|
|
|
|
# If retained flag set, store the message for further subscriptions
|
|
|
|
|
self.logger.debug(f"Retaining message on topic {topic_name}")
|
|
|
|
|
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
|
|
|
|
# [MQTT-3.3.1-10]
|
|
|
|
|
elif topic_name in self._retained_messages:
|
|
|
|
|
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
|
|
|
|
|
del self._retained_messages[topic_name]
|
|
|
|
|
|
|
|
|
|
async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
|
|
|
|
|
topic_filter, qos = subscription
|
|
|
|
|
if "#" in topic_filter and not topic_filter.endswith("#"):
|
|
|
|
|
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
|
|
|
|
|
return 0x80
|
|
|
|
|
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
|
|
|
|
|
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
|
|
|
|
|
return 0x80
|
|
|
|
|
# Check if the client is authorised to connect to the topic
|
|
|
|
|
if not await self._topic_filtering(session, topic_filter, Action.SUBSCRIBE):
|
|
|
|
|
return 0x80
|
|
|
|
|
|
|
|
|
|
# Ensure "max-qos" is an integer before using it
|
|
|
|
|
max_qos = self.config.get("max-qos", qos)
|
|
|
|
|
if not isinstance(max_qos, int):
|
|
|
|
|
max_qos = qos
|
|
|
|
|
|
|
|
|
|
qos = min(qos, max_qos)
|
|
|
|
|
if topic_filter not in self._subscriptions:
|
|
|
|
|
self._subscriptions[topic_filter] = []
|
|
|
|
|
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
|
|
|
|
|
self._subscriptions[topic_filter].append((session, qos))
|
|
|
|
|
else:
|
|
|
|
|
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
|
|
|
|
|
return qos
|
|
|
|
|
|
|
|
|
|
async def _topic_filtering(self, session: Session, topic: str, action: Action) -> bool:
|
|
|
|
|
"""Call the topic_filtering method on registered plugins to check that the subscription is allowed.
|
|
|
|
|
|
|
|
|
|
User is considered allowed if all plugins called return True.
|
|
|
|
@ -690,48 +813,29 @@ class Broker:
|
|
|
|
|
)
|
|
|
|
|
return all(result for result in results.values())
|
|
|
|
|
|
|
|
|
|
def retain_message(
|
|
|
|
|
self,
|
|
|
|
|
source_session: Session | None,
|
|
|
|
|
topic_name: str | None,
|
|
|
|
|
data: bytes | bytearray | None,
|
|
|
|
|
qos: int | None = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
if data and topic_name is not None:
|
|
|
|
|
# If retained flag set, store the message for further subscriptions
|
|
|
|
|
self.logger.debug(f"Retaining message on topic {topic_name}")
|
|
|
|
|
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
|
|
|
|
# [MQTT-3.3.1-10]
|
|
|
|
|
elif topic_name in self._retained_messages:
|
|
|
|
|
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
|
|
|
|
|
del self._retained_messages[topic_name]
|
|
|
|
|
async def _delete_session(self, client_id: str) -> None:
|
|
|
|
|
"""Delete an existing session data, for example due to clean session set in CONNECT."""
|
|
|
|
|
session = self._sessions.pop(client_id, (None, None))[0]
|
|
|
|
|
|
|
|
|
|
# NOTE: issue #61 remove try block
|
|
|
|
|
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
|
|
|
|
|
topic_filter, qos = subscription
|
|
|
|
|
if "#" in topic_filter and not topic_filter.endswith("#"):
|
|
|
|
|
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
|
|
|
|
|
return 0x80
|
|
|
|
|
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
|
|
|
|
|
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
|
|
|
|
|
return 0x80
|
|
|
|
|
# Check if the client is authorised to connect to the topic
|
|
|
|
|
if not await self.topic_filtering(session, topic_filter, Action.SUBSCRIBE):
|
|
|
|
|
return 0x80
|
|
|
|
|
if session is None:
|
|
|
|
|
self.logger.debug(f"Delete session : session {client_id} doesn't exist")
|
|
|
|
|
return
|
|
|
|
|
self.logger.debug(f"Deleted existing session {session!r}")
|
|
|
|
|
|
|
|
|
|
# Ensure "max-qos" is an integer before using it
|
|
|
|
|
max_qos = self.config.get("max-qos", qos)
|
|
|
|
|
if not isinstance(max_qos, int):
|
|
|
|
|
max_qos = qos
|
|
|
|
|
# Delete subscriptions
|
|
|
|
|
self.logger.debug(f"Deleting session {session!r} subscriptions")
|
|
|
|
|
await self._del_all_subscriptions(session)
|
|
|
|
|
session.clear_queues()
|
|
|
|
|
|
|
|
|
|
qos = min(qos, max_qos)
|
|
|
|
|
if topic_filter not in self._subscriptions:
|
|
|
|
|
self._subscriptions[topic_filter] = []
|
|
|
|
|
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
|
|
|
|
|
self._subscriptions[topic_filter].append((session, qos))
|
|
|
|
|
else:
|
|
|
|
|
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
|
|
|
|
|
return qos
|
|
|
|
|
async def _del_all_subscriptions(self, session: Session) -> None:
|
|
|
|
|
"""Delete all topic subscriptions for a given session."""
|
|
|
|
|
filter_queue: deque[str] = deque()
|
|
|
|
|
for topic in self._subscriptions:
|
|
|
|
|
if self._del_subscription(topic, session):
|
|
|
|
|
filter_queue.append(topic)
|
|
|
|
|
for topic in filter_queue:
|
|
|
|
|
if not self._subscriptions[topic]:
|
|
|
|
|
del self._subscriptions[topic]
|
|
|
|
|
|
|
|
|
|
def _del_subscription(self, a_filter: str, session: Session) -> int:
|
|
|
|
|
"""Delete a session subscription on a given topic.
|
|
|
|
@ -752,32 +856,9 @@ class Broker:
|
|
|
|
|
deleted += 1
|
|
|
|
|
break
|
|
|
|
|
except KeyError:
|
|
|
|
|
# Unsubscribe topic not found in current subscribed topics
|
|
|
|
|
pass
|
|
|
|
|
self.logger.debug(f"Unsubscription on topic '{a_filter}' for client {format_client_message(session=session)}")
|
|
|
|
|
return deleted
|
|
|
|
|
|
|
|
|
|
def _del_all_subscriptions(self, session: Session) -> None:
|
|
|
|
|
"""Delete all topic subscriptions for a given session.
|
|
|
|
|
|
|
|
|
|
:param session:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
filter_queue: deque[str] = deque()
|
|
|
|
|
for topic in self._subscriptions:
|
|
|
|
|
if self._del_subscription(topic, session):
|
|
|
|
|
filter_queue.append(topic)
|
|
|
|
|
for topic in filter_queue:
|
|
|
|
|
if not self._subscriptions[topic]:
|
|
|
|
|
del self._subscriptions[topic]
|
|
|
|
|
|
|
|
|
|
def matches(self, topic: str, a_filter: str) -> bool:
|
|
|
|
|
if "#" not in a_filter and "+" not in a_filter:
|
|
|
|
|
# if filter doesn't contain wildcard, return exact match
|
|
|
|
|
return a_filter == topic
|
|
|
|
|
# else use regex
|
|
|
|
|
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
|
|
|
|
|
return bool(match_pattern.fullmatch(topic))
|
|
|
|
|
|
|
|
|
|
async def _broadcast_loop(self) -> None:
|
|
|
|
|
"""Run the main loop to broadcast messages."""
|
|
|
|
|
running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = self._tasks_queue
|
|
|
|
@ -824,7 +905,7 @@ class Broker:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Skip all subscriptions which do not match the topic
|
|
|
|
|
if not self.matches(broadcast["topic"], k_filter):
|
|
|
|
|
if not self._matches(broadcast["topic"], k_filter):
|
|
|
|
|
self.logger.debug(f"Topic '{broadcast['topic']}' does not match filter '{k_filter}'")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
@ -892,7 +973,7 @@ class Broker:
|
|
|
|
|
broadcast["qos"] = force_qos
|
|
|
|
|
await self._broadcast_queue.put(broadcast)
|
|
|
|
|
|
|
|
|
|
async def publish_session_retained_messages(self, session: Session) -> None:
|
|
|
|
|
async def _publish_session_retained_messages(self, session: Session) -> None:
|
|
|
|
|
self.logger.debug(
|
|
|
|
|
f"Publishing {session.retained_messages.qsize()}"
|
|
|
|
|
f" messages retained for session {format_client_message(session=session)}",
|
|
|
|
@ -910,7 +991,7 @@ class Broker:
|
|
|
|
|
if publish_tasks:
|
|
|
|
|
await asyncio.wait(publish_tasks)
|
|
|
|
|
|
|
|
|
|
async def publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None:
|
|
|
|
|
async def _publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None:
|
|
|
|
|
self.logger.debug(
|
|
|
|
|
f"Begin broadcasting messages retained due to subscription on '{subscription[0]}'"
|
|
|
|
|
f" from {format_client_message(session=session)}",
|
|
|
|
@ -920,7 +1001,7 @@ class Broker:
|
|
|
|
|
topic_filter, qos = subscription
|
|
|
|
|
for topic, retained in self._retained_messages.items():
|
|
|
|
|
self.logger.debug(f"matching : {topic} {topic_filter}")
|
|
|
|
|
if self.matches(topic, topic_filter):
|
|
|
|
|
if self._matches(topic, topic_filter):
|
|
|
|
|
self.logger.debug(f"{topic} and {topic_filter} match")
|
|
|
|
|
handler = self._get_handler(session)
|
|
|
|
|
if handler:
|
|
|
|
@ -936,62 +1017,16 @@ class Broker:
|
|
|
|
|
f" from {format_client_message(session=session)}",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def delete_session(self, client_id: str) -> None:
|
|
|
|
|
"""Delete an existing session data, for example due to clean session set in CONNECT."""
|
|
|
|
|
session = self._sessions.pop(client_id, (None, None))[0]
|
|
|
|
|
|
|
|
|
|
if session is None:
|
|
|
|
|
self.logger.debug(f"Delete session : session {client_id} doesn't exist")
|
|
|
|
|
return
|
|
|
|
|
self.logger.debug(f"Deleted existing session {session!r}")
|
|
|
|
|
|
|
|
|
|
# Delete subscriptions
|
|
|
|
|
self.logger.debug(f"Deleting session {session!r} subscriptions")
|
|
|
|
|
self._del_all_subscriptions(session)
|
|
|
|
|
session.clear_queues()
|
|
|
|
|
def _matches(self, topic: str, a_filter: str) -> bool:
|
|
|
|
|
if "#" not in a_filter and "+" not in a_filter:
|
|
|
|
|
# if filter doesn't contain wildcard, return exact match
|
|
|
|
|
return a_filter == topic
|
|
|
|
|
# else use regex
|
|
|
|
|
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
|
|
|
|
|
return bool(match_pattern.fullmatch(topic))
|
|
|
|
|
|
|
|
|
|
def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:
|
|
|
|
|
client_id = session.client_id
|
|
|
|
|
if client_id:
|
|
|
|
|
return self._sessions.get(client_id, (None, None))[1]
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
|
|
|
|
|
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
|
|
|
|
|
|
|
|
|
|
NOTE: issue #72
|
|
|
|
|
|
|
|
|
|
- Address can be specified using one of the following methods:
|
|
|
|
|
- empty string - all interfaces default port
|
|
|
|
|
- 1883 - Port number only (listen all interfaces)
|
|
|
|
|
- :1883 - Port number only (listen all interfaces)
|
|
|
|
|
- 0.0.0.0:1883 - IPv4 address
|
|
|
|
|
- [::]:1883 - IPv6 address
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _parse_port(port_str: str) -> int:
|
|
|
|
|
port_str = port_str.removeprefix(":")
|
|
|
|
|
|
|
|
|
|
if not port_str:
|
|
|
|
|
return default_port
|
|
|
|
|
|
|
|
|
|
return int(port_str)
|
|
|
|
|
|
|
|
|
|
if port_str.startswith("["): # IPv6 literal
|
|
|
|
|
try:
|
|
|
|
|
addr_end = port_str.index("]")
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
msg = "Expecting '[' to be followed by ']'"
|
|
|
|
|
raise ValueError(msg) from e
|
|
|
|
|
|
|
|
|
|
return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
|
|
|
|
|
|
|
|
|
|
if ":" in port_str:
|
|
|
|
|
address, port_str = port_str.rsplit(":", 1)
|
|
|
|
|
return (address or None, _parse_port(port_str))
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
return (None, _parse_port(port_str))
|
|
|
|
|
except ValueError:
|
|
|
|
|
return (port_str, default_port)
|
|
|
|
|