kopia lustrzana https://github.com/Yakifo/amqtt
refactor: modularize broker listener startup and improve session management
- Extracted listener startup logic into a separate method `_start_listeners` for better readability and maintainability. - Created helper methods for SSL context creation. - Improved session cleanup during shutdown by centralizing logic in `_cleanup_session`. - Refactored message handling and subscription management into dedicated methods for clarity. - Updated tests to reflect changes in method visibility for matching topics.pull/165/head
rodzic
3e0902cc8b
commit
12d5cb6866
667
amqtt/broker.py
667
amqtt/broker.py
|
@ -242,72 +242,10 @@ class Broker:
|
||||||
|
|
||||||
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
|
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
|
||||||
try:
|
try:
|
||||||
# Start network listeners
|
await self._start_listeners()
|
||||||
for listener_name, listener in self.listeners_config.items():
|
|
||||||
if "bind" not in listener:
|
|
||||||
self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
|
|
||||||
continue
|
|
||||||
|
|
||||||
max_connections = listener.get("max_connections", -1)
|
|
||||||
|
|
||||||
ssl_conection = None
|
|
||||||
ssl_active = listener.get("ssl", False) # accept string "on" / "off" or boolean
|
|
||||||
if isinstance(ssl_active, str):
|
|
||||||
ssl_active = ssl_active.upper() == "ON"
|
|
||||||
|
|
||||||
if ssl_active:
|
|
||||||
try:
|
|
||||||
ssl_conection = ssl.create_default_context(
|
|
||||||
ssl.Purpose.CLIENT_AUTH,
|
|
||||||
cafile=listener.get("cafile"),
|
|
||||||
capath=listener.get("capath"),
|
|
||||||
cadata=listener.get("cadata"),
|
|
||||||
)
|
|
||||||
ssl_conection.load_cert_chain(listener["certfile"], listener["keyfile"])
|
|
||||||
ssl_conection.verify_mode = ssl.CERT_OPTIONAL
|
|
||||||
except KeyError as ke:
|
|
||||||
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
|
|
||||||
raise BrokerError(msg) from ke
|
|
||||||
except FileNotFoundError as fnfe:
|
|
||||||
msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
|
|
||||||
raise BrokerError(msg) from fnfe
|
|
||||||
|
|
||||||
try:
|
|
||||||
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
|
|
||||||
except ValueError as e:
|
|
||||||
msg = f"Invalid port value in bind value: {listener['bind']}"
|
|
||||||
raise BrokerError(msg) from e
|
|
||||||
|
|
||||||
instance: asyncio.Server | websockets.asyncio.server.Server | None = None
|
|
||||||
if listener["type"] == "tcp":
|
|
||||||
cb_partial = partial(self.stream_connected, listener_name=listener_name)
|
|
||||||
instance = await asyncio.start_server(
|
|
||||||
cb_partial,
|
|
||||||
address,
|
|
||||||
port,
|
|
||||||
reuse_address=True,
|
|
||||||
ssl=ssl_conection,
|
|
||||||
)
|
|
||||||
self._servers[listener_name] = Server(listener_name, instance, max_connections)
|
|
||||||
elif listener["type"] == "ws":
|
|
||||||
cb_partial = partial(self.ws_connected, listener_name=listener_name)
|
|
||||||
instance = await websockets.serve(
|
|
||||||
cb_partial,
|
|
||||||
address,
|
|
||||||
port,
|
|
||||||
ssl=ssl_conection,
|
|
||||||
subprotocols=[websockets.Subprotocol("mqtt")],
|
|
||||||
)
|
|
||||||
self._servers[listener_name] = Server(listener_name, instance, max_connections)
|
|
||||||
|
|
||||||
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
|
|
||||||
|
|
||||||
self.transitions.starting_success()
|
self.transitions.starting_success()
|
||||||
await self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
|
await self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
|
||||||
|
|
||||||
# Start broadcast loop
|
|
||||||
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
|
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
|
||||||
|
|
||||||
self.logger.debug("Broker started")
|
self.logger.debug("Broker started")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception("Broker startup failed")
|
self.logger.exception("Broker startup failed")
|
||||||
|
@ -315,47 +253,127 @@ class Broker:
|
||||||
msg = f"Broker instance can't be started: {e}"
|
msg = f"Broker instance can't be started: {e}"
|
||||||
raise BrokerError(msg) from e
|
raise BrokerError(msg) from e
|
||||||
|
|
||||||
|
async def _start_listeners(self) -> None:
|
||||||
|
"""Start network listeners based on the configuration."""
|
||||||
|
for listener_name, listener in self.listeners_config.items():
|
||||||
|
if "bind" not in listener:
|
||||||
|
self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
|
||||||
|
continue
|
||||||
|
|
||||||
|
max_connections = listener.get("max_connections", -1)
|
||||||
|
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
|
||||||
|
except ValueError as e:
|
||||||
|
msg = f"Invalid port value in bind value: {listener['bind']}"
|
||||||
|
raise BrokerError(msg) from e
|
||||||
|
|
||||||
|
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
|
||||||
|
self._servers[listener_name] = Server(listener_name, instance, max_connections)
|
||||||
|
|
||||||
|
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
|
||||||
|
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
|
||||||
|
|
||||||
|
- Address can be specified using one of the following methods:
|
||||||
|
- empty string - all interfaces default port
|
||||||
|
- 1883 - Port number only (listen all interfaces)
|
||||||
|
- :1883 - Port number only (listen all interfaces)
|
||||||
|
- 0.0.0.0:1883 - IPv4 address
|
||||||
|
- [::]:1883 - IPv6 address
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _parse_port(port_str: str) -> int:
|
||||||
|
port_str = port_str.removeprefix(":")
|
||||||
|
|
||||||
|
if not port_str:
|
||||||
|
return default_port
|
||||||
|
|
||||||
|
return int(port_str)
|
||||||
|
|
||||||
|
if port_str.startswith("["): # IPv6 literal
|
||||||
|
try:
|
||||||
|
addr_end = port_str.index("]")
|
||||||
|
except ValueError as e:
|
||||||
|
msg = "Expecting '[' to be followed by ']'"
|
||||||
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
|
return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
|
||||||
|
|
||||||
|
if ":" in port_str:
|
||||||
|
address, port_str = port_str.rsplit(":", 1)
|
||||||
|
return (address or None, _parse_port(port_str))
|
||||||
|
|
||||||
|
try:
|
||||||
|
return (None, _parse_port(port_str))
|
||||||
|
except ValueError:
|
||||||
|
return (port_str, default_port)
|
||||||
|
|
||||||
|
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext:
|
||||||
|
"""Create an SSL context for a listener."""
|
||||||
|
try:
|
||||||
|
ssl_context = ssl.create_default_context(
|
||||||
|
ssl.Purpose.CLIENT_AUTH,
|
||||||
|
cafile=listener.get("cafile"),
|
||||||
|
capath=listener.get("capath"),
|
||||||
|
cadata=listener.get("cadata"),
|
||||||
|
)
|
||||||
|
ssl_context.load_cert_chain(listener["certfile"], listener["keyfile"])
|
||||||
|
ssl_context.verify_mode = ssl.CERT_OPTIONAL
|
||||||
|
return ssl_context
|
||||||
|
except KeyError as ke:
|
||||||
|
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
|
||||||
|
raise BrokerError(msg) from ke
|
||||||
|
except FileNotFoundError as fnfe:
|
||||||
|
msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
|
||||||
|
raise BrokerError(msg) from fnfe
|
||||||
|
|
||||||
|
async def _create_server_instance(
|
||||||
|
self,
|
||||||
|
listener_name: str,
|
||||||
|
listener_type: str,
|
||||||
|
address: str | None,
|
||||||
|
port: int,
|
||||||
|
ssl_context: ssl.SSLContext | None,
|
||||||
|
) -> asyncio.Server | websockets.asyncio.server.Server:
|
||||||
|
"""Create a server instance for a listener."""
|
||||||
|
if listener_type == "tcp":
|
||||||
|
return await asyncio.start_server(
|
||||||
|
partial(self.stream_connected, listener_name=listener_name),
|
||||||
|
address,
|
||||||
|
port,
|
||||||
|
reuse_address=True,
|
||||||
|
ssl=ssl_context,
|
||||||
|
)
|
||||||
|
if listener_type == "ws":
|
||||||
|
return await websockets.serve(
|
||||||
|
partial(self.ws_connected, listener_name=listener_name),
|
||||||
|
address,
|
||||||
|
port,
|
||||||
|
ssl=ssl_context,
|
||||||
|
subprotocols=[websockets.Subprotocol("mqtt")],
|
||||||
|
)
|
||||||
|
msg = f"Unsupported listener type: {listener_type}"
|
||||||
|
raise BrokerError(msg)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
"""Stop broker instance."""
|
"""Stop broker instance."""
|
||||||
try:
|
self.logger.info("Shutting down broker...")
|
||||||
# # Wait for all in-flight tasks to complete before stopping session handlers
|
# Fire broker_shutdown event to plugins
|
||||||
# for client_id, (session, handler) in self._sessions.items():
|
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
|
||||||
# if handler:
|
|
||||||
# self.logger.debug(f"Waiting for in-flight tasks to complete for session {client_id}")
|
|
||||||
# if session.inflight_out:
|
|
||||||
# # Directly use asyncio.sleep or another async operation in the loop
|
|
||||||
# await asyncio.gather(
|
|
||||||
# *(asyncio.sleep(0) for _ in session.inflight_out.values()),
|
|
||||||
# return_exceptions=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Stop all session handlers
|
# Cleanup all sessions
|
||||||
for client_id, (_, handler) in self._sessions.items():
|
for client_id in list(self._sessions.keys()):
|
||||||
if handler:
|
await self._cleanup_session(client_id)
|
||||||
self.logger.debug(f"Stopping handler for session {client_id}")
|
|
||||||
await self._stop_handler(handler)
|
|
||||||
|
|
||||||
# Clear subscriptions
|
|
||||||
for topic, subscriptions in self._subscriptions.items():
|
|
||||||
self.logger.debug(f"Clearing subscriptions for topic '{topic}'")
|
|
||||||
for session, _ in subscriptions:
|
|
||||||
self._del_subscription(topic, session)
|
|
||||||
self._subscriptions.clear()
|
|
||||||
|
|
||||||
# Clear retained messages
|
# Clear retained messages
|
||||||
if self._retained_messages:
|
|
||||||
self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages")
|
self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages")
|
||||||
self._retained_messages.clear()
|
self._retained_messages.clear()
|
||||||
|
|
||||||
self._sessions.clear()
|
|
||||||
self.transitions.shutdown()
|
self.transitions.shutdown()
|
||||||
except (MachineError, ValueError) as exc:
|
|
||||||
# Backwards compat: MachineError is raised by transitions < 0.5.0.
|
|
||||||
self.logger.debug(f"Invalid method call at this moment: {exc}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Fire broker_shutdown event to plugins
|
|
||||||
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
|
|
||||||
|
|
||||||
await self._shutdown_broadcast_loop()
|
await self._shutdown_broadcast_loop()
|
||||||
|
|
||||||
|
@ -372,59 +390,89 @@ class Broker:
|
||||||
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
|
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
|
||||||
self.transitions.stopping_success()
|
self.transitions.stopping_success()
|
||||||
|
|
||||||
|
async def _cleanup_session(self, client_id: str) -> None:
|
||||||
|
"""Centralized cleanup logic for a session."""
|
||||||
|
session, handler = self._sessions.pop(client_id, (None, None))
|
||||||
|
|
||||||
|
if handler:
|
||||||
|
self.logger.debug(f"Stopping handler for session {client_id}")
|
||||||
|
await self._stop_handler(handler)
|
||||||
|
if session:
|
||||||
|
self.logger.debug(f"Clearing all subscriptions for session {client_id}")
|
||||||
|
await self._del_all_subscriptions(session)
|
||||||
|
session.clear_queues()
|
||||||
|
|
||||||
async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
||||||
return await self._broadcast_message(None, topic, data, qos)
|
return await self._broadcast_message(None, topic, data, qos)
|
||||||
|
|
||||||
async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None:
|
async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None:
|
||||||
await self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
await self._client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
|
||||||
|
|
||||||
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
|
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
|
||||||
await self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
|
||||||
|
|
||||||
async def client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
|
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
|
||||||
# Wait for connection available on listener
|
"""Handle a new client connection."""
|
||||||
server = self._servers.get(listener_name, None)
|
server = self._servers.get(listener_name)
|
||||||
if not server:
|
if not server:
|
||||||
msg = f"Invalid listener name '{listener_name}'"
|
msg = f"Invalid listener name '{listener_name}'"
|
||||||
raise BrokerError(msg)
|
raise BrokerError(msg)
|
||||||
await server.acquire_connection()
|
|
||||||
|
|
||||||
|
await server.acquire_connection()
|
||||||
remote_info = writer.get_peer_info()
|
remote_info = writer.get_peer_info()
|
||||||
if remote_info is None:
|
if remote_info is None:
|
||||||
self.logger.warning("remote info could not get from peer info")
|
self.logger.warning("Remote info could not be retrieved from peer info")
|
||||||
return
|
return
|
||||||
|
|
||||||
remote_address, remote_port = remote_info
|
remote_address, remote_port = remote_info
|
||||||
self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'")
|
self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
handler, client_session = await self._initialize_client_session(reader, writer, remote_address, remote_port)
|
||||||
|
except (AMQTTError, MQTTError, NoDataError) as exc:
|
||||||
|
self.logger.warning(f"Failed to initialize client session: {exc}")
|
||||||
|
server.release_connection()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._handle_client_session(reader, writer, client_session, handler, server, listener_name)
|
||||||
|
except (AMQTTError, MQTTError, NoDataError) as exc:
|
||||||
|
self.logger.warning(f"Error while handling client session: {exc}")
|
||||||
|
finally:
|
||||||
|
self.logger.debug(f"{client_session.client_id} Client disconnected")
|
||||||
|
server.release_connection()
|
||||||
|
|
||||||
|
async def _initialize_client_session(
|
||||||
|
self,
|
||||||
|
reader: ReaderAdapter,
|
||||||
|
writer: WriterAdapter,
|
||||||
|
remote_address: str,
|
||||||
|
remote_port: int,
|
||||||
|
) -> tuple[BrokerProtocolHandler, Session]:
|
||||||
|
"""Initialize a client session and protocol handler."""
|
||||||
# Wait for first packet and expect a CONNECT
|
# Wait for first packet and expect a CONNECT
|
||||||
try:
|
try:
|
||||||
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
|
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
|
||||||
except AMQTTError as exc:
|
except AMQTTError as exc:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:"
|
f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:"
|
||||||
f"Can't read first packet an CONNECT: {exc}",
|
f" Can't read first packet as CONNECT: {exc}",
|
||||||
)
|
)
|
||||||
self.logger.debug("Connection closed")
|
raise AMQTTError(exc) from exc
|
||||||
server.release_connection()
|
except MQTTError as exc:
|
||||||
return
|
|
||||||
except MQTTError:
|
|
||||||
self.logger.exception(
|
self.logger.exception(
|
||||||
f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}",
|
f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}",
|
||||||
)
|
)
|
||||||
await writer.close()
|
await writer.close()
|
||||||
server.release_connection()
|
raise MQTTError(exc) from exc
|
||||||
self.logger.debug("Connection closed")
|
except NoDataError as exc:
|
||||||
return
|
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {exc}") # noqa: TRY400 # cannot replace with exception else pytest fails
|
||||||
except NoDataError as ne:
|
raise NoDataError(exc) from exc
|
||||||
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {ne}") # noqa: TRY400 # cannot replace with exception else test fails
|
|
||||||
server.release_connection()
|
|
||||||
return
|
|
||||||
|
|
||||||
if client_session.clean_session:
|
if client_session.clean_session:
|
||||||
# Delete existing session and create a new one
|
# Delete existing session and create a new one
|
||||||
if client_session.client_id is not None and client_session.client_id != "":
|
if client_session.client_id is not None and client_session.client_id != "":
|
||||||
self.delete_session(client_session.client_id)
|
await self._delete_session(client_session.client_id)
|
||||||
else:
|
else:
|
||||||
client_session.client_id = gen_client_id()
|
client_session.client_id = gen_client_id()
|
||||||
client_session.parent = 0
|
client_session.parent = 0
|
||||||
|
@ -436,21 +484,32 @@ class Broker:
|
||||||
else:
|
else:
|
||||||
client_session.parent = 0
|
client_session.parent = 0
|
||||||
|
|
||||||
|
timeout_disconnect_delay = self.config.get("timeout-disconnect-delay", 0)
|
||||||
|
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
|
||||||
|
client_session.keep_alive += timeout_disconnect_delay
|
||||||
|
|
||||||
|
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
|
||||||
|
return handler, client_session
|
||||||
|
|
||||||
|
async def _handle_client_session(
|
||||||
|
self,
|
||||||
|
reader: ReaderAdapter,
|
||||||
|
writer: WriterAdapter,
|
||||||
|
client_session: Session,
|
||||||
|
handler: BrokerProtocolHandler,
|
||||||
|
server: Server,
|
||||||
|
listener_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Handle the lifecycle of a client session."""
|
||||||
|
authenticated = await self._authenticate(client_session, self.listeners_config[listener_name])
|
||||||
|
if not authenticated:
|
||||||
|
await writer.close()
|
||||||
|
return
|
||||||
|
|
||||||
if client_session.client_id is None:
|
if client_session.client_id is None:
|
||||||
msg = "Client ID was not correctly created/set."
|
msg = "Client ID was not correctly created/set."
|
||||||
raise BrokerError(msg)
|
raise BrokerError(msg)
|
||||||
|
|
||||||
timeout_disconnect_delay = self.config.get("timeout-disconnect-delay")
|
|
||||||
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
|
|
||||||
client_session.keep_alive += timeout_disconnect_delay
|
|
||||||
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
|
|
||||||
|
|
||||||
authenticated = await self.authenticate(client_session, self.listeners_config[listener_name])
|
|
||||||
if not authenticated:
|
|
||||||
await writer.close()
|
|
||||||
server.release_connection() # Delete client from connections list
|
|
||||||
return
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
client_session.transitions.connect()
|
client_session.transitions.connect()
|
||||||
|
@ -469,14 +528,17 @@ class Broker:
|
||||||
self._sessions[client_session.client_id] = (client_session, handler)
|
self._sessions[client_session.client_id] = (client_session, handler)
|
||||||
|
|
||||||
await handler.mqtt_connack_authorize(authenticated)
|
await handler.mqtt_connack_authorize(authenticated)
|
||||||
|
|
||||||
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
|
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
|
||||||
|
|
||||||
self.logger.debug(f"{client_session.client_id} Start messages handling")
|
self.logger.debug(f"{client_session.client_id} Start messages handling")
|
||||||
await handler.start()
|
await handler.start()
|
||||||
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
|
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
|
||||||
await self.publish_session_retained_messages(client_session)
|
await self._publish_session_retained_messages(client_session)
|
||||||
|
|
||||||
|
await self._client_message_loop(client_session, handler)
|
||||||
|
|
||||||
|
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
|
||||||
|
"""Run the main loop to handle client messages."""
|
||||||
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
|
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
|
||||||
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
|
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
|
||||||
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
||||||
|
@ -495,7 +557,41 @@ class Broker:
|
||||||
],
|
],
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
)
|
)
|
||||||
|
|
||||||
if disconnect_waiter in done:
|
if disconnect_waiter in done:
|
||||||
|
connected = await self._handle_disconnect(client_session, handler, disconnect_waiter)
|
||||||
|
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
|
||||||
|
|
||||||
|
if subscribe_waiter in done:
|
||||||
|
await self._handle_subscription(client_session, handler, subscribe_waiter)
|
||||||
|
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
||||||
|
self.logger.debug(repr(self._subscriptions))
|
||||||
|
|
||||||
|
if unsubscribe_waiter in done:
|
||||||
|
await self._handle_unsubscription(client_session, handler, unsubscribe_waiter)
|
||||||
|
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
|
||||||
|
|
||||||
|
if wait_deliver in done:
|
||||||
|
if not await self._handle_message_delivery(client_session, handler, wait_deliver):
|
||||||
|
break
|
||||||
|
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.logger.debug("Client loop cancelled")
|
||||||
|
break
|
||||||
|
|
||||||
|
disconnect_waiter.cancel()
|
||||||
|
subscribe_waiter.cancel()
|
||||||
|
unsubscribe_waiter.cancel()
|
||||||
|
wait_deliver.cancel()
|
||||||
|
|
||||||
|
async def _handle_disconnect(
|
||||||
|
self,
|
||||||
|
client_session: Session,
|
||||||
|
handler: BrokerProtocolHandler,
|
||||||
|
disconnect_waiter: asyncio.Future[Any],
|
||||||
|
) -> bool:
|
||||||
|
"""Handle client disconnection."""
|
||||||
result = disconnect_waiter.result()
|
result = disconnect_waiter.result()
|
||||||
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
|
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
|
||||||
if result is None:
|
if result is None:
|
||||||
|
@ -520,14 +616,38 @@ class Broker:
|
||||||
self.logger.debug(f"{client_session.client_id} Disconnecting session")
|
self.logger.debug(f"{client_session.client_id} Disconnecting session")
|
||||||
await self._stop_handler(handler)
|
await self._stop_handler(handler)
|
||||||
client_session.transitions.disconnect()
|
client_session.transitions.disconnect()
|
||||||
|
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _handle_subscription(
|
||||||
|
self,
|
||||||
|
client_session: Session,
|
||||||
|
handler: BrokerProtocolHandler,
|
||||||
|
subscribe_waiter: asyncio.Future[Any],
|
||||||
|
) -> None:
|
||||||
|
"""Handle client subscription."""
|
||||||
|
self.logger.debug(f"{client_session.client_id} handling subscription")
|
||||||
|
subscriptions = subscribe_waiter.result()
|
||||||
|
return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics]
|
||||||
|
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
|
||||||
|
for index, subscription in enumerate(subscriptions.topics):
|
||||||
|
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
|
||||||
await self.plugins_manager.fire_event(
|
await self.plugins_manager.fire_event(
|
||||||
EVENT_BROKER_CLIENT_DISCONNECTED,
|
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||||
client_id=client_session.client_id,
|
client_id=client_session.client_id,
|
||||||
|
topic=subscription[0],
|
||||||
|
qos=subscription[1],
|
||||||
)
|
)
|
||||||
connected = False
|
await self._publish_retained_messages_for_subscription(subscription, client_session)
|
||||||
# Recreate the disconnect_waiter task after processing
|
|
||||||
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
|
async def _handle_unsubscription(
|
||||||
if unsubscribe_waiter in done:
|
self,
|
||||||
|
client_session: Session,
|
||||||
|
handler: BrokerProtocolHandler,
|
||||||
|
unsubscribe_waiter: asyncio.Future[Any],
|
||||||
|
) -> None:
|
||||||
|
"""Handle client unsubscription."""
|
||||||
self.logger.debug(f"{client_session.client_id} handling unsubscription")
|
self.logger.debug(f"{client_session.client_id} handling unsubscription")
|
||||||
unsubscription = unsubscribe_waiter.result()
|
unsubscription = unsubscribe_waiter.result()
|
||||||
for topic in unsubscription.topics:
|
for topic in unsubscription.topics:
|
||||||
|
@ -538,55 +658,34 @@ class Broker:
|
||||||
topic=topic,
|
topic=topic,
|
||||||
)
|
)
|
||||||
await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id)
|
await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id)
|
||||||
# Recreate the unsubscribe_waiter task
|
|
||||||
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
|
|
||||||
if subscribe_waiter in done:
|
|
||||||
self.logger.debug(f"{client_session.client_id} handling subscription")
|
|
||||||
subscriptions = subscribe_waiter.result()
|
|
||||||
return_codes = [
|
|
||||||
await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics
|
|
||||||
]
|
|
||||||
|
|
||||||
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
|
async def _handle_message_delivery(
|
||||||
for index, subscription in enumerate(subscriptions.topics):
|
self,
|
||||||
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
|
client_session: Session,
|
||||||
await self.plugins_manager.fire_event(
|
handler: BrokerProtocolHandler,
|
||||||
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
wait_deliver: asyncio.Future[Any],
|
||||||
client_id=client_session.client_id,
|
) -> bool:
|
||||||
topic=subscription[0],
|
"""Handle message delivery to the client."""
|
||||||
qos=subscription[1],
|
|
||||||
)
|
|
||||||
await self.publish_retained_messages_for_subscription(subscription, client_session)
|
|
||||||
# Recreate the subscribe_waiter task
|
|
||||||
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
|
|
||||||
self.logger.debug(repr(self._subscriptions))
|
|
||||||
if wait_deliver in done:
|
|
||||||
if self.logger.isEnabledFor(logging.DEBUG):
|
|
||||||
self.logger.debug(f"{client_session.client_id} handling message delivery")
|
self.logger.debug(f"{client_session.client_id} handling message delivery")
|
||||||
app_message = wait_deliver.result()
|
app_message = wait_deliver.result()
|
||||||
|
|
||||||
if app_message is None:
|
if app_message is None:
|
||||||
self.logger.debug("app_message was empty!")
|
self.logger.debug("app_message was empty!")
|
||||||
continue
|
return True
|
||||||
|
|
||||||
if not app_message.topic:
|
if not app_message.topic:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"[MQTT-4.7.3-1] - {client_session.client_id}"
|
f"[MQTT-4.7.3-1] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
|
||||||
" invalid TOPIC sent in PUBLISH message, closing connection",
|
|
||||||
)
|
)
|
||||||
break
|
return False
|
||||||
if "#" in app_message.topic or "+" in app_message.topic:
|
if "#" in app_message.topic or "+" in app_message.topic:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"[MQTT-3.3.2-2] - {client_session.client_id}"
|
f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
|
||||||
" invalid TOPIC sent in PUBLISH message, closing connection",
|
|
||||||
)
|
)
|
||||||
break
|
return False
|
||||||
|
|
||||||
permitted = await self.topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
|
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
|
||||||
if not permitted:
|
if not permitted:
|
||||||
self.logger.info(
|
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
|
||||||
f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.",
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
await self.plugins_manager.fire_event(
|
await self.plugins_manager.fire_event(
|
||||||
EVENT_BROKER_MESSAGE_RECEIVED,
|
EVENT_BROKER_MESSAGE_RECEIVED,
|
||||||
|
@ -594,42 +693,24 @@ class Broker:
|
||||||
message=app_message,
|
message=app_message,
|
||||||
)
|
)
|
||||||
await self._broadcast_message(client_session, app_message.topic, app_message.data)
|
await self._broadcast_message(client_session, app_message.topic, app_message.data)
|
||||||
if app_message.publish_packet is not None and app_message.publish_packet.retain_flag:
|
if app_message.publish_packet and app_message.publish_packet.retain_flag:
|
||||||
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
||||||
# Recreate the wait_deliver task
|
return True
|
||||||
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
self.logger.debug("Client loop cancelled")
|
|
||||||
break
|
|
||||||
disconnect_waiter.cancel()
|
|
||||||
subscribe_waiter.cancel()
|
|
||||||
unsubscribe_waiter.cancel()
|
|
||||||
wait_deliver.cancel()
|
|
||||||
|
|
||||||
self.logger.debug(f"{client_session.client_id} Client disconnected")
|
# async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
|
||||||
server.release_connection()
|
# """Create a BrokerProtocolHandler and attach to a session."""
|
||||||
|
# handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
|
||||||
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
|
# handler.attach(session, reader, writer)
|
||||||
"""Create a BrokerProtocolHandler and attach to a session.
|
# return handler
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
|
|
||||||
handler.attach(session, reader, writer)
|
|
||||||
return handler
|
|
||||||
|
|
||||||
async def _stop_handler(self, handler: BrokerProtocolHandler) -> None:
|
async def _stop_handler(self, handler: BrokerProtocolHandler) -> None:
|
||||||
"""Stop a running handler and detach if from the session.
|
"""Stop a running handler and detach if from the session."""
|
||||||
|
|
||||||
:param handler:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
await handler.stop()
|
await handler.stop()
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception("Failed to stop handler")
|
self.logger.exception("Failed to stop handler")
|
||||||
|
|
||||||
async def authenticate(self, session: Session, _: dict[str, Any]) -> bool:
|
async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool:
|
||||||
"""Call the authenticate method on registered plugins to test user authentication.
|
"""Call the authenticate method on registered plugins to test user authentication.
|
||||||
|
|
||||||
User is considered authenticated if all plugins called returns True.
|
User is considered authenticated if all plugins called returns True.
|
||||||
|
@ -658,7 +739,49 @@ class Broker:
|
||||||
# If all plugins returned True, authentication is success
|
# If all plugins returned True, authentication is success
|
||||||
return auth_result
|
return auth_result
|
||||||
|
|
||||||
async def topic_filtering(self, session: Session, topic: str, action: Action) -> bool:
|
def retain_message(
|
||||||
|
self,
|
||||||
|
source_session: Session | None,
|
||||||
|
topic_name: str | None,
|
||||||
|
data: bytes | bytearray | None,
|
||||||
|
qos: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
if data and topic_name is not None:
|
||||||
|
# If retained flag set, store the message for further subscriptions
|
||||||
|
self.logger.debug(f"Retaining message on topic {topic_name}")
|
||||||
|
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
||||||
|
# [MQTT-3.3.1-10]
|
||||||
|
elif topic_name in self._retained_messages:
|
||||||
|
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
|
||||||
|
del self._retained_messages[topic_name]
|
||||||
|
|
||||||
|
async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
|
||||||
|
topic_filter, qos = subscription
|
||||||
|
if "#" in topic_filter and not topic_filter.endswith("#"):
|
||||||
|
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
|
||||||
|
return 0x80
|
||||||
|
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
|
||||||
|
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
|
||||||
|
return 0x80
|
||||||
|
# Check if the client is authorised to connect to the topic
|
||||||
|
if not await self._topic_filtering(session, topic_filter, Action.SUBSCRIBE):
|
||||||
|
return 0x80
|
||||||
|
|
||||||
|
# Ensure "max-qos" is an integer before using it
|
||||||
|
max_qos = self.config.get("max-qos", qos)
|
||||||
|
if not isinstance(max_qos, int):
|
||||||
|
max_qos = qos
|
||||||
|
|
||||||
|
qos = min(qos, max_qos)
|
||||||
|
if topic_filter not in self._subscriptions:
|
||||||
|
self._subscriptions[topic_filter] = []
|
||||||
|
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
|
||||||
|
self._subscriptions[topic_filter].append((session, qos))
|
||||||
|
else:
|
||||||
|
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
|
||||||
|
return qos
|
||||||
|
|
||||||
|
async def _topic_filtering(self, session: Session, topic: str, action: Action) -> bool:
|
||||||
"""Call the topic_filtering method on registered plugins to check that the subscription is allowed.
|
"""Call the topic_filtering method on registered plugins to check that the subscription is allowed.
|
||||||
|
|
||||||
User is considered allowed if all plugins called return True.
|
User is considered allowed if all plugins called return True.
|
||||||
|
@ -690,48 +813,29 @@ class Broker:
|
||||||
)
|
)
|
||||||
return all(result for result in results.values())
|
return all(result for result in results.values())
|
||||||
|
|
||||||
def retain_message(
|
async def _delete_session(self, client_id: str) -> None:
|
||||||
self,
|
"""Delete an existing session data, for example due to clean session set in CONNECT."""
|
||||||
source_session: Session | None,
|
session = self._sessions.pop(client_id, (None, None))[0]
|
||||||
topic_name: str | None,
|
|
||||||
data: bytes | bytearray | None,
|
|
||||||
qos: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
if data and topic_name is not None:
|
|
||||||
# If retained flag set, store the message for further subscriptions
|
|
||||||
self.logger.debug(f"Retaining message on topic {topic_name}")
|
|
||||||
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
|
||||||
# [MQTT-3.3.1-10]
|
|
||||||
elif topic_name in self._retained_messages:
|
|
||||||
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
|
|
||||||
del self._retained_messages[topic_name]
|
|
||||||
|
|
||||||
# NOTE: issue #61 remove try block
|
if session is None:
|
||||||
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
|
self.logger.debug(f"Delete session : session {client_id} doesn't exist")
|
||||||
topic_filter, qos = subscription
|
return
|
||||||
if "#" in topic_filter and not topic_filter.endswith("#"):
|
self.logger.debug(f"Deleted existing session {session!r}")
|
||||||
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
|
|
||||||
return 0x80
|
|
||||||
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
|
|
||||||
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
|
|
||||||
return 0x80
|
|
||||||
# Check if the client is authorised to connect to the topic
|
|
||||||
if not await self.topic_filtering(session, topic_filter, Action.SUBSCRIBE):
|
|
||||||
return 0x80
|
|
||||||
|
|
||||||
# Ensure "max-qos" is an integer before using it
|
# Delete subscriptions
|
||||||
max_qos = self.config.get("max-qos", qos)
|
self.logger.debug(f"Deleting session {session!r} subscriptions")
|
||||||
if not isinstance(max_qos, int):
|
await self._del_all_subscriptions(session)
|
||||||
max_qos = qos
|
session.clear_queues()
|
||||||
|
|
||||||
qos = min(qos, max_qos)
|
async def _del_all_subscriptions(self, session: Session) -> None:
|
||||||
if topic_filter not in self._subscriptions:
|
"""Delete all topic subscriptions for a given session."""
|
||||||
self._subscriptions[topic_filter] = []
|
filter_queue: deque[str] = deque()
|
||||||
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
|
for topic in self._subscriptions:
|
||||||
self._subscriptions[topic_filter].append((session, qos))
|
if self._del_subscription(topic, session):
|
||||||
else:
|
filter_queue.append(topic)
|
||||||
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
|
for topic in filter_queue:
|
||||||
return qos
|
if not self._subscriptions[topic]:
|
||||||
|
del self._subscriptions[topic]
|
||||||
|
|
||||||
def _del_subscription(self, a_filter: str, session: Session) -> int:
|
def _del_subscription(self, a_filter: str, session: Session) -> int:
|
||||||
"""Delete a session subscription on a given topic.
|
"""Delete a session subscription on a given topic.
|
||||||
|
@ -752,32 +856,9 @@ class Broker:
|
||||||
deleted += 1
|
deleted += 1
|
||||||
break
|
break
|
||||||
except KeyError:
|
except KeyError:
|
||||||
# Unsubscribe topic not found in current subscribed topics
|
self.logger.debug(f"Unsubscription on topic '{a_filter}' for client {format_client_message(session=session)}")
|
||||||
pass
|
|
||||||
return deleted
|
return deleted
|
||||||
|
|
||||||
def _del_all_subscriptions(self, session: Session) -> None:
|
|
||||||
"""Delete all topic subscriptions for a given session.
|
|
||||||
|
|
||||||
:param session:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
filter_queue: deque[str] = deque()
|
|
||||||
for topic in self._subscriptions:
|
|
||||||
if self._del_subscription(topic, session):
|
|
||||||
filter_queue.append(topic)
|
|
||||||
for topic in filter_queue:
|
|
||||||
if not self._subscriptions[topic]:
|
|
||||||
del self._subscriptions[topic]
|
|
||||||
|
|
||||||
def matches(self, topic: str, a_filter: str) -> bool:
|
|
||||||
if "#" not in a_filter and "+" not in a_filter:
|
|
||||||
# if filter doesn't contain wildcard, return exact match
|
|
||||||
return a_filter == topic
|
|
||||||
# else use regex
|
|
||||||
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
|
|
||||||
return bool(match_pattern.fullmatch(topic))
|
|
||||||
|
|
||||||
async def _broadcast_loop(self) -> None:
|
async def _broadcast_loop(self) -> None:
|
||||||
"""Run the main loop to broadcast messages."""
|
"""Run the main loop to broadcast messages."""
|
||||||
running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = self._tasks_queue
|
running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = self._tasks_queue
|
||||||
|
@ -824,7 +905,7 @@ class Broker:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip all subscriptions which do not match the topic
|
# Skip all subscriptions which do not match the topic
|
||||||
if not self.matches(broadcast["topic"], k_filter):
|
if not self._matches(broadcast["topic"], k_filter):
|
||||||
self.logger.debug(f"Topic '{broadcast['topic']}' does not match filter '{k_filter}'")
|
self.logger.debug(f"Topic '{broadcast['topic']}' does not match filter '{k_filter}'")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -892,7 +973,7 @@ class Broker:
|
||||||
broadcast["qos"] = force_qos
|
broadcast["qos"] = force_qos
|
||||||
await self._broadcast_queue.put(broadcast)
|
await self._broadcast_queue.put(broadcast)
|
||||||
|
|
||||||
async def publish_session_retained_messages(self, session: Session) -> None:
|
async def _publish_session_retained_messages(self, session: Session) -> None:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Publishing {session.retained_messages.qsize()}"
|
f"Publishing {session.retained_messages.qsize()}"
|
||||||
f" messages retained for session {format_client_message(session=session)}",
|
f" messages retained for session {format_client_message(session=session)}",
|
||||||
|
@ -910,7 +991,7 @@ class Broker:
|
||||||
if publish_tasks:
|
if publish_tasks:
|
||||||
await asyncio.wait(publish_tasks)
|
await asyncio.wait(publish_tasks)
|
||||||
|
|
||||||
async def publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None:
|
async def _publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Begin broadcasting messages retained due to subscription on '{subscription[0]}'"
|
f"Begin broadcasting messages retained due to subscription on '{subscription[0]}'"
|
||||||
f" from {format_client_message(session=session)}",
|
f" from {format_client_message(session=session)}",
|
||||||
|
@ -920,7 +1001,7 @@ class Broker:
|
||||||
topic_filter, qos = subscription
|
topic_filter, qos = subscription
|
||||||
for topic, retained in self._retained_messages.items():
|
for topic, retained in self._retained_messages.items():
|
||||||
self.logger.debug(f"matching : {topic} {topic_filter}")
|
self.logger.debug(f"matching : {topic} {topic_filter}")
|
||||||
if self.matches(topic, topic_filter):
|
if self._matches(topic, topic_filter):
|
||||||
self.logger.debug(f"{topic} and {topic_filter} match")
|
self.logger.debug(f"{topic} and {topic_filter} match")
|
||||||
handler = self._get_handler(session)
|
handler = self._get_handler(session)
|
||||||
if handler:
|
if handler:
|
||||||
|
@ -936,62 +1017,16 @@ class Broker:
|
||||||
f" from {format_client_message(session=session)}",
|
f" from {format_client_message(session=session)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_session(self, client_id: str) -> None:
|
def _matches(self, topic: str, a_filter: str) -> bool:
|
||||||
"""Delete an existing session data, for example due to clean session set in CONNECT."""
|
if "#" not in a_filter and "+" not in a_filter:
|
||||||
session = self._sessions.pop(client_id, (None, None))[0]
|
# if filter doesn't contain wildcard, return exact match
|
||||||
|
return a_filter == topic
|
||||||
if session is None:
|
# else use regex
|
||||||
self.logger.debug(f"Delete session : session {client_id} doesn't exist")
|
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
|
||||||
return
|
return bool(match_pattern.fullmatch(topic))
|
||||||
self.logger.debug(f"Deleted existing session {session!r}")
|
|
||||||
|
|
||||||
# Delete subscriptions
|
|
||||||
self.logger.debug(f"Deleting session {session!r} subscriptions")
|
|
||||||
self._del_all_subscriptions(session)
|
|
||||||
session.clear_queues()
|
|
||||||
|
|
||||||
def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:
|
def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:
|
||||||
client_id = session.client_id
|
client_id = session.client_id
|
||||||
if client_id:
|
if client_id:
|
||||||
return self._sessions.get(client_id, (None, None))[1]
|
return self._sessions.get(client_id, (None, None))[1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
|
|
||||||
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
|
|
||||||
|
|
||||||
NOTE: issue #72
|
|
||||||
|
|
||||||
- Address can be specified using one of the following methods:
|
|
||||||
- empty string - all interfaces default port
|
|
||||||
- 1883 - Port number only (listen all interfaces)
|
|
||||||
- :1883 - Port number only (listen all interfaces)
|
|
||||||
- 0.0.0.0:1883 - IPv4 address
|
|
||||||
- [::]:1883 - IPv6 address
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _parse_port(port_str: str) -> int:
|
|
||||||
port_str = port_str.removeprefix(":")
|
|
||||||
|
|
||||||
if not port_str:
|
|
||||||
return default_port
|
|
||||||
|
|
||||||
return int(port_str)
|
|
||||||
|
|
||||||
if port_str.startswith("["): # IPv6 literal
|
|
||||||
try:
|
|
||||||
addr_end = port_str.index("]")
|
|
||||||
except ValueError as e:
|
|
||||||
msg = "Expecting '[' to be followed by ']'"
|
|
||||||
raise ValueError(msg) from e
|
|
||||||
|
|
||||||
return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
|
|
||||||
|
|
||||||
if ":" in port_str:
|
|
||||||
address, port_str = port_str.rsplit(":", 1)
|
|
||||||
return (address or None, _parse_port(port_str))
|
|
||||||
|
|
||||||
try:
|
|
||||||
return (None, _parse_port(port_str))
|
|
||||||
except ValueError:
|
|
||||||
return (port_str, default_port)
|
|
||||||
|
|
|
@ -526,8 +526,7 @@ class MQTTClient:
|
||||||
while self.client_tasks:
|
while self.client_tasks:
|
||||||
task = self.client_tasks.popleft()
|
task = self.client_tasks.popleft()
|
||||||
if not task.done():
|
if not task.done():
|
||||||
# task.set_exception(ClientError("Connection lost"))
|
task.cancel()
|
||||||
task.cancel() # NOTE: issue #153
|
|
||||||
|
|
||||||
self.logger.debug("Monitoring broker disconnection")
|
self.logger.debug("Monitoring broker disconnection")
|
||||||
# Wait for disconnection from broker (like connection lost)
|
# Wait for disconnection from broker (like connection lost)
|
||||||
|
|
|
@ -195,7 +195,6 @@ class ClientProtocolHandler(ProtocolHandler):
|
||||||
self.logger.debug("Broker closed connection")
|
self.logger.debug("Broker closed connection")
|
||||||
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
|
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
|
||||||
self._disconnect_waiter.set_result(None)
|
self._disconnect_waiter.set_result(None)
|
||||||
# await self.stop() # NOTE: issue #119
|
|
||||||
|
|
||||||
async def wait_disconnect(self) -> None:
|
async def wait_disconnect(self) -> None:
|
||||||
if self._disconnect_waiter is not None:
|
if self._disconnect_waiter is not None:
|
||||||
|
|
|
@ -67,15 +67,11 @@ class ProtocolHandler:
|
||||||
self.writer: WriterAdapter | None = None
|
self.writer: WriterAdapter | None = None
|
||||||
self.plugins_manager: PluginManager = plugins_manager
|
self.plugins_manager: PluginManager = plugins_manager
|
||||||
|
|
||||||
# TODO: check how to update loop usage best
|
try:
|
||||||
self._loop = loop if loop is not None else asyncio.get_event_loop_policy().get_event_loop()
|
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||||
# try:
|
except RuntimeError:
|
||||||
# # Use the currently running loop if available
|
self._loop = asyncio.new_event_loop()
|
||||||
# self._loop = loop if loop is not None else asyncio.get_running_loop()
|
asyncio.set_event_loop(self._loop)
|
||||||
# except RuntimeError:
|
|
||||||
# # If no running loop is found, create a new one
|
|
||||||
# self._loop = asyncio.new_event_loop()
|
|
||||||
# asyncio.set_event_loop(self._loop)
|
|
||||||
|
|
||||||
self._reader_task: asyncio.Task[None] | None = None
|
self._reader_task: asyncio.Task[None] | None = None
|
||||||
self._keepalive_task: asyncio.TimerHandle | None = None
|
self._keepalive_task: asyncio.TimerHandle | None = None
|
||||||
|
|
|
@ -27,8 +27,6 @@ def get_plugin_manager(namespace: str) -> "PluginManager | None":
|
||||||
class BaseContext:
|
class BaseContext:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.loop: asyncio.AbstractEventLoop | None = None
|
self.loop: asyncio.AbstractEventLoop | None = None
|
||||||
# TODO: change this usage
|
|
||||||
# self.logger: logging.Logger | None = None
|
|
||||||
self.logger: logging.Logger = _LOGGER
|
self.logger: logging.Logger = _LOGGER
|
||||||
self.config: dict[str, Any] | None = None
|
self.config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
@ -41,16 +39,11 @@ class PluginManager:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
|
def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
|
||||||
# TODO: check how to update loop usage best
|
|
||||||
# self._loop = loop if loop is not None else asyncio.get_event_loop_policy().get_event_loop()
|
|
||||||
if loop is None:
|
|
||||||
try:
|
try:
|
||||||
self._loop = asyncio.get_running_loop()
|
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
# asyncio.set_event_loop(self._loop)
|
asyncio.set_event_loop(self._loop)
|
||||||
else:
|
|
||||||
self._loop = loop
|
|
||||||
|
|
||||||
self.logger = logging.getLogger(namespace)
|
self.logger = logging.getLogger(namespace)
|
||||||
self.context = context if context is not None else BaseContext()
|
self.context = context if context is not None else BaseContext()
|
||||||
|
|
|
@ -158,7 +158,6 @@ timeout = 10
|
||||||
|
|
||||||
# ------------------------------------ MYPY ------------------------------------
|
# ------------------------------------ MYPY ------------------------------------
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
# mypy_path = "amqtt"
|
|
||||||
exclude = ["^tests/.*", "^docs/.*", "^samples/.*"]
|
exclude = ["^tests/.*", "^docs/.*", "^samples/.*"]
|
||||||
follow_imports = "silent"
|
follow_imports = "silent"
|
||||||
show_error_codes = true
|
show_error_codes = true
|
||||||
|
@ -235,10 +234,11 @@ never-returning-functions = ["sys.exit", "argparse.parse_error"]
|
||||||
|
|
||||||
[tool.pylint.DESIGN]
|
[tool.pylint.DESIGN]
|
||||||
max-branches = 20 # too-many-branches
|
max-branches = 20 # too-many-branches
|
||||||
max-parents = 10
|
max-parents = 10 # too-many-parents
|
||||||
max-positional-arguments = 10 # too-many-positional-arguments
|
max-positional-arguments = 10 # too-many-positional-arguments
|
||||||
max-returns = 7
|
max-returns = 7 # too-many-returns
|
||||||
max-statements = 61 # too-many-statements
|
max-statements = 61 # too-many-statements
|
||||||
|
max-module-lines = 1500 # too-many-lines
|
||||||
|
|
||||||
# ---------------------------------- COVERAGE ----------------------------------
|
# ---------------------------------- COVERAGE ----------------------------------
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from unittest.mock import MagicMock, call
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -25,6 +25,7 @@ from amqtt.mqtt.connack import ConnackPacket
|
||||||
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
|
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
|
||||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||||
from amqtt.mqtt.disconnect import DisconnectPacket
|
from amqtt.mqtt.disconnect import DisconnectPacket
|
||||||
|
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
||||||
from amqtt.mqtt.pubcomp import PubcompPacket
|
from amqtt.mqtt.pubcomp import PubcompPacket
|
||||||
from amqtt.mqtt.publish import PublishPacket
|
from amqtt.mqtt.publish import PublishPacket
|
||||||
from amqtt.mqtt.pubrec import PubrecPacket
|
from amqtt.mqtt.pubrec import PubrecPacket
|
||||||
|
@ -677,7 +678,7 @@ def test_matches_multi_level_wildcard(broker):
|
||||||
"sport/tennis",
|
"sport/tennis",
|
||||||
"sport/tennis/",
|
"sport/tennis/",
|
||||||
]:
|
]:
|
||||||
assert not broker.matches(bad_topic, test_filter)
|
assert not broker._matches(bad_topic, test_filter)
|
||||||
|
|
||||||
for good_topic in [
|
for good_topic in [
|
||||||
"sport/tennis/player1",
|
"sport/tennis/player1",
|
||||||
|
@ -685,7 +686,7 @@ def test_matches_multi_level_wildcard(broker):
|
||||||
"sport/tennis/player1/ranking",
|
"sport/tennis/player1/ranking",
|
||||||
"sport/tennis/player1/score/wimbledon",
|
"sport/tennis/player1/score/wimbledon",
|
||||||
]:
|
]:
|
||||||
assert broker.matches(good_topic, test_filter)
|
assert broker._matches(good_topic, test_filter)
|
||||||
|
|
||||||
|
|
||||||
def test_matches_single_level_wildcard(broker):
|
def test_matches_single_level_wildcard(broker):
|
||||||
|
@ -696,37 +697,36 @@ def test_matches_single_level_wildcard(broker):
|
||||||
"sport/tennis/player1/",
|
"sport/tennis/player1/",
|
||||||
"sport/tennis/player1/ranking",
|
"sport/tennis/player1/ranking",
|
||||||
]:
|
]:
|
||||||
assert not broker.matches(bad_topic, test_filter)
|
assert not broker._matches(bad_topic, test_filter)
|
||||||
|
|
||||||
for good_topic in [
|
for good_topic in [
|
||||||
"sport/tennis/",
|
"sport/tennis/",
|
||||||
"sport/tennis/player1",
|
"sport/tennis/player1",
|
||||||
"sport/tennis/player2",
|
"sport/tennis/player2",
|
||||||
]:
|
]:
|
||||||
assert broker.matches(good_topic, test_filter)
|
assert broker._matches(good_topic, test_filter)
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
# async def test_broker_broadcast_cancellation(broker):
|
async def test_broker_broadcast_cancellation(broker):
|
||||||
# topic = "test"
|
topic = "test"
|
||||||
# data = b"data"
|
data = b"data"
|
||||||
# qos = QOS_0
|
qos = QOS_0
|
||||||
|
|
||||||
# sub_client = MQTTClient()
|
sub_client = MQTTClient()
|
||||||
# await sub_client.connect("mqtt://127.0.0.1")
|
await sub_client.connect("mqtt://127.0.0.1")
|
||||||
# await sub_client.subscribe([(topic, qos)])
|
await sub_client.subscribe([(topic, qos)])
|
||||||
|
|
||||||
# with patch.object(BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError) as mocked_mqtt_publish:
|
with patch.object(BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError) as mocked_mqtt_publish:
|
||||||
# await _client_publish(topic, data, qos)
|
await _client_publish(topic, data, qos)
|
||||||
|
|
||||||
# # Second publish triggers the awaiting of first `mqtt_publish` task
|
# Second publish triggers the awaiting of first `mqtt_publish` task
|
||||||
# await _client_publish(topic, data, qos)
|
await _client_publish(topic, data, qos)
|
||||||
# await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
# # `assert_awaited` does not exist in Python before `3.8`
|
mocked_mqtt_publish.assert_awaited()
|
||||||
# mocked_mqtt_publish.assert_awaited()
|
|
||||||
|
|
||||||
# # Ensure broadcast loop is still functional and can deliver the message
|
# Ensure broadcast loop is still functional and can deliver the message
|
||||||
# await _client_publish(topic, data, qos)
|
await _client_publish(topic, data, qos)
|
||||||
# message = await asyncio.wait_for(sub_client.deliver_message(), timeout_duration=1)
|
message = await asyncio.wait_for(sub_client.deliver_message(), timeout=1)
|
||||||
# assert message
|
assert message
|
||||||
|
|
Ładowanie…
Reference in New Issue