diff --git a/amqtt/broker.py b/amqtt/broker.py index 11f812b..22026ee 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -242,72 +242,10 @@ class Broker: await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START) try: - # Start network listeners - for listener_name, listener in self.listeners_config.items(): - if "bind" not in listener: - self.logger.debug(f"Listener configuration '{listener_name}' is not bound") - continue - - max_connections = listener.get("max_connections", -1) - - ssl_conection = None - ssl_active = listener.get("ssl", False) # accept string "on" / "off" or boolean - if isinstance(ssl_active, str): - ssl_active = ssl_active.upper() == "ON" - - if ssl_active: - try: - ssl_conection = ssl.create_default_context( - ssl.Purpose.CLIENT_AUTH, - cafile=listener.get("cafile"), - capath=listener.get("capath"), - cadata=listener.get("cadata"), - ) - ssl_conection.load_cert_chain(listener["certfile"], listener["keyfile"]) - ssl_conection.verify_mode = ssl.CERT_OPTIONAL - except KeyError as ke: - msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}" - raise BrokerError(msg) from ke - except FileNotFoundError as fnfe: - msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}" - raise BrokerError(msg) from fnfe - - try: - address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) - except ValueError as e: - msg = f"Invalid port value in bind value: {listener['bind']}" - raise BrokerError(msg) from e - - instance: asyncio.Server | websockets.asyncio.server.Server | None = None - if listener["type"] == "tcp": - cb_partial = partial(self.stream_connected, listener_name=listener_name) - instance = await asyncio.start_server( - cb_partial, - address, - port, - reuse_address=True, - ssl=ssl_conection, - ) - self._servers[listener_name] = Server(listener_name, instance, max_connections) - elif listener["type"] == "ws": - cb_partial = partial(self.ws_connected, listener_name=listener_name) - instance = await websockets.serve( - cb_partial, - address, - port, - ssl=ssl_conection, - subprotocols=[websockets.Subprotocol("mqtt")], - ) - self._servers[listener_name] = Server(listener_name, instance, max_connections) - - self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") - + await self._start_listeners() self.transitions.starting_success() await self.plugins_manager.fire_event(EVENT_BROKER_POST_START) - - # Start broadcast loop self._broadcast_task = asyncio.ensure_future(self._broadcast_loop()) - self.logger.debug("Broker started") except Exception as e: self.logger.exception("Broker startup failed") @@ -315,48 +253,128 @@ class Broker: msg = f"Broker instance can't be started: {e}" raise BrokerError(msg) from e + async def _start_listeners(self) -> None: + """Start network listeners based on the configuration.""" + for listener_name, listener in self.listeners_config.items(): + if "bind" not in listener: + self.logger.debug(f"Listener configuration '{listener_name}' is not bound") + continue + + max_connections = listener.get("max_connections", -1) + ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None + + try: + address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) + except ValueError as e: + msg = f"Invalid port value in bind value: {listener['bind']}" + raise BrokerError(msg) from e + + instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) + self._servers[listener_name] = Server(listener_name, instance, max_connections) + + self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") + + @classmethod + def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]: + """Split an address:port pair into separate IP address and port. with IPv6 special-case handling. + + - Address can be specified using one of the following methods: + - empty string - all interfaces default port + - 1883 - Port number only (listen all interfaces) + - :1883 - Port number only (listen all interfaces) + - 0.0.0.0:1883 - IPv4 address + - [::]:1883 - IPv6 address + """ + + def _parse_port(port_str: str) -> int: + port_str = port_str.removeprefix(":") + + if not port_str: + return default_port + + return int(port_str) + + if port_str.startswith("["): # IPv6 literal + try: + addr_end = port_str.index("]") + except ValueError as e: + msg = "Expecting '[' to be followed by ']'" + raise ValueError(msg) from e + + return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :])) + + if ":" in port_str: + address, port_str = port_str.rsplit(":", 1) + return (address or None, _parse_port(port_str)) + + try: + return (None, _parse_port(port_str)) + except ValueError: + return (port_str, default_port) + + def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext: + """Create an SSL context for a listener.""" + try: + ssl_context = ssl.create_default_context( + ssl.Purpose.CLIENT_AUTH, + cafile=listener.get("cafile"), + capath=listener.get("capath"), + cadata=listener.get("cadata"), + ) + ssl_context.load_cert_chain(listener["certfile"], listener["keyfile"]) + ssl_context.verify_mode = ssl.CERT_OPTIONAL + return ssl_context + except KeyError as ke: + msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}" + raise BrokerError(msg) from ke + except FileNotFoundError as fnfe: + msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}" + raise BrokerError(msg) from fnfe + + async def _create_server_instance( + self, + listener_name: str, + listener_type: str, + address: str | None, + port: int, + ssl_context: ssl.SSLContext | None, + ) -> asyncio.Server | websockets.asyncio.server.Server: + """Create a server instance for a listener.""" + if listener_type == "tcp": + return await asyncio.start_server( + partial(self.stream_connected, listener_name=listener_name), + address, + port, + reuse_address=True, + ssl=ssl_context, + ) + if listener_type == "ws": + return await websockets.serve( + partial(self.ws_connected, listener_name=listener_name), + address, + port, + ssl=ssl_context, + subprotocols=[websockets.Subprotocol("mqtt")], + ) + msg = f"Unsupported listener type: {listener_type}" + raise BrokerError(msg) + async def shutdown(self) -> None: """Stop broker instance.""" - try: - # # Wait for all in-flight tasks to complete before stopping session handlers - # for client_id, (session, handler) in self._sessions.items(): - # if handler: - # self.logger.debug(f"Waiting for in-flight tasks to complete for session {client_id}") - # if session.inflight_out: - # # Directly use asyncio.sleep or another async operation in the loop - # await asyncio.gather( - # *(asyncio.sleep(0) for _ in session.inflight_out.values()), - # return_exceptions=True, - # ) - - # Stop all session handlers - for client_id, (_, handler) in self._sessions.items(): - if handler: - self.logger.debug(f"Stopping handler for session {client_id}") - await self._stop_handler(handler) - - # Clear subscriptions - for topic, subscriptions in self._subscriptions.items(): - self.logger.debug(f"Clearing subscriptions for topic '{topic}'") - for session, _ in subscriptions: - self._del_subscription(topic, session) - self._subscriptions.clear() - - # Clear retained messages - if self._retained_messages: - self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages") - self._retained_messages.clear() - - self._sessions.clear() - self.transitions.shutdown() - except (MachineError, ValueError) as exc: - # Backwards compat: MachineError is raised by transitions < 0.5.0. - self.logger.debug(f"Invalid method call at this moment: {exc}") - raise - + self.logger.info("Shutting down broker...") # Fire broker_shutdown event to plugins await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN) + # Cleanup all sessions + for client_id in list(self._sessions.keys()): + await self._cleanup_session(client_id) + + # Clear retained messages + self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages") + self._retained_messages.clear() + + self.transitions.shutdown() + await self._shutdown_broadcast_loop() for server in self._servers.values(): @@ -372,59 +390,89 @@ class Broker: await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN) self.transitions.stopping_success() + async def _cleanup_session(self, client_id: str) -> None: + """Centralized cleanup logic for a session.""" + session, handler = self._sessions.pop(client_id, (None, None)) + + if handler: + self.logger.debug(f"Stopping handler for session {client_id}") + await self._stop_handler(handler) + if session: + self.logger.debug(f"Clearing all subscriptions for session {client_id}") + await self._del_all_subscriptions(session) + session.clear_queues() + async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None: return await self._broadcast_message(None, topic, data, qos) async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None: - await self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket)) + await self._client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket)) async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None: - await self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) + await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) - async def client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: - # Wait for connection available on listener - server = self._servers.get(listener_name, None) + async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: + """Handle a new client connection.""" + server = self._servers.get(listener_name) if not server: msg = f"Invalid listener name '{listener_name}'" raise BrokerError(msg) - await server.acquire_connection() + await server.acquire_connection() remote_info = writer.get_peer_info() if remote_info is None: - self.logger.warning("remote info could not get from peer info") + self.logger.warning("Remote info could not be retrieved from peer info") return remote_address, remote_port = remote_info self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'") + try: + handler, client_session = await self._initialize_client_session(reader, writer, remote_address, remote_port) + except (AMQTTError, MQTTError, NoDataError) as exc: + self.logger.warning(f"Failed to initialize client session: {exc}") + server.release_connection() + return + + try: + await self._handle_client_session(reader, writer, client_session, handler, server, listener_name) + except (AMQTTError, MQTTError, NoDataError) as exc: + self.logger.warning(f"Error while handling client session: {exc}") + finally: + self.logger.debug(f"{client_session.client_id} Client disconnected") + server.release_connection() + + async def _initialize_client_session( + self, + reader: ReaderAdapter, + writer: WriterAdapter, + remote_address: str, + remote_port: int, + ) -> tuple[BrokerProtocolHandler, Session]: + """Initialize a client session and protocol handler.""" # Wait for first packet and expect a CONNECT try: handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager) except AMQTTError as exc: self.logger.warning( f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:" - f"Can't read first packet an CONNECT: {exc}", + f" Can't read first packet as CONNECT: {exc}", ) - self.logger.debug("Connection closed") - server.release_connection() - return - except MQTTError: + raise AMQTTError(exc) from exc + except MQTTError as exc: self.logger.exception( f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}", ) await writer.close() - server.release_connection() - self.logger.debug("Connection closed") - return - except NoDataError as ne: - self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {ne}") # noqa: TRY400 # cannot replace with exception else test fails - server.release_connection() - return + raise MQTTError(exc) from exc + except NoDataError as exc: + self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {exc}") # noqa: TRY400 # cannot replace with exception else pytest fails + raise NoDataError(exc) from exc if client_session.clean_session: # Delete existing session and create a new one if client_session.client_id is not None and client_session.client_id != "": - self.delete_session(client_session.client_id) + await self._delete_session(client_session.client_id) else: client_session.client_id = gen_client_id() client_session.parent = 0 @@ -436,21 +484,32 @@ class Broker: else: client_session.parent = 0 + timeout_disconnect_delay = self.config.get("timeout-disconnect-delay", 0) + if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int): + client_session.keep_alive += timeout_disconnect_delay + + self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}") + return handler, client_session + + async def _handle_client_session( + self, + reader: ReaderAdapter, + writer: WriterAdapter, + client_session: Session, + handler: BrokerProtocolHandler, + server: Server, + listener_name: str, + ) -> None: + """Handle the lifecycle of a client session.""" + authenticated = await self._authenticate(client_session, self.listeners_config[listener_name]) + if not authenticated: + await writer.close() + return + if client_session.client_id is None: msg = "Client ID was not correctly created/set." raise BrokerError(msg) - timeout_disconnect_delay = self.config.get("timeout-disconnect-delay") - if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int): - client_session.keep_alive += timeout_disconnect_delay - self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}") - - authenticated = await self.authenticate(client_session, self.listeners_config[listener_name]) - if not authenticated: - await writer.close() - server.release_connection() # Delete client from connections list - return - while True: try: client_session.transitions.connect() @@ -469,14 +528,17 @@ class Broker: self._sessions[client_session.client_id] = (client_session, handler) await handler.mqtt_connack_authorize(authenticated) - await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id) self.logger.debug(f"{client_session.client_id} Start messages handling") await handler.start() self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}") - await self.publish_session_retained_messages(client_session) + await self._publish_session_retained_messages(client_session) + await self._client_message_loop(client_session, handler) + + async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None: + """Run the main loop to handle client messages.""" # Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect) disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect()) subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription()) @@ -495,141 +557,160 @@ class Broker: ], return_when=asyncio.FIRST_COMPLETED, ) - if disconnect_waiter in done: - result = disconnect_waiter.result() - self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}") - if result is None: - self.logger.debug(f"Will flag: {client_session.will_flag}") - if client_session.will_flag: - self.logger.debug( - f"Client {format_client_message(client_session)} disconnected abnormally, sending will message", - ) - await self._broadcast_message( - client_session, - client_session.will_topic, - client_session.will_message, - client_session.will_qos, - ) - if client_session.will_retain: - self.retain_message( - client_session, - client_session.will_topic, - client_session.will_message, - client_session.will_qos, - ) - self.logger.debug(f"{client_session.client_id} Disconnecting session") - await self._stop_handler(handler) - client_session.transitions.disconnect() - await self.plugins_manager.fire_event( - EVENT_BROKER_CLIENT_DISCONNECTED, - client_id=client_session.client_id, - ) - connected = False - # Recreate the disconnect_waiter task after processing - disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect()) - if unsubscribe_waiter in done: - self.logger.debug(f"{client_session.client_id} handling unsubscription") - unsubscription = unsubscribe_waiter.result() - for topic in unsubscription.topics: - self._del_subscription(topic, client_session) - await self.plugins_manager.fire_event( - EVENT_BROKER_CLIENT_UNSUBSCRIBED, - client_id=client_session.client_id, - topic=topic, - ) - await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id) - # Recreate the unsubscribe_waiter task - unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription()) - if subscribe_waiter in done: - self.logger.debug(f"{client_session.client_id} handling subscription") - subscriptions = subscribe_waiter.result() - return_codes = [ - await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics - ] - await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes) - for index, subscription in enumerate(subscriptions.topics): - if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED: - await self.plugins_manager.fire_event( - EVENT_BROKER_CLIENT_SUBSCRIBED, - client_id=client_session.client_id, - topic=subscription[0], - qos=subscription[1], - ) - await self.publish_retained_messages_for_subscription(subscription, client_session) - # Recreate the subscribe_waiter task + if disconnect_waiter in done: + connected = await self._handle_disconnect(client_session, handler, disconnect_waiter) + disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect()) + + if subscribe_waiter in done: + await self._handle_subscription(client_session, handler, subscribe_waiter) subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription()) self.logger.debug(repr(self._subscriptions)) + + if unsubscribe_waiter in done: + await self._handle_unsubscription(client_session, handler, unsubscribe_waiter) + unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription()) + if wait_deliver in done: - if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug(f"{client_session.client_id} handling message delivery") - app_message = wait_deliver.result() - - if app_message is None: - self.logger.debug("app_message was empty!") - continue - - if not app_message.topic: - self.logger.warning( - f"[MQTT-4.7.3-1] - {client_session.client_id}" - " invalid TOPIC sent in PUBLISH message, closing connection", - ) + if not await self._handle_message_delivery(client_session, handler, wait_deliver): break - if "#" in app_message.topic or "+" in app_message.topic: - self.logger.warning( - f"[MQTT-3.3.2-2] - {client_session.client_id}" - " invalid TOPIC sent in PUBLISH message, closing connection", - ) - break - - permitted = await self.topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH) - if not permitted: - self.logger.info( - f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.", - ) - else: - await self.plugins_manager.fire_event( - EVENT_BROKER_MESSAGE_RECEIVED, - client_id=client_session.client_id, - message=app_message, - ) - await self._broadcast_message(client_session, app_message.topic, app_message.data) - if app_message.publish_packet is not None and app_message.publish_packet.retain_flag: - self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos) - # Recreate the wait_deliver task wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message()) + except asyncio.CancelledError: self.logger.debug("Client loop cancelled") break + disconnect_waiter.cancel() subscribe_waiter.cancel() unsubscribe_waiter.cancel() wait_deliver.cancel() - self.logger.debug(f"{client_session.client_id} Client disconnected") - server.release_connection() + async def _handle_disconnect( + self, + client_session: Session, + handler: BrokerProtocolHandler, + disconnect_waiter: asyncio.Future[Any], + ) -> bool: + """Handle client disconnection.""" + result = disconnect_waiter.result() + self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}") + if result is None: + self.logger.debug(f"Will flag: {client_session.will_flag}") + if client_session.will_flag: + self.logger.debug( + f"Client {format_client_message(client_session)} disconnected abnormally, sending will message", + ) + await self._broadcast_message( + client_session, + client_session.will_topic, + client_session.will_message, + client_session.will_qos, + ) + if client_session.will_retain: + self.retain_message( + client_session, + client_session.will_topic, + client_session.will_message, + client_session.will_qos, + ) + self.logger.debug(f"{client_session.client_id} Disconnecting session") + await self._stop_handler(handler) + client_session.transitions.disconnect() + await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id) + return False + return True - async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler: - """Create a BrokerProtocolHandler and attach to a session. + async def _handle_subscription( + self, + client_session: Session, + handler: BrokerProtocolHandler, + subscribe_waiter: asyncio.Future[Any], + ) -> None: + """Handle client subscription.""" + self.logger.debug(f"{client_session.client_id} handling subscription") + subscriptions = subscribe_waiter.result() + return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics] + await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes) + for index, subscription in enumerate(subscriptions.topics): + if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED: + await self.plugins_manager.fire_event( + EVENT_BROKER_CLIENT_SUBSCRIBED, + client_id=client_session.client_id, + topic=subscription[0], + qos=subscription[1], + ) + await self._publish_retained_messages_for_subscription(subscription, client_session) - :return: - """ - handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop) - handler.attach(session, reader, writer) - return handler + async def _handle_unsubscription( + self, + client_session: Session, + handler: BrokerProtocolHandler, + unsubscribe_waiter: asyncio.Future[Any], + ) -> None: + """Handle client unsubscription.""" + self.logger.debug(f"{client_session.client_id} handling unsubscription") + unsubscription = unsubscribe_waiter.result() + for topic in unsubscription.topics: + self._del_subscription(topic, client_session) + await self.plugins_manager.fire_event( + EVENT_BROKER_CLIENT_UNSUBSCRIBED, + client_id=client_session.client_id, + topic=topic, + ) + await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id) + + async def _handle_message_delivery( + self, + client_session: Session, + handler: BrokerProtocolHandler, + wait_deliver: asyncio.Future[Any], + ) -> bool: + """Handle message delivery to the client.""" + self.logger.debug(f"{client_session.client_id} handling message delivery") + app_message = wait_deliver.result() + + if app_message is None: + self.logger.debug("app_message was empty!") + return True + if not app_message.topic: + self.logger.warning( + f"[MQTT-4.7.3-1] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection", + ) + return False + if "#" in app_message.topic or "+" in app_message.topic: + self.logger.warning( + f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection", + ) + return False + + permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH) + if not permitted: + self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.") + else: + await self.plugins_manager.fire_event( + EVENT_BROKER_MESSAGE_RECEIVED, + client_id=client_session.client_id, + message=app_message, + ) + await self._broadcast_message(client_session, app_message.topic, app_message.data) + if app_message.publish_packet and app_message.publish_packet.retain_flag: + self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos) + return True + + # async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler: + # """Create a BrokerProtocolHandler and attach to a session.""" + # handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop) + # handler.attach(session, reader, writer) + # return handler async def _stop_handler(self, handler: BrokerProtocolHandler) -> None: - """Stop a running handler and detach if from the session. - - :param handler: - :return: - """ + """Stop a running handler and detach if from the session.""" try: await handler.stop() except Exception: self.logger.exception("Failed to stop handler") - async def authenticate(self, session: Session, _: dict[str, Any]) -> bool: + async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool: """Call the authenticate method on registered plugins to test user authentication. User is considered authenticated if all plugins called returns True. @@ -658,7 +739,49 @@ class Broker: # If all plugins returned True, authentication is success return auth_result - async def topic_filtering(self, session: Session, topic: str, action: Action) -> bool: + def retain_message( + self, + source_session: Session | None, + topic_name: str | None, + data: bytes | bytearray | None, + qos: int | None = None, + ) -> None: + if data and topic_name is not None: + # If retained flag set, store the message for further subscriptions + self.logger.debug(f"Retaining message on topic {topic_name}") + self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos) + # [MQTT-3.3.1-10] + elif topic_name in self._retained_messages: + self.logger.debug(f"Clearing retained messages for topic '{topic_name}'") + del self._retained_messages[topic_name] + + async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int: + topic_filter, qos = subscription + if "#" in topic_filter and not topic_filter.endswith("#"): + # [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter + return 0x80 + if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter): + # [MQTT-4.7.1-3] + wildcard character must occupy entire level + return 0x80 + # Check if the client is authorised to connect to the topic + if not await self._topic_filtering(session, topic_filter, Action.SUBSCRIBE): + return 0x80 + + # Ensure "max-qos" is an integer before using it + max_qos = self.config.get("max-qos", qos) + if not isinstance(max_qos, int): + max_qos = qos + + qos = min(qos, max_qos) + if topic_filter not in self._subscriptions: + self._subscriptions[topic_filter] = [] + if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]): + self._subscriptions[topic_filter].append((session, qos)) + else: + self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}") + return qos + + async def _topic_filtering(self, session: Session, topic: str, action: Action) -> bool: """Call the topic_filtering method on registered plugins to check that the subscription is allowed. User is considered allowed if all plugins called return True. @@ -690,48 +813,29 @@ class Broker: ) return all(result for result in results.values()) - def retain_message( - self, - source_session: Session | None, - topic_name: str | None, - data: bytes | bytearray | None, - qos: int | None = None, - ) -> None: - if data and topic_name is not None: - # If retained flag set, store the message for further subscriptions - self.logger.debug(f"Retaining message on topic {topic_name}") - self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos) - # [MQTT-3.3.1-10] - elif topic_name in self._retained_messages: - self.logger.debug(f"Clearing retained messages for topic '{topic_name}'") - del self._retained_messages[topic_name] + async def _delete_session(self, client_id: str) -> None: + """Delete an existing session data, for example due to clean session set in CONNECT.""" + session = self._sessions.pop(client_id, (None, None))[0] - # NOTE: issue #61 remove try block - async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int: - topic_filter, qos = subscription - if "#" in topic_filter and not topic_filter.endswith("#"): - # [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter - return 0x80 - if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter): - # [MQTT-4.7.1-3] + wildcard character must occupy entire level - return 0x80 - # Check if the client is authorised to connect to the topic - if not await self.topic_filtering(session, topic_filter, Action.SUBSCRIBE): - return 0x80 + if session is None: + self.logger.debug(f"Delete session : session {client_id} doesn't exist") + return + self.logger.debug(f"Deleted existing session {session!r}") - # Ensure "max-qos" is an integer before using it - max_qos = self.config.get("max-qos", qos) - if not isinstance(max_qos, int): - max_qos = qos + # Delete subscriptions + self.logger.debug(f"Deleting session {session!r} subscriptions") + await self._del_all_subscriptions(session) + session.clear_queues() - qos = min(qos, max_qos) - if topic_filter not in self._subscriptions: - self._subscriptions[topic_filter] = [] - if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]): - self._subscriptions[topic_filter].append((session, qos)) - else: - self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}") - return qos + async def _del_all_subscriptions(self, session: Session) -> None: + """Delete all topic subscriptions for a given session.""" + filter_queue: deque[str] = deque() + for topic in self._subscriptions: + if self._del_subscription(topic, session): + filter_queue.append(topic) + for topic in filter_queue: + if not self._subscriptions[topic]: + del self._subscriptions[topic] def _del_subscription(self, a_filter: str, session: Session) -> int: """Delete a session subscription on a given topic. @@ -752,32 +856,9 @@ class Broker: deleted += 1 break except KeyError: - # Unsubscribe topic not found in current subscribed topics - pass + self.logger.debug(f"Unsubscription on topic '{a_filter}' for client {format_client_message(session=session)}") return deleted - def _del_all_subscriptions(self, session: Session) -> None: - """Delete all topic subscriptions for a given session. - - :param session: - :return: - """ - filter_queue: deque[str] = deque() - for topic in self._subscriptions: - if self._del_subscription(topic, session): - filter_queue.append(topic) - for topic in filter_queue: - if not self._subscriptions[topic]: - del self._subscriptions[topic] - - def matches(self, topic: str, a_filter: str) -> bool: - if "#" not in a_filter and "+" not in a_filter: - # if filter doesn't contain wildcard, return exact match - return a_filter == topic - # else use regex - match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?")) - return bool(match_pattern.fullmatch(topic)) - async def _broadcast_loop(self) -> None: """Run the main loop to broadcast messages.""" running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = self._tasks_queue @@ -824,7 +905,7 @@ class Broker: continue # Skip all subscriptions which do not match the topic - if not self.matches(broadcast["topic"], k_filter): + if not self._matches(broadcast["topic"], k_filter): self.logger.debug(f"Topic '{broadcast['topic']}' does not match filter '{k_filter}'") continue @@ -892,7 +973,7 @@ class Broker: broadcast["qos"] = force_qos await self._broadcast_queue.put(broadcast) - async def publish_session_retained_messages(self, session: Session) -> None: + async def _publish_session_retained_messages(self, session: Session) -> None: self.logger.debug( f"Publishing {session.retained_messages.qsize()}" f" messages retained for session {format_client_message(session=session)}", @@ -910,7 +991,7 @@ class Broker: if publish_tasks: await asyncio.wait(publish_tasks) - async def publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None: + async def _publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None: self.logger.debug( f"Begin broadcasting messages retained due to subscription on '{subscription[0]}'" f" from {format_client_message(session=session)}", @@ -920,7 +1001,7 @@ class Broker: topic_filter, qos = subscription for topic, retained in self._retained_messages.items(): self.logger.debug(f"matching : {topic} {topic_filter}") - if self.matches(topic, topic_filter): + if self._matches(topic, topic_filter): self.logger.debug(f"{topic} and {topic_filter} match") handler = self._get_handler(session) if handler: @@ -936,62 +1017,16 @@ class Broker: f" from {format_client_message(session=session)}", ) - def delete_session(self, client_id: str) -> None: - """Delete an existing session data, for example due to clean session set in CONNECT.""" - session = self._sessions.pop(client_id, (None, None))[0] - - if session is None: - self.logger.debug(f"Delete session : session {client_id} doesn't exist") - return - self.logger.debug(f"Deleted existing session {session!r}") - - # Delete subscriptions - self.logger.debug(f"Deleting session {session!r} subscriptions") - self._del_all_subscriptions(session) - session.clear_queues() + def _matches(self, topic: str, a_filter: str) -> bool: + if "#" not in a_filter and "+" not in a_filter: + # if filter doesn't contain wildcard, return exact match + return a_filter == topic + # else use regex + match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?")) + return bool(match_pattern.fullmatch(topic)) def _get_handler(self, session: Session) -> BrokerProtocolHandler | None: client_id = session.client_id if client_id: return self._sessions.get(client_id, (None, None))[1] return None - - @classmethod - def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]: - """Split an address:port pair into separate IP address and port. with IPv6 special-case handling. - - NOTE: issue #72 - - - Address can be specified using one of the following methods: - - empty string - all interfaces default port - - 1883 - Port number only (listen all interfaces) - - :1883 - Port number only (listen all interfaces) - - 0.0.0.0:1883 - IPv4 address - - [::]:1883 - IPv6 address - """ - - def _parse_port(port_str: str) -> int: - port_str = port_str.removeprefix(":") - - if not port_str: - return default_port - - return int(port_str) - - if port_str.startswith("["): # IPv6 literal - try: - addr_end = port_str.index("]") - except ValueError as e: - msg = "Expecting '[' to be followed by ']'" - raise ValueError(msg) from e - - return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :])) - - if ":" in port_str: - address, port_str = port_str.rsplit(":", 1) - return (address or None, _parse_port(port_str)) - - try: - return (None, _parse_port(port_str)) - except ValueError: - return (port_str, default_port) diff --git a/amqtt/client.py b/amqtt/client.py index 78ab24c..bde782a 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -526,8 +526,7 @@ class MQTTClient: while self.client_tasks: task = self.client_tasks.popleft() if not task.done(): - # task.set_exception(ClientError("Connection lost")) - task.cancel() # NOTE: issue #153 + task.cancel() self.logger.debug("Monitoring broker disconnection") # Wait for disconnection from broker (like connection lost) diff --git a/amqtt/mqtt/protocol/client_handler.py b/amqtt/mqtt/protocol/client_handler.py index b0b5851..491390f 100644 --- a/amqtt/mqtt/protocol/client_handler.py +++ b/amqtt/mqtt/protocol/client_handler.py @@ -195,7 +195,6 @@ class ClientProtocolHandler(ProtocolHandler): self.logger.debug("Broker closed connection") if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): self._disconnect_waiter.set_result(None) - # await self.stop() # NOTE: issue #119 async def wait_disconnect(self) -> None: if self._disconnect_waiter is not None: diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index e13c60c..89f50f7 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -67,15 +67,11 @@ class ProtocolHandler: self.writer: WriterAdapter | None = None self.plugins_manager: PluginManager = plugins_manager - # TODO: check how to update loop usage best - self._loop = loop if loop is not None else asyncio.get_event_loop_policy().get_event_loop() - # try: - # # Use the currently running loop if available - # self._loop = loop if loop is not None else asyncio.get_running_loop() - # except RuntimeError: - # # If no running loop is found, create a new one - # self._loop = asyncio.new_event_loop() - # asyncio.set_event_loop(self._loop) + try: + self._loop = loop if loop is not None else asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) self._reader_task: asyncio.Task[None] | None = None self._keepalive_task: asyncio.TimerHandle | None = None diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index 4e80cb7..ccb4029 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -27,8 +27,6 @@ def get_plugin_manager(namespace: str) -> "PluginManager | None": class BaseContext: def __init__(self) -> None: self.loop: asyncio.AbstractEventLoop | None = None - # TODO: change this usage - # self.logger: logging.Logger | None = None self.logger: logging.Logger = _LOGGER self.config: dict[str, Any] | None = None @@ -41,16 +39,11 @@ class PluginManager: """ def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None: - # TODO: check how to update loop usage best - # self._loop = loop if loop is not None else asyncio.get_event_loop_policy().get_event_loop() - if loop is None: - try: - self._loop = asyncio.get_running_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - # asyncio.set_event_loop(self._loop) - else: - self._loop = loop + try: + self._loop = loop if loop is not None else asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) self.logger = logging.getLogger(namespace) self.context = context if context is not None else BaseContext() diff --git a/pyproject.toml b/pyproject.toml index 2f80060..341f1b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,7 +158,6 @@ timeout = 10 # ------------------------------------ MYPY ------------------------------------ [tool.mypy] -# mypy_path = "amqtt" exclude = ["^tests/.*", "^docs/.*", "^samples/.*"] follow_imports = "silent" show_error_codes = true @@ -235,10 +234,11 @@ never-returning-functions = ["sys.exit", "argparse.parse_error"] [tool.pylint.DESIGN] max-branches = 20 # too-many-branches -max-parents = 10 +max-parents = 10 # too-many-parents max-positional-arguments = 10 # too-many-positional-arguments -max-returns = 7 +max-returns = 7 # too-many-returns max-statements = 61 # too-many-statements +max-module-lines = 1500 # too-many-lines # ---------------------------------- COVERAGE ---------------------------------- [tool.coverage.run] diff --git a/tests/test_broker.py b/tests/test_broker.py index 0944f9b..70183d1 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -1,7 +1,7 @@ import asyncio import logging import socket -from unittest.mock import MagicMock, call +from unittest.mock import MagicMock, call, patch import psutil import pytest @@ -25,6 +25,7 @@ from amqtt.mqtt.connack import ConnackPacket from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 from amqtt.mqtt.disconnect import DisconnectPacket +from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from amqtt.mqtt.pubcomp import PubcompPacket from amqtt.mqtt.publish import PublishPacket from amqtt.mqtt.pubrec import PubrecPacket @@ -677,7 +678,7 @@ def test_matches_multi_level_wildcard(broker): "sport/tennis", "sport/tennis/", ]: - assert not broker.matches(bad_topic, test_filter) + assert not broker._matches(bad_topic, test_filter) for good_topic in [ "sport/tennis/player1", @@ -685,7 +686,7 @@ def test_matches_multi_level_wildcard(broker): "sport/tennis/player1/ranking", "sport/tennis/player1/score/wimbledon", ]: - assert broker.matches(good_topic, test_filter) + assert broker._matches(good_topic, test_filter) def test_matches_single_level_wildcard(broker): @@ -696,37 +697,36 @@ def test_matches_single_level_wildcard(broker): "sport/tennis/player1/", "sport/tennis/player1/ranking", ]: - assert not broker.matches(bad_topic, test_filter) + assert not broker._matches(bad_topic, test_filter) for good_topic in [ "sport/tennis/", "sport/tennis/player1", "sport/tennis/player2", ]: - assert broker.matches(good_topic, test_filter) + assert broker._matches(good_topic, test_filter) -# @pytest.mark.asyncio -# async def test_broker_broadcast_cancellation(broker): -# topic = "test" -# data = b"data" -# qos = QOS_0 +@pytest.mark.asyncio +async def test_broker_broadcast_cancellation(broker): + topic = "test" + data = b"data" + qos = QOS_0 -# sub_client = MQTTClient() -# await sub_client.connect("mqtt://127.0.0.1") -# await sub_client.subscribe([(topic, qos)]) + sub_client = MQTTClient() + await sub_client.connect("mqtt://127.0.0.1") + await sub_client.subscribe([(topic, qos)]) -# with patch.object(BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError) as mocked_mqtt_publish: -# await _client_publish(topic, data, qos) + with patch.object(BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError) as mocked_mqtt_publish: + await _client_publish(topic, data, qos) -# # Second publish triggers the awaiting of first `mqtt_publish` task -# await _client_publish(topic, data, qos) -# await asyncio.sleep(0.01) + # Second publish triggers the awaiting of first `mqtt_publish` task + await _client_publish(topic, data, qos) + await asyncio.sleep(0.01) -# # `assert_awaited` does not exist in Python before `3.8` -# mocked_mqtt_publish.assert_awaited() + mocked_mqtt_publish.assert_awaited() -# # Ensure broadcast loop is still functional and can deliver the message -# await _client_publish(topic, data, qos) -# message = await asyncio.wait_for(sub_client.deliver_message(), timeout_duration=1) -# assert message + # Ensure broadcast loop is still functional and can deliver the message + await _client_publish(topic, data, qos) + message = await asyncio.wait_for(sub_client.deliver_message(), timeout=1) + assert message