refactor: modularize broker listener startup and improve session management

- Extracted listener startup logic into a separate method `_start_listeners` for better readability and maintainability.
- Created helper methods for SSL context creation.
- Improved session cleanup during shutdown by centralizing logic in `_cleanup_session`.
- Refactored message handling and subscription management into dedicated methods for clarity.
- Updated tests to reflect changes in method visibility for matching topics.
pull/165/head
MVladislav 2025-04-06 19:03:30 +02:00
rodzic 3e0902cc8b
commit 12d5cb6866
7 zmienionych plików z 439 dodań i 417 usunięć

Wyświetl plik

@ -242,72 +242,10 @@ class Broker:
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START) await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
try: try:
# Start network listeners await self._start_listeners()
for listener_name, listener in self.listeners_config.items():
if "bind" not in listener:
self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
continue
max_connections = listener.get("max_connections", -1)
ssl_conection = None
ssl_active = listener.get("ssl", False) # accept string "on" / "off" or boolean
if isinstance(ssl_active, str):
ssl_active = ssl_active.upper() == "ON"
if ssl_active:
try:
ssl_conection = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH,
cafile=listener.get("cafile"),
capath=listener.get("capath"),
cadata=listener.get("cadata"),
)
ssl_conection.load_cert_chain(listener["certfile"], listener["keyfile"])
ssl_conection.verify_mode = ssl.CERT_OPTIONAL
except KeyError as ke:
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
raise BrokerError(msg) from ke
except FileNotFoundError as fnfe:
msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
raise BrokerError(msg) from fnfe
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
instance: asyncio.Server | websockets.asyncio.server.Server | None = None
if listener["type"] == "tcp":
cb_partial = partial(self.stream_connected, listener_name=listener_name)
instance = await asyncio.start_server(
cb_partial,
address,
port,
reuse_address=True,
ssl=ssl_conection,
)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
elif listener["type"] == "ws":
cb_partial = partial(self.ws_connected, listener_name=listener_name)
instance = await websockets.serve(
cb_partial,
address,
port,
ssl=ssl_conection,
subprotocols=[websockets.Subprotocol("mqtt")],
)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
self.transitions.starting_success() self.transitions.starting_success()
await self.plugins_manager.fire_event(EVENT_BROKER_POST_START) await self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
# Start broadcast loop
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop()) self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
self.logger.debug("Broker started") self.logger.debug("Broker started")
except Exception as e: except Exception as e:
self.logger.exception("Broker startup failed") self.logger.exception("Broker startup failed")
@ -315,47 +253,127 @@ class Broker:
msg = f"Broker instance can't be started: {e}" msg = f"Broker instance can't be started: {e}"
raise BrokerError(msg) from e raise BrokerError(msg) from e
async def _start_listeners(self) -> None:
"""Start network listeners based on the configuration."""
for listener_name, listener in self.listeners_config.items():
if "bind" not in listener:
self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
continue
max_connections = listener.get("max_connections", -1)
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
@classmethod
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
- Address can be specified using one of the following methods:
- empty string - all interfaces default port
- 1883 - Port number only (listen all interfaces)
- :1883 - Port number only (listen all interfaces)
- 0.0.0.0:1883 - IPv4 address
- [::]:1883 - IPv6 address
"""
def _parse_port(port_str: str) -> int:
port_str = port_str.removeprefix(":")
if not port_str:
return default_port
return int(port_str)
if port_str.startswith("["): # IPv6 literal
try:
addr_end = port_str.index("]")
except ValueError as e:
msg = "Expecting '[' to be followed by ']'"
raise ValueError(msg) from e
return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
if ":" in port_str:
address, port_str = port_str.rsplit(":", 1)
return (address or None, _parse_port(port_str))
try:
return (None, _parse_port(port_str))
except ValueError:
return (port_str, default_port)
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext:
"""Create an SSL context for a listener."""
try:
ssl_context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH,
cafile=listener.get("cafile"),
capath=listener.get("capath"),
cadata=listener.get("cadata"),
)
ssl_context.load_cert_chain(listener["certfile"], listener["keyfile"])
ssl_context.verify_mode = ssl.CERT_OPTIONAL
return ssl_context
except KeyError as ke:
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
raise BrokerError(msg) from ke
except FileNotFoundError as fnfe:
msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
raise BrokerError(msg) from fnfe
async def _create_server_instance(
self,
listener_name: str,
listener_type: str,
address: str | None,
port: int,
ssl_context: ssl.SSLContext | None,
) -> asyncio.Server | websockets.asyncio.server.Server:
"""Create a server instance for a listener."""
if listener_type == "tcp":
return await asyncio.start_server(
partial(self.stream_connected, listener_name=listener_name),
address,
port,
reuse_address=True,
ssl=ssl_context,
)
if listener_type == "ws":
return await websockets.serve(
partial(self.ws_connected, listener_name=listener_name),
address,
port,
ssl=ssl_context,
subprotocols=[websockets.Subprotocol("mqtt")],
)
msg = f"Unsupported listener type: {listener_type}"
raise BrokerError(msg)
async def shutdown(self) -> None: async def shutdown(self) -> None:
"""Stop broker instance.""" """Stop broker instance."""
try: self.logger.info("Shutting down broker...")
# # Wait for all in-flight tasks to complete before stopping session handlers # Fire broker_shutdown event to plugins
# for client_id, (session, handler) in self._sessions.items(): await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
# if handler:
# self.logger.debug(f"Waiting for in-flight tasks to complete for session {client_id}")
# if session.inflight_out:
# # Directly use asyncio.sleep or another async operation in the loop
# await asyncio.gather(
# *(asyncio.sleep(0) for _ in session.inflight_out.values()),
# return_exceptions=True,
# )
# Stop all session handlers # Cleanup all sessions
for client_id, (_, handler) in self._sessions.items(): for client_id in list(self._sessions.keys()):
if handler: await self._cleanup_session(client_id)
self.logger.debug(f"Stopping handler for session {client_id}")
await self._stop_handler(handler)
# Clear subscriptions
for topic, subscriptions in self._subscriptions.items():
self.logger.debug(f"Clearing subscriptions for topic '{topic}'")
for session, _ in subscriptions:
self._del_subscription(topic, session)
self._subscriptions.clear()
# Clear retained messages # Clear retained messages
if self._retained_messages:
self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages") self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages")
self._retained_messages.clear() self._retained_messages.clear()
self._sessions.clear()
self.transitions.shutdown() self.transitions.shutdown()
except (MachineError, ValueError) as exc:
# Backwards compat: MachineError is raised by transitions < 0.5.0.
self.logger.debug(f"Invalid method call at this moment: {exc}")
raise
# Fire broker_shutdown event to plugins
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
await self._shutdown_broadcast_loop() await self._shutdown_broadcast_loop()
@ -372,59 +390,89 @@ class Broker:
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN) await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
self.transitions.stopping_success() self.transitions.stopping_success()
async def _cleanup_session(self, client_id: str) -> None:
"""Centralized cleanup logic for a session."""
session, handler = self._sessions.pop(client_id, (None, None))
if handler:
self.logger.debug(f"Stopping handler for session {client_id}")
await self._stop_handler(handler)
if session:
self.logger.debug(f"Clearing all subscriptions for session {client_id}")
await self._del_all_subscriptions(session)
session.clear_queues()
async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None: async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None:
return await self._broadcast_message(None, topic, data, qos) return await self._broadcast_message(None, topic, data, qos)
async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None: async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None:
await self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket)) await self._client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None: async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
await self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
async def client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
# Wait for connection available on listener """Handle a new client connection."""
server = self._servers.get(listener_name, None) server = self._servers.get(listener_name)
if not server: if not server:
msg = f"Invalid listener name '{listener_name}'" msg = f"Invalid listener name '{listener_name}'"
raise BrokerError(msg) raise BrokerError(msg)
await server.acquire_connection()
await server.acquire_connection()
remote_info = writer.get_peer_info() remote_info = writer.get_peer_info()
if remote_info is None: if remote_info is None:
self.logger.warning("remote info could not get from peer info") self.logger.warning("Remote info could not be retrieved from peer info")
return return
remote_address, remote_port = remote_info remote_address, remote_port = remote_info
self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'") self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'")
try:
handler, client_session = await self._initialize_client_session(reader, writer, remote_address, remote_port)
except (AMQTTError, MQTTError, NoDataError) as exc:
self.logger.warning(f"Failed to initialize client session: {exc}")
server.release_connection()
return
try:
await self._handle_client_session(reader, writer, client_session, handler, server, listener_name)
except (AMQTTError, MQTTError, NoDataError) as exc:
self.logger.warning(f"Error while handling client session: {exc}")
finally:
self.logger.debug(f"{client_session.client_id} Client disconnected")
server.release_connection()
async def _initialize_client_session(
self,
reader: ReaderAdapter,
writer: WriterAdapter,
remote_address: str,
remote_port: int,
) -> tuple[BrokerProtocolHandler, Session]:
"""Initialize a client session and protocol handler."""
# Wait for first packet and expect a CONNECT # Wait for first packet and expect a CONNECT
try: try:
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager) handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
except AMQTTError as exc: except AMQTTError as exc:
self.logger.warning( self.logger.warning(
f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:" f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:"
f"Can't read first packet an CONNECT: {exc}", f" Can't read first packet as CONNECT: {exc}",
) )
self.logger.debug("Connection closed") raise AMQTTError(exc) from exc
server.release_connection() except MQTTError as exc:
return
except MQTTError:
self.logger.exception( self.logger.exception(
f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}", f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}",
) )
await writer.close() await writer.close()
server.release_connection() raise MQTTError(exc) from exc
self.logger.debug("Connection closed") except NoDataError as exc:
return self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {exc}") # noqa: TRY400 # cannot replace with exception else pytest fails
except NoDataError as ne: raise NoDataError(exc) from exc
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {ne}") # noqa: TRY400 # cannot replace with exception else test fails
server.release_connection()
return
if client_session.clean_session: if client_session.clean_session:
# Delete existing session and create a new one # Delete existing session and create a new one
if client_session.client_id is not None and client_session.client_id != "": if client_session.client_id is not None and client_session.client_id != "":
self.delete_session(client_session.client_id) await self._delete_session(client_session.client_id)
else: else:
client_session.client_id = gen_client_id() client_session.client_id = gen_client_id()
client_session.parent = 0 client_session.parent = 0
@ -436,21 +484,32 @@ class Broker:
else: else:
client_session.parent = 0 client_session.parent = 0
timeout_disconnect_delay = self.config.get("timeout-disconnect-delay", 0)
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
client_session.keep_alive += timeout_disconnect_delay
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
return handler, client_session
async def _handle_client_session(
self,
reader: ReaderAdapter,
writer: WriterAdapter,
client_session: Session,
handler: BrokerProtocolHandler,
server: Server,
listener_name: str,
) -> None:
"""Handle the lifecycle of a client session."""
authenticated = await self._authenticate(client_session, self.listeners_config[listener_name])
if not authenticated:
await writer.close()
return
if client_session.client_id is None: if client_session.client_id is None:
msg = "Client ID was not correctly created/set." msg = "Client ID was not correctly created/set."
raise BrokerError(msg) raise BrokerError(msg)
timeout_disconnect_delay = self.config.get("timeout-disconnect-delay")
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
client_session.keep_alive += timeout_disconnect_delay
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
authenticated = await self.authenticate(client_session, self.listeners_config[listener_name])
if not authenticated:
await writer.close()
server.release_connection() # Delete client from connections list
return
while True: while True:
try: try:
client_session.transitions.connect() client_session.transitions.connect()
@ -469,14 +528,17 @@ class Broker:
self._sessions[client_session.client_id] = (client_session, handler) self._sessions[client_session.client_id] = (client_session, handler)
await handler.mqtt_connack_authorize(authenticated) await handler.mqtt_connack_authorize(authenticated)
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id) await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
self.logger.debug(f"{client_session.client_id} Start messages handling") self.logger.debug(f"{client_session.client_id} Start messages handling")
await handler.start() await handler.start()
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}") self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
await self.publish_session_retained_messages(client_session) await self._publish_session_retained_messages(client_session)
await self._client_message_loop(client_session, handler)
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
"""Run the main loop to handle client messages."""
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect) # Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect()) disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription()) subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
@ -495,7 +557,41 @@ class Broker:
], ],
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
if disconnect_waiter in done: if disconnect_waiter in done:
connected = await self._handle_disconnect(client_session, handler, disconnect_waiter)
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
if subscribe_waiter in done:
await self._handle_subscription(client_session, handler, subscribe_waiter)
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
self.logger.debug(repr(self._subscriptions))
if unsubscribe_waiter in done:
await self._handle_unsubscription(client_session, handler, unsubscribe_waiter)
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
if wait_deliver in done:
if not await self._handle_message_delivery(client_session, handler, wait_deliver):
break
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
except asyncio.CancelledError:
self.logger.debug("Client loop cancelled")
break
disconnect_waiter.cancel()
subscribe_waiter.cancel()
unsubscribe_waiter.cancel()
wait_deliver.cancel()
async def _handle_disconnect(
self,
client_session: Session,
handler: BrokerProtocolHandler,
disconnect_waiter: asyncio.Future[Any],
) -> bool:
"""Handle client disconnection."""
result = disconnect_waiter.result() result = disconnect_waiter.result()
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}") self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
if result is None: if result is None:
@ -520,14 +616,38 @@ class Broker:
self.logger.debug(f"{client_session.client_id} Disconnecting session") self.logger.debug(f"{client_session.client_id} Disconnecting session")
await self._stop_handler(handler) await self._stop_handler(handler)
client_session.transitions.disconnect() client_session.transitions.disconnect()
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
return False
return True
async def _handle_subscription(
self,
client_session: Session,
handler: BrokerProtocolHandler,
subscribe_waiter: asyncio.Future[Any],
) -> None:
"""Handle client subscription."""
self.logger.debug(f"{client_session.client_id} handling subscription")
subscriptions = subscribe_waiter.result()
return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics]
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
for index, subscription in enumerate(subscriptions.topics):
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
await self.plugins_manager.fire_event( await self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_DISCONNECTED, EVENT_BROKER_CLIENT_SUBSCRIBED,
client_id=client_session.client_id, client_id=client_session.client_id,
topic=subscription[0],
qos=subscription[1],
) )
connected = False await self._publish_retained_messages_for_subscription(subscription, client_session)
# Recreate the disconnect_waiter task after processing
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect()) async def _handle_unsubscription(
if unsubscribe_waiter in done: self,
client_session: Session,
handler: BrokerProtocolHandler,
unsubscribe_waiter: asyncio.Future[Any],
) -> None:
"""Handle client unsubscription."""
self.logger.debug(f"{client_session.client_id} handling unsubscription") self.logger.debug(f"{client_session.client_id} handling unsubscription")
unsubscription = unsubscribe_waiter.result() unsubscription = unsubscribe_waiter.result()
for topic in unsubscription.topics: for topic in unsubscription.topics:
@ -538,55 +658,34 @@ class Broker:
topic=topic, topic=topic,
) )
await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id) await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id)
# Recreate the unsubscribe_waiter task
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
if subscribe_waiter in done:
self.logger.debug(f"{client_session.client_id} handling subscription")
subscriptions = subscribe_waiter.result()
return_codes = [
await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics
]
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes) async def _handle_message_delivery(
for index, subscription in enumerate(subscriptions.topics): self,
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED: client_session: Session,
await self.plugins_manager.fire_event( handler: BrokerProtocolHandler,
EVENT_BROKER_CLIENT_SUBSCRIBED, wait_deliver: asyncio.Future[Any],
client_id=client_session.client_id, ) -> bool:
topic=subscription[0], """Handle message delivery to the client."""
qos=subscription[1],
)
await self.publish_retained_messages_for_subscription(subscription, client_session)
# Recreate the subscribe_waiter task
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
self.logger.debug(repr(self._subscriptions))
if wait_deliver in done:
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"{client_session.client_id} handling message delivery") self.logger.debug(f"{client_session.client_id} handling message delivery")
app_message = wait_deliver.result() app_message = wait_deliver.result()
if app_message is None: if app_message is None:
self.logger.debug("app_message was empty!") self.logger.debug("app_message was empty!")
continue return True
if not app_message.topic: if not app_message.topic:
self.logger.warning( self.logger.warning(
f"[MQTT-4.7.3-1] - {client_session.client_id}" f"[MQTT-4.7.3-1] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
" invalid TOPIC sent in PUBLISH message, closing connection",
) )
break return False
if "#" in app_message.topic or "+" in app_message.topic: if "#" in app_message.topic or "+" in app_message.topic:
self.logger.warning( self.logger.warning(
f"[MQTT-3.3.2-2] - {client_session.client_id}" f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
" invalid TOPIC sent in PUBLISH message, closing connection",
) )
break return False
permitted = await self.topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH) permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
if not permitted: if not permitted:
self.logger.info( self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.",
)
else: else:
await self.plugins_manager.fire_event( await self.plugins_manager.fire_event(
EVENT_BROKER_MESSAGE_RECEIVED, EVENT_BROKER_MESSAGE_RECEIVED,
@ -594,42 +693,24 @@ class Broker:
message=app_message, message=app_message,
) )
await self._broadcast_message(client_session, app_message.topic, app_message.data) await self._broadcast_message(client_session, app_message.topic, app_message.data)
if app_message.publish_packet is not None and app_message.publish_packet.retain_flag: if app_message.publish_packet and app_message.publish_packet.retain_flag:
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos) self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
# Recreate the wait_deliver task return True
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
except asyncio.CancelledError:
self.logger.debug("Client loop cancelled")
break
disconnect_waiter.cancel()
subscribe_waiter.cancel()
unsubscribe_waiter.cancel()
wait_deliver.cancel()
self.logger.debug(f"{client_session.client_id} Client disconnected") # async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
server.release_connection() # """Create a BrokerProtocolHandler and attach to a session."""
# handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler: # handler.attach(session, reader, writer)
"""Create a BrokerProtocolHandler and attach to a session. # return handler
:return:
"""
handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
handler.attach(session, reader, writer)
return handler
async def _stop_handler(self, handler: BrokerProtocolHandler) -> None: async def _stop_handler(self, handler: BrokerProtocolHandler) -> None:
"""Stop a running handler and detach if from the session. """Stop a running handler and detach if from the session."""
:param handler:
:return:
"""
try: try:
await handler.stop() await handler.stop()
except Exception: except Exception:
self.logger.exception("Failed to stop handler") self.logger.exception("Failed to stop handler")
async def authenticate(self, session: Session, _: dict[str, Any]) -> bool: async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool:
"""Call the authenticate method on registered plugins to test user authentication. """Call the authenticate method on registered plugins to test user authentication.
User is considered authenticated if all plugins called returns True. User is considered authenticated if all plugins called returns True.
@ -658,7 +739,49 @@ class Broker:
# If all plugins returned True, authentication is success # If all plugins returned True, authentication is success
return auth_result return auth_result
async def topic_filtering(self, session: Session, topic: str, action: Action) -> bool: def retain_message(
self,
source_session: Session | None,
topic_name: str | None,
data: bytes | bytearray | None,
qos: int | None = None,
) -> None:
if data and topic_name is not None:
# If retained flag set, store the message for further subscriptions
self.logger.debug(f"Retaining message on topic {topic_name}")
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
# [MQTT-3.3.1-10]
elif topic_name in self._retained_messages:
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
del self._retained_messages[topic_name]
async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
topic_filter, qos = subscription
if "#" in topic_filter and not topic_filter.endswith("#"):
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
return 0x80
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
return 0x80
# Check if the client is authorised to connect to the topic
if not await self._topic_filtering(session, topic_filter, Action.SUBSCRIBE):
return 0x80
# Ensure "max-qos" is an integer before using it
max_qos = self.config.get("max-qos", qos)
if not isinstance(max_qos, int):
max_qos = qos
qos = min(qos, max_qos)
if topic_filter not in self._subscriptions:
self._subscriptions[topic_filter] = []
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
self._subscriptions[topic_filter].append((session, qos))
else:
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
return qos
async def _topic_filtering(self, session: Session, topic: str, action: Action) -> bool:
"""Call the topic_filtering method on registered plugins to check that the subscription is allowed. """Call the topic_filtering method on registered plugins to check that the subscription is allowed.
User is considered allowed if all plugins called return True. User is considered allowed if all plugins called return True.
@ -690,48 +813,29 @@ class Broker:
) )
return all(result for result in results.values()) return all(result for result in results.values())
def retain_message( async def _delete_session(self, client_id: str) -> None:
self, """Delete an existing session data, for example due to clean session set in CONNECT."""
source_session: Session | None, session = self._sessions.pop(client_id, (None, None))[0]
topic_name: str | None,
data: bytes | bytearray | None,
qos: int | None = None,
) -> None:
if data and topic_name is not None:
# If retained flag set, store the message for further subscriptions
self.logger.debug(f"Retaining message on topic {topic_name}")
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
# [MQTT-3.3.1-10]
elif topic_name in self._retained_messages:
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
del self._retained_messages[topic_name]
# NOTE: issue #61 remove try block if session is None:
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int: self.logger.debug(f"Delete session : session {client_id} doesn't exist")
topic_filter, qos = subscription return
if "#" in topic_filter and not topic_filter.endswith("#"): self.logger.debug(f"Deleted existing session {session!r}")
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
return 0x80
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
return 0x80
# Check if the client is authorised to connect to the topic
if not await self.topic_filtering(session, topic_filter, Action.SUBSCRIBE):
return 0x80
# Ensure "max-qos" is an integer before using it # Delete subscriptions
max_qos = self.config.get("max-qos", qos) self.logger.debug(f"Deleting session {session!r} subscriptions")
if not isinstance(max_qos, int): await self._del_all_subscriptions(session)
max_qos = qos session.clear_queues()
qos = min(qos, max_qos) async def _del_all_subscriptions(self, session: Session) -> None:
if topic_filter not in self._subscriptions: """Delete all topic subscriptions for a given session."""
self._subscriptions[topic_filter] = [] filter_queue: deque[str] = deque()
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]): for topic in self._subscriptions:
self._subscriptions[topic_filter].append((session, qos)) if self._del_subscription(topic, session):
else: filter_queue.append(topic)
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}") for topic in filter_queue:
return qos if not self._subscriptions[topic]:
del self._subscriptions[topic]
def _del_subscription(self, a_filter: str, session: Session) -> int: def _del_subscription(self, a_filter: str, session: Session) -> int:
"""Delete a session subscription on a given topic. """Delete a session subscription on a given topic.
@ -752,32 +856,9 @@ class Broker:
deleted += 1 deleted += 1
break break
except KeyError: except KeyError:
# Unsubscribe topic not found in current subscribed topics self.logger.debug(f"Unsubscription on topic '{a_filter}' for client {format_client_message(session=session)}")
pass
return deleted return deleted
def _del_all_subscriptions(self, session: Session) -> None:
"""Delete all topic subscriptions for a given session.
:param session:
:return:
"""
filter_queue: deque[str] = deque()
for topic in self._subscriptions:
if self._del_subscription(topic, session):
filter_queue.append(topic)
for topic in filter_queue:
if not self._subscriptions[topic]:
del self._subscriptions[topic]
def matches(self, topic: str, a_filter: str) -> bool:
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
# else use regex
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
return bool(match_pattern.fullmatch(topic))
async def _broadcast_loop(self) -> None: async def _broadcast_loop(self) -> None:
"""Run the main loop to broadcast messages.""" """Run the main loop to broadcast messages."""
running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = self._tasks_queue running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = self._tasks_queue
@ -824,7 +905,7 @@ class Broker:
continue continue
# Skip all subscriptions which do not match the topic # Skip all subscriptions which do not match the topic
if not self.matches(broadcast["topic"], k_filter): if not self._matches(broadcast["topic"], k_filter):
self.logger.debug(f"Topic '{broadcast['topic']}' does not match filter '{k_filter}'") self.logger.debug(f"Topic '{broadcast['topic']}' does not match filter '{k_filter}'")
continue continue
@ -892,7 +973,7 @@ class Broker:
broadcast["qos"] = force_qos broadcast["qos"] = force_qos
await self._broadcast_queue.put(broadcast) await self._broadcast_queue.put(broadcast)
async def publish_session_retained_messages(self, session: Session) -> None: async def _publish_session_retained_messages(self, session: Session) -> None:
self.logger.debug( self.logger.debug(
f"Publishing {session.retained_messages.qsize()}" f"Publishing {session.retained_messages.qsize()}"
f" messages retained for session {format_client_message(session=session)}", f" messages retained for session {format_client_message(session=session)}",
@ -910,7 +991,7 @@ class Broker:
if publish_tasks: if publish_tasks:
await asyncio.wait(publish_tasks) await asyncio.wait(publish_tasks)
async def publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None: async def _publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None:
self.logger.debug( self.logger.debug(
f"Begin broadcasting messages retained due to subscription on '{subscription[0]}'" f"Begin broadcasting messages retained due to subscription on '{subscription[0]}'"
f" from {format_client_message(session=session)}", f" from {format_client_message(session=session)}",
@ -920,7 +1001,7 @@ class Broker:
topic_filter, qos = subscription topic_filter, qos = subscription
for topic, retained in self._retained_messages.items(): for topic, retained in self._retained_messages.items():
self.logger.debug(f"matching : {topic} {topic_filter}") self.logger.debug(f"matching : {topic} {topic_filter}")
if self.matches(topic, topic_filter): if self._matches(topic, topic_filter):
self.logger.debug(f"{topic} and {topic_filter} match") self.logger.debug(f"{topic} and {topic_filter} match")
handler = self._get_handler(session) handler = self._get_handler(session)
if handler: if handler:
@ -936,62 +1017,16 @@ class Broker:
f" from {format_client_message(session=session)}", f" from {format_client_message(session=session)}",
) )
def delete_session(self, client_id: str) -> None: def _matches(self, topic: str, a_filter: str) -> bool:
"""Delete an existing session data, for example due to clean session set in CONNECT.""" if "#" not in a_filter and "+" not in a_filter:
session = self._sessions.pop(client_id, (None, None))[0] # if filter doesn't contain wildcard, return exact match
return a_filter == topic
if session is None: # else use regex
self.logger.debug(f"Delete session : session {client_id} doesn't exist") match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
return return bool(match_pattern.fullmatch(topic))
self.logger.debug(f"Deleted existing session {session!r}")
# Delete subscriptions
self.logger.debug(f"Deleting session {session!r} subscriptions")
self._del_all_subscriptions(session)
session.clear_queues()
def _get_handler(self, session: Session) -> BrokerProtocolHandler | None: def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:
client_id = session.client_id client_id = session.client_id
if client_id: if client_id:
return self._sessions.get(client_id, (None, None))[1] return self._sessions.get(client_id, (None, None))[1]
return None return None
@classmethod
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
NOTE: issue #72
- Address can be specified using one of the following methods:
- empty string - all interfaces default port
- 1883 - Port number only (listen all interfaces)
- :1883 - Port number only (listen all interfaces)
- 0.0.0.0:1883 - IPv4 address
- [::]:1883 - IPv6 address
"""
def _parse_port(port_str: str) -> int:
port_str = port_str.removeprefix(":")
if not port_str:
return default_port
return int(port_str)
if port_str.startswith("["): # IPv6 literal
try:
addr_end = port_str.index("]")
except ValueError as e:
msg = "Expecting '[' to be followed by ']'"
raise ValueError(msg) from e
return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
if ":" in port_str:
address, port_str = port_str.rsplit(":", 1)
return (address or None, _parse_port(port_str))
try:
return (None, _parse_port(port_str))
except ValueError:
return (port_str, default_port)

Wyświetl plik

@ -526,8 +526,7 @@ class MQTTClient:
while self.client_tasks: while self.client_tasks:
task = self.client_tasks.popleft() task = self.client_tasks.popleft()
if not task.done(): if not task.done():
# task.set_exception(ClientError("Connection lost")) task.cancel()
task.cancel() # NOTE: issue #153
self.logger.debug("Monitoring broker disconnection") self.logger.debug("Monitoring broker disconnection")
# Wait for disconnection from broker (like connection lost) # Wait for disconnection from broker (like connection lost)

Wyświetl plik

@ -195,7 +195,6 @@ class ClientProtocolHandler(ProtocolHandler):
self.logger.debug("Broker closed connection") self.logger.debug("Broker closed connection")
if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
self._disconnect_waiter.set_result(None) self._disconnect_waiter.set_result(None)
# await self.stop() # NOTE: issue #119
async def wait_disconnect(self) -> None: async def wait_disconnect(self) -> None:
if self._disconnect_waiter is not None: if self._disconnect_waiter is not None:

Wyświetl plik

@ -67,15 +67,11 @@ class ProtocolHandler:
self.writer: WriterAdapter | None = None self.writer: WriterAdapter | None = None
self.plugins_manager: PluginManager = plugins_manager self.plugins_manager: PluginManager = plugins_manager
# TODO: check how to update loop usage best try:
self._loop = loop if loop is not None else asyncio.get_event_loop_policy().get_event_loop() self._loop = loop if loop is not None else asyncio.get_running_loop()
# try: except RuntimeError:
# # Use the currently running loop if available self._loop = asyncio.new_event_loop()
# self._loop = loop if loop is not None else asyncio.get_running_loop() asyncio.set_event_loop(self._loop)
# except RuntimeError:
# # If no running loop is found, create a new one
# self._loop = asyncio.new_event_loop()
# asyncio.set_event_loop(self._loop)
self._reader_task: asyncio.Task[None] | None = None self._reader_task: asyncio.Task[None] | None = None
self._keepalive_task: asyncio.TimerHandle | None = None self._keepalive_task: asyncio.TimerHandle | None = None

Wyświetl plik

@ -27,8 +27,6 @@ def get_plugin_manager(namespace: str) -> "PluginManager | None":
class BaseContext: class BaseContext:
def __init__(self) -> None: def __init__(self) -> None:
self.loop: asyncio.AbstractEventLoop | None = None self.loop: asyncio.AbstractEventLoop | None = None
# TODO: change this usage
# self.logger: logging.Logger | None = None
self.logger: logging.Logger = _LOGGER self.logger: logging.Logger = _LOGGER
self.config: dict[str, Any] | None = None self.config: dict[str, Any] | None = None
@ -41,16 +39,11 @@ class PluginManager:
""" """
def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None: def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
# TODO: check how to update loop usage best
# self._loop = loop if loop is not None else asyncio.get_event_loop_policy().get_event_loop()
if loop is None:
try: try:
self._loop = asyncio.get_running_loop() self._loop = loop if loop is not None else asyncio.get_running_loop()
except RuntimeError: except RuntimeError:
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
# asyncio.set_event_loop(self._loop) asyncio.set_event_loop(self._loop)
else:
self._loop = loop
self.logger = logging.getLogger(namespace) self.logger = logging.getLogger(namespace)
self.context = context if context is not None else BaseContext() self.context = context if context is not None else BaseContext()

Wyświetl plik

@ -158,7 +158,6 @@ timeout = 10
# ------------------------------------ MYPY ------------------------------------ # ------------------------------------ MYPY ------------------------------------
[tool.mypy] [tool.mypy]
# mypy_path = "amqtt"
exclude = ["^tests/.*", "^docs/.*", "^samples/.*"] exclude = ["^tests/.*", "^docs/.*", "^samples/.*"]
follow_imports = "silent" follow_imports = "silent"
show_error_codes = true show_error_codes = true
@ -235,10 +234,11 @@ never-returning-functions = ["sys.exit", "argparse.parse_error"]
[tool.pylint.DESIGN] [tool.pylint.DESIGN]
max-branches = 20 # too-many-branches max-branches = 20 # too-many-branches
max-parents = 10 max-parents = 10 # too-many-parents
max-positional-arguments = 10 # too-many-positional-arguments max-positional-arguments = 10 # too-many-positional-arguments
max-returns = 7 max-returns = 7 # too-many-returns
max-statements = 61 # too-many-statements max-statements = 61 # too-many-statements
max-module-lines = 1500 # too-many-lines
# ---------------------------------- COVERAGE ---------------------------------- # ---------------------------------- COVERAGE ----------------------------------
[tool.coverage.run] [tool.coverage.run]

Wyświetl plik

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
import socket import socket
from unittest.mock import MagicMock, call from unittest.mock import MagicMock, call, patch
import psutil import psutil
import pytest import pytest
@ -25,6 +25,7 @@ from amqtt.mqtt.connack import ConnackPacket
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.disconnect import DisconnectPacket from amqtt.mqtt.disconnect import DisconnectPacket
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.mqtt.pubcomp import PubcompPacket from amqtt.mqtt.pubcomp import PubcompPacket
from amqtt.mqtt.publish import PublishPacket from amqtt.mqtt.publish import PublishPacket
from amqtt.mqtt.pubrec import PubrecPacket from amqtt.mqtt.pubrec import PubrecPacket
@ -677,7 +678,7 @@ def test_matches_multi_level_wildcard(broker):
"sport/tennis", "sport/tennis",
"sport/tennis/", "sport/tennis/",
]: ]:
assert not broker.matches(bad_topic, test_filter) assert not broker._matches(bad_topic, test_filter)
for good_topic in [ for good_topic in [
"sport/tennis/player1", "sport/tennis/player1",
@ -685,7 +686,7 @@ def test_matches_multi_level_wildcard(broker):
"sport/tennis/player1/ranking", "sport/tennis/player1/ranking",
"sport/tennis/player1/score/wimbledon", "sport/tennis/player1/score/wimbledon",
]: ]:
assert broker.matches(good_topic, test_filter) assert broker._matches(good_topic, test_filter)
def test_matches_single_level_wildcard(broker): def test_matches_single_level_wildcard(broker):
@ -696,37 +697,36 @@ def test_matches_single_level_wildcard(broker):
"sport/tennis/player1/", "sport/tennis/player1/",
"sport/tennis/player1/ranking", "sport/tennis/player1/ranking",
]: ]:
assert not broker.matches(bad_topic, test_filter) assert not broker._matches(bad_topic, test_filter)
for good_topic in [ for good_topic in [
"sport/tennis/", "sport/tennis/",
"sport/tennis/player1", "sport/tennis/player1",
"sport/tennis/player2", "sport/tennis/player2",
]: ]:
assert broker.matches(good_topic, test_filter) assert broker._matches(good_topic, test_filter)
# @pytest.mark.asyncio @pytest.mark.asyncio
# async def test_broker_broadcast_cancellation(broker): async def test_broker_broadcast_cancellation(broker):
# topic = "test" topic = "test"
# data = b"data" data = b"data"
# qos = QOS_0 qos = QOS_0
# sub_client = MQTTClient() sub_client = MQTTClient()
# await sub_client.connect("mqtt://127.0.0.1") await sub_client.connect("mqtt://127.0.0.1")
# await sub_client.subscribe([(topic, qos)]) await sub_client.subscribe([(topic, qos)])
# with patch.object(BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError) as mocked_mqtt_publish: with patch.object(BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError) as mocked_mqtt_publish:
# await _client_publish(topic, data, qos) await _client_publish(topic, data, qos)
# # Second publish triggers the awaiting of first `mqtt_publish` task # Second publish triggers the awaiting of first `mqtt_publish` task
# await _client_publish(topic, data, qos) await _client_publish(topic, data, qos)
# await asyncio.sleep(0.01) await asyncio.sleep(0.01)
# # `assert_awaited` does not exist in Python before `3.8` mocked_mqtt_publish.assert_awaited()
# mocked_mqtt_publish.assert_awaited()
# # Ensure broadcast loop is still functional and can deliver the message # Ensure broadcast loop is still functional and can deliver the message
# await _client_publish(topic, data, qos) await _client_publish(topic, data, qos)
# message = await asyncio.wait_for(sub_client.deliver_message(), timeout_duration=1) message = await asyncio.wait_for(sub_client.deliver_message(), timeout=1)
# assert message assert message