refactor: modularize broker listener startup and improve session management

- Extracted listener startup logic into a separate method `_start_listeners` for better readability and maintainability.
- Created helper methods for SSL context creation.
- Improved session cleanup during shutdown by centralizing logic in `_cleanup_session`.
- Refactored message handling and subscription management into dedicated methods for clarity.
- Updated tests to reflect changes in method visibility for matching topics.
pull/165/head
MVladislav 2025-04-06 19:03:30 +02:00
rodzic 3e0902cc8b
commit 12d5cb6866
7 zmienionych plików z 439 dodań i 417 usunięć

Wyświetl plik

@ -242,72 +242,10 @@ class Broker:
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
try:
# Start network listeners
for listener_name, listener in self.listeners_config.items():
if "bind" not in listener:
self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
continue
max_connections = listener.get("max_connections", -1)
ssl_conection = None
ssl_active = listener.get("ssl", False) # accept string "on" / "off" or boolean
if isinstance(ssl_active, str):
ssl_active = ssl_active.upper() == "ON"
if ssl_active:
try:
ssl_conection = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH,
cafile=listener.get("cafile"),
capath=listener.get("capath"),
cadata=listener.get("cadata"),
)
ssl_conection.load_cert_chain(listener["certfile"], listener["keyfile"])
ssl_conection.verify_mode = ssl.CERT_OPTIONAL
except KeyError as ke:
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
raise BrokerError(msg) from ke
except FileNotFoundError as fnfe:
msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
raise BrokerError(msg) from fnfe
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
instance: asyncio.Server | websockets.asyncio.server.Server | None = None
if listener["type"] == "tcp":
cb_partial = partial(self.stream_connected, listener_name=listener_name)
instance = await asyncio.start_server(
cb_partial,
address,
port,
reuse_address=True,
ssl=ssl_conection,
)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
elif listener["type"] == "ws":
cb_partial = partial(self.ws_connected, listener_name=listener_name)
instance = await websockets.serve(
cb_partial,
address,
port,
ssl=ssl_conection,
subprotocols=[websockets.Subprotocol("mqtt")],
)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
await self._start_listeners()
self.transitions.starting_success()
await self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
# Start broadcast loop
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
self.logger.debug("Broker started")
except Exception as e:
self.logger.exception("Broker startup failed")
@ -315,48 +253,128 @@ class Broker:
msg = f"Broker instance can't be started: {e}"
raise BrokerError(msg) from e
async def _start_listeners(self) -> None:
"""Start network listeners based on the configuration."""
for listener_name, listener in self.listeners_config.items():
if "bind" not in listener:
self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
continue
max_connections = listener.get("max_connections", -1)
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
@classmethod
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
- Address can be specified using one of the following methods:
- empty string - all interfaces default port
- 1883 - Port number only (listen all interfaces)
- :1883 - Port number only (listen all interfaces)
- 0.0.0.0:1883 - IPv4 address
- [::]:1883 - IPv6 address
"""
def _parse_port(port_str: str) -> int:
port_str = port_str.removeprefix(":")
if not port_str:
return default_port
return int(port_str)
if port_str.startswith("["): # IPv6 literal
try:
addr_end = port_str.index("]")
except ValueError as e:
msg = "Expecting '[' to be followed by ']'"
raise ValueError(msg) from e
return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
if ":" in port_str:
address, port_str = port_str.rsplit(":", 1)
return (address or None, _parse_port(port_str))
try:
return (None, _parse_port(port_str))
except ValueError:
return (port_str, default_port)
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext:
"""Create an SSL context for a listener."""
try:
ssl_context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH,
cafile=listener.get("cafile"),
capath=listener.get("capath"),
cadata=listener.get("cadata"),
)
ssl_context.load_cert_chain(listener["certfile"], listener["keyfile"])
ssl_context.verify_mode = ssl.CERT_OPTIONAL
return ssl_context
except KeyError as ke:
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
raise BrokerError(msg) from ke
except FileNotFoundError as fnfe:
msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
raise BrokerError(msg) from fnfe
async def _create_server_instance(
self,
listener_name: str,
listener_type: str,
address: str | None,
port: int,
ssl_context: ssl.SSLContext | None,
) -> asyncio.Server | websockets.asyncio.server.Server:
"""Create a server instance for a listener."""
if listener_type == "tcp":
return await asyncio.start_server(
partial(self.stream_connected, listener_name=listener_name),
address,
port,
reuse_address=True,
ssl=ssl_context,
)
if listener_type == "ws":
return await websockets.serve(
partial(self.ws_connected, listener_name=listener_name),
address,
port,
ssl=ssl_context,
subprotocols=[websockets.Subprotocol("mqtt")],
)
msg = f"Unsupported listener type: {listener_type}"
raise BrokerError(msg)
async def shutdown(self) -> None:
"""Stop broker instance."""
try:
# # Wait for all in-flight tasks to complete before stopping session handlers
# for client_id, (session, handler) in self._sessions.items():
# if handler:
# self.logger.debug(f"Waiting for in-flight tasks to complete for session {client_id}")
# if session.inflight_out:
# # Directly use asyncio.sleep or another async operation in the loop
# await asyncio.gather(
# *(asyncio.sleep(0) for _ in session.inflight_out.values()),
# return_exceptions=True,
# )
# Stop all session handlers
for client_id, (_, handler) in self._sessions.items():
if handler:
self.logger.debug(f"Stopping handler for session {client_id}")
await self._stop_handler(handler)
# Clear subscriptions
for topic, subscriptions in self._subscriptions.items():
self.logger.debug(f"Clearing subscriptions for topic '{topic}'")
for session, _ in subscriptions:
self._del_subscription(topic, session)
self._subscriptions.clear()
# Clear retained messages
if self._retained_messages:
self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages")
self._retained_messages.clear()
self._sessions.clear()
self.transitions.shutdown()
except (MachineError, ValueError) as exc:
# Backwards compat: MachineError is raised by transitions < 0.5.0.
self.logger.debug(f"Invalid method call at this moment: {exc}")
raise
self.logger.info("Shutting down broker...")
# Fire broker_shutdown event to plugins
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
# Cleanup all sessions
for client_id in list(self._sessions.keys()):
await self._cleanup_session(client_id)
# Clear retained messages
self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages")
self._retained_messages.clear()
self.transitions.shutdown()
await self._shutdown_broadcast_loop()
for server in self._servers.values():
@ -372,59 +390,89 @@ class Broker:
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
self.transitions.stopping_success()
async def _cleanup_session(self, client_id: str) -> None:
"""Centralized cleanup logic for a session."""
session, handler = self._sessions.pop(client_id, (None, None))
if handler:
self.logger.debug(f"Stopping handler for session {client_id}")
await self._stop_handler(handler)
if session:
self.logger.debug(f"Clearing all subscriptions for session {client_id}")
await self._del_all_subscriptions(session)
session.clear_queues()
async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None:
return await self._broadcast_message(None, topic, data, qos)
async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None:
await self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
await self._client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
await self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
async def client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
# Wait for connection available on listener
server = self._servers.get(listener_name, None)
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
"""Handle a new client connection."""
server = self._servers.get(listener_name)
if not server:
msg = f"Invalid listener name '{listener_name}'"
raise BrokerError(msg)
await server.acquire_connection()
await server.acquire_connection()
remote_info = writer.get_peer_info()
if remote_info is None:
self.logger.warning("remote info could not get from peer info")
self.logger.warning("Remote info could not be retrieved from peer info")
return
remote_address, remote_port = remote_info
self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'")
try:
handler, client_session = await self._initialize_client_session(reader, writer, remote_address, remote_port)
except (AMQTTError, MQTTError, NoDataError) as exc:
self.logger.warning(f"Failed to initialize client session: {exc}")
server.release_connection()
return
try:
await self._handle_client_session(reader, writer, client_session, handler, server, listener_name)
except (AMQTTError, MQTTError, NoDataError) as exc:
self.logger.warning(f"Error while handling client session: {exc}")
finally:
self.logger.debug(f"{client_session.client_id} Client disconnected")
server.release_connection()
async def _initialize_client_session(
self,
reader: ReaderAdapter,
writer: WriterAdapter,
remote_address: str,
remote_port: int,
) -> tuple[BrokerProtocolHandler, Session]:
"""Initialize a client session and protocol handler."""
# Wait for first packet and expect a CONNECT
try:
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
except AMQTTError as exc:
self.logger.warning(
f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:"
f"Can't read first packet an CONNECT: {exc}",
f" Can't read first packet as CONNECT: {exc}",
)
self.logger.debug("Connection closed")
server.release_connection()
return
except MQTTError:
raise AMQTTError(exc) from exc
except MQTTError as exc:
self.logger.exception(
f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}",
)
await writer.close()
server.release_connection()
self.logger.debug("Connection closed")
return
except NoDataError as ne:
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {ne}") # noqa: TRY400 # cannot replace with exception else test fails
server.release_connection()
return
raise MQTTError(exc) from exc
except NoDataError as exc:
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {exc}") # noqa: TRY400 # cannot replace with exception else pytest fails
raise NoDataError(exc) from exc
if client_session.clean_session:
# Delete existing session and create a new one
if client_session.client_id is not None and client_session.client_id != "":
self.delete_session(client_session.client_id)
await self._delete_session(client_session.client_id)
else:
client_session.client_id = gen_client_id()
client_session.parent = 0
@ -436,21 +484,32 @@ class Broker:
else:
client_session.parent = 0
timeout_disconnect_delay = self.config.get("timeout-disconnect-delay", 0)
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
client_session.keep_alive += timeout_disconnect_delay
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
return handler, client_session
async def _handle_client_session(
self,
reader: ReaderAdapter,
writer: WriterAdapter,
client_session: Session,
handler: BrokerProtocolHandler,
server: Server,
listener_name: str,
) -> None:
"""Handle the lifecycle of a client session."""
authenticated = await self._authenticate(client_session, self.listeners_config[listener_name])
if not authenticated:
await writer.close()
return
if client_session.client_id is None:
msg = "Client ID was not correctly created/set."
raise BrokerError(msg)
timeout_disconnect_delay = self.config.get("timeout-disconnect-delay")
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
client_session.keep_alive += timeout_disconnect_delay
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
authenticated = await self.authenticate(client_session, self.listeners_config[listener_name])
if not authenticated:
await writer.close()
server.release_connection() # Delete client from connections list
return
while True:
try:
client_session.transitions.connect()
@ -469,14 +528,17 @@ class Broker:
self._sessions[client_session.client_id] = (client_session, handler)
await handler.mqtt_connack_authorize(authenticated)
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
self.logger.debug(f"{client_session.client_id} Start messages handling")
await handler.start()
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
await self.publish_session_retained_messages(client_session)
await self._publish_session_retained_messages(client_session)
await self._client_message_loop(client_session, handler)
async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
"""Run the main loop to handle client messages."""
# Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
@ -495,141 +557,160 @@ class Broker:
],
return_when=asyncio.FIRST_COMPLETED,
)
if disconnect_waiter in done:
result = disconnect_waiter.result()
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
if result is None:
self.logger.debug(f"Will flag: {client_session.will_flag}")
if client_session.will_flag:
self.logger.debug(
f"Client {format_client_message(client_session)} disconnected abnormally, sending will message",
)
await self._broadcast_message(
client_session,
client_session.will_topic,
client_session.will_message,
client_session.will_qos,
)
if client_session.will_retain:
self.retain_message(
client_session,
client_session.will_topic,
client_session.will_message,
client_session.will_qos,
)
self.logger.debug(f"{client_session.client_id} Disconnecting session")
await self._stop_handler(handler)
client_session.transitions.disconnect()
await self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_DISCONNECTED,
client_id=client_session.client_id,
)
connected = False
# Recreate the disconnect_waiter task after processing
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
if unsubscribe_waiter in done:
self.logger.debug(f"{client_session.client_id} handling unsubscription")
unsubscription = unsubscribe_waiter.result()
for topic in unsubscription.topics:
self._del_subscription(topic, client_session)
await self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
client_id=client_session.client_id,
topic=topic,
)
await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id)
# Recreate the unsubscribe_waiter task
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
if subscribe_waiter in done:
self.logger.debug(f"{client_session.client_id} handling subscription")
subscriptions = subscribe_waiter.result()
return_codes = [
await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics
]
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
for index, subscription in enumerate(subscriptions.topics):
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
await self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
client_id=client_session.client_id,
topic=subscription[0],
qos=subscription[1],
)
await self.publish_retained_messages_for_subscription(subscription, client_session)
# Recreate the subscribe_waiter task
if disconnect_waiter in done:
connected = await self._handle_disconnect(client_session, handler, disconnect_waiter)
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
if subscribe_waiter in done:
await self._handle_subscription(client_session, handler, subscribe_waiter)
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
self.logger.debug(repr(self._subscriptions))
if unsubscribe_waiter in done:
await self._handle_unsubscription(client_session, handler, unsubscribe_waiter)
unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
if wait_deliver in done:
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"{client_session.client_id} handling message delivery")
app_message = wait_deliver.result()
if app_message is None:
self.logger.debug("app_message was empty!")
continue
if not app_message.topic:
self.logger.warning(
f"[MQTT-4.7.3-1] - {client_session.client_id}"
" invalid TOPIC sent in PUBLISH message, closing connection",
)
if not await self._handle_message_delivery(client_session, handler, wait_deliver):
break
if "#" in app_message.topic or "+" in app_message.topic:
self.logger.warning(
f"[MQTT-3.3.2-2] - {client_session.client_id}"
" invalid TOPIC sent in PUBLISH message, closing connection",
)
break
permitted = await self.topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
if not permitted:
self.logger.info(
f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.",
)
else:
await self.plugins_manager.fire_event(
EVENT_BROKER_MESSAGE_RECEIVED,
client_id=client_session.client_id,
message=app_message,
)
await self._broadcast_message(client_session, app_message.topic, app_message.data)
if app_message.publish_packet is not None and app_message.publish_packet.retain_flag:
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
# Recreate the wait_deliver task
wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
except asyncio.CancelledError:
self.logger.debug("Client loop cancelled")
break
disconnect_waiter.cancel()
subscribe_waiter.cancel()
unsubscribe_waiter.cancel()
wait_deliver.cancel()
self.logger.debug(f"{client_session.client_id} Client disconnected")
server.release_connection()
async def _handle_disconnect(
self,
client_session: Session,
handler: BrokerProtocolHandler,
disconnect_waiter: asyncio.Future[Any],
) -> bool:
"""Handle client disconnection."""
result = disconnect_waiter.result()
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
if result is None:
self.logger.debug(f"Will flag: {client_session.will_flag}")
if client_session.will_flag:
self.logger.debug(
f"Client {format_client_message(client_session)} disconnected abnormally, sending will message",
)
await self._broadcast_message(
client_session,
client_session.will_topic,
client_session.will_message,
client_session.will_qos,
)
if client_session.will_retain:
self.retain_message(
client_session,
client_session.will_topic,
client_session.will_message,
client_session.will_qos,
)
self.logger.debug(f"{client_session.client_id} Disconnecting session")
await self._stop_handler(handler)
client_session.transitions.disconnect()
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
return False
return True
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
"""Create a BrokerProtocolHandler and attach to a session.
async def _handle_subscription(
self,
client_session: Session,
handler: BrokerProtocolHandler,
subscribe_waiter: asyncio.Future[Any],
) -> None:
"""Handle client subscription."""
self.logger.debug(f"{client_session.client_id} handling subscription")
subscriptions = subscribe_waiter.result()
return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics]
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
for index, subscription in enumerate(subscriptions.topics):
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
await self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
client_id=client_session.client_id,
topic=subscription[0],
qos=subscription[1],
)
await self._publish_retained_messages_for_subscription(subscription, client_session)
:return:
"""
handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
handler.attach(session, reader, writer)
return handler
async def _handle_unsubscription(
self,
client_session: Session,
handler: BrokerProtocolHandler,
unsubscribe_waiter: asyncio.Future[Any],
) -> None:
"""Handle client unsubscription."""
self.logger.debug(f"{client_session.client_id} handling unsubscription")
unsubscription = unsubscribe_waiter.result()
for topic in unsubscription.topics:
self._del_subscription(topic, client_session)
await self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
client_id=client_session.client_id,
topic=topic,
)
await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id)
async def _handle_message_delivery(
self,
client_session: Session,
handler: BrokerProtocolHandler,
wait_deliver: asyncio.Future[Any],
) -> bool:
"""Handle message delivery to the client."""
self.logger.debug(f"{client_session.client_id} handling message delivery")
app_message = wait_deliver.result()
if app_message is None:
self.logger.debug("app_message was empty!")
return True
if not app_message.topic:
self.logger.warning(
f"[MQTT-4.7.3-1] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
)
return False
if "#" in app_message.topic or "+" in app_message.topic:
self.logger.warning(
f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
)
return False
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
if not permitted:
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
else:
await self.plugins_manager.fire_event(
EVENT_BROKER_MESSAGE_RECEIVED,
client_id=client_session.client_id,
message=app_message,
)
await self._broadcast_message(client_session, app_message.topic, app_message.data)
if app_message.publish_packet and app_message.publish_packet.retain_flag:
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
return True
# async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
# """Create a BrokerProtocolHandler and attach to a session."""
# handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
# handler.attach(session, reader, writer)
# return handler
async def _stop_handler(self, handler: BrokerProtocolHandler) -> None:
"""Stop a running handler and detach if from the session.
:param handler:
:return:
"""
"""Stop a running handler and detach if from the session."""
try:
await handler.stop()
except Exception:
self.logger.exception("Failed to stop handler")
async def authenticate(self, session: Session, _: dict[str, Any]) -> bool:
async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool:
"""Call the authenticate method on registered plugins to test user authentication.
User is considered authenticated if all plugins called returns True.
@ -658,7 +739,49 @@ class Broker:
# If all plugins returned True, authentication is success
return auth_result
async def topic_filtering(self, session: Session, topic: str, action: Action) -> bool:
def retain_message(
self,
source_session: Session | None,
topic_name: str | None,
data: bytes | bytearray | None,
qos: int | None = None,
) -> None:
if data and topic_name is not None:
# If retained flag set, store the message for further subscriptions
self.logger.debug(f"Retaining message on topic {topic_name}")
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
# [MQTT-3.3.1-10]
elif topic_name in self._retained_messages:
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
del self._retained_messages[topic_name]
async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
topic_filter, qos = subscription
if "#" in topic_filter and not topic_filter.endswith("#"):
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
return 0x80
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
return 0x80
# Check if the client is authorised to connect to the topic
if not await self._topic_filtering(session, topic_filter, Action.SUBSCRIBE):
return 0x80
# Ensure "max-qos" is an integer before using it
max_qos = self.config.get("max-qos", qos)
if not isinstance(max_qos, int):
max_qos = qos
qos = min(qos, max_qos)
if topic_filter not in self._subscriptions:
self._subscriptions[topic_filter] = []
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
self._subscriptions[topic_filter].append((session, qos))
else:
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
return qos
async def _topic_filtering(self, session: Session, topic: str, action: Action) -> bool:
"""Call the topic_filtering method on registered plugins to check that the subscription is allowed.
User is considered allowed if all plugins called return True.
@ -690,48 +813,29 @@ class Broker:
)
return all(result for result in results.values())
def retain_message(
self,
source_session: Session | None,
topic_name: str | None,
data: bytes | bytearray | None,
qos: int | None = None,
) -> None:
if data and topic_name is not None:
# If retained flag set, store the message for further subscriptions
self.logger.debug(f"Retaining message on topic {topic_name}")
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
# [MQTT-3.3.1-10]
elif topic_name in self._retained_messages:
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
del self._retained_messages[topic_name]
async def _delete_session(self, client_id: str) -> None:
"""Delete an existing session data, for example due to clean session set in CONNECT."""
session = self._sessions.pop(client_id, (None, None))[0]
# NOTE: issue #61 remove try block
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
topic_filter, qos = subscription
if "#" in topic_filter and not topic_filter.endswith("#"):
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
return 0x80
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
return 0x80
# Check if the client is authorised to connect to the topic
if not await self.topic_filtering(session, topic_filter, Action.SUBSCRIBE):
return 0x80
if session is None:
self.logger.debug(f"Delete session : session {client_id} doesn't exist")
return
self.logger.debug(f"Deleted existing session {session!r}")
# Ensure "max-qos" is an integer before using it
max_qos = self.config.get("max-qos", qos)
if not isinstance(max_qos, int):
max_qos = qos
# Delete subscriptions
self.logger.debug(f"Deleting session {session!r} subscriptions")
await self._del_all_subscriptions(session)
session.clear_queues()
qos = min(qos, max_qos)
if topic_filter not in self._subscriptions:
self._subscriptions[topic_filter] = []
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
self._subscriptions[topic_filter].append((session, qos))
else:
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
return qos
async def _del_all_subscriptions(self, session: Session) -> None:
"""Delete all topic subscriptions for a given session."""
filter_queue: deque[str] = deque()
for topic in self._subscriptions:
if self._del_subscription(topic, session):
filter_queue.append(topic)
for topic in filter_queue:
if not self._subscriptions[topic]:
del self._subscriptions[topic]
def _del_subscription(self, a_filter: str, session: Session) -> int:
"""Delete a session subscription on a given topic.
@ -752,32 +856,9 @@ class Broker:
deleted += 1
break
except KeyError:
# Unsubscribe topic not found in current subscribed topics
pass
self.logger.debug(f"Unsubscription on topic '{a_filter}' for client {format_client_message(session=session)}")
return deleted
def _del_all_subscriptions(self, session: Session) -> None:
"""Delete all topic subscriptions for a given session.
:param session:
:return:
"""
filter_queue: deque[str] = deque()
for topic in self._subscriptions:
if self._del_subscription(topic, session):
filter_queue.append(topic)
for topic in filter_queue:
if not self._subscriptions[topic]:
del self._subscriptions[topic]
def matches(self, topic: str, a_filter: str) -> bool:
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
# else use regex
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
return bool(match_pattern.fullmatch(topic))
async def _broadcast_loop(self) -> None:
"""Run the main loop to broadcast messages."""
running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = self._tasks_queue
@ -824,7 +905,7 @@ class Broker:
continue
# Skip all subscriptions which do not match the topic
if not self.matches(broadcast["topic"], k_filter):
if not self._matches(broadcast["topic"], k_filter):
self.logger.debug(f"Topic '{broadcast['topic']}' does not match filter '{k_filter}'")
continue
@ -892,7 +973,7 @@ class Broker:
broadcast["qos"] = force_qos
await self._broadcast_queue.put(broadcast)
async def publish_session_retained_messages(self, session: Session) -> None:
async def _publish_session_retained_messages(self, session: Session) -> None:
self.logger.debug(
f"Publishing {session.retained_messages.qsize()}"
f" messages retained for session {format_client_message(session=session)}",
@ -910,7 +991,7 @@ class Broker:
if publish_tasks:
await asyncio.wait(publish_tasks)
async def publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None:
async def _publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None:
self.logger.debug(
f"Begin broadcasting messages retained due to subscription on '{subscription[0]}'"
f" from {format_client_message(session=session)}",
@ -920,7 +1001,7 @@ class Broker:
topic_filter, qos = subscription
for topic, retained in self._retained_messages.items():
self.logger.debug(f"matching : {topic} {topic_filter}")
if self.matches(topic, topic_filter):
if self._matches(topic, topic_filter):
self.logger.debug(f"{topic} and {topic_filter} match")
handler = self._get_handler(session)
if handler:
@ -936,62 +1017,16 @@ class Broker:
f" from {format_client_message(session=session)}",
)
def delete_session(self, client_id: str) -> None:
"""Delete an existing session data, for example due to clean session set in CONNECT."""
session = self._sessions.pop(client_id, (None, None))[0]
if session is None:
self.logger.debug(f"Delete session : session {client_id} doesn't exist")
return
self.logger.debug(f"Deleted existing session {session!r}")
# Delete subscriptions
self.logger.debug(f"Deleting session {session!r} subscriptions")
self._del_all_subscriptions(session)
session.clear_queues()
def _matches(self, topic: str, a_filter: str) -> bool:
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
# else use regex
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
return bool(match_pattern.fullmatch(topic))
def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:
client_id = session.client_id
if client_id:
return self._sessions.get(client_id, (None, None))[1]
return None
@classmethod
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
NOTE: issue #72
- Address can be specified using one of the following methods:
- empty string - all interfaces default port
- 1883 - Port number only (listen all interfaces)
- :1883 - Port number only (listen all interfaces)
- 0.0.0.0:1883 - IPv4 address
- [::]:1883 - IPv6 address
"""
def _parse_port(port_str: str) -> int:
port_str = port_str.removeprefix(":")
if not port_str:
return default_port
return int(port_str)
if port_str.startswith("["): # IPv6 literal
try:
addr_end = port_str.index("]")
except ValueError as e:
msg = "Expecting '[' to be followed by ']'"
raise ValueError(msg) from e
return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
if ":" in port_str:
address, port_str = port_str.rsplit(":", 1)
return (address or None, _parse_port(port_str))
try:
return (None, _parse_port(port_str))
except ValueError:
return (port_str, default_port)

Wyświetl plik

@ -526,8 +526,7 @@ class MQTTClient:
while self.client_tasks:
task = self.client_tasks.popleft()
if not task.done():
# task.set_exception(ClientError("Connection lost"))
task.cancel() # NOTE: issue #153
task.cancel()
self.logger.debug("Monitoring broker disconnection")
# Wait for disconnection from broker (like connection lost)

Wyświetl plik

@ -195,7 +195,6 @@ class ClientProtocolHandler(ProtocolHandler):
self.logger.debug("Broker closed connection")
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
self._disconnect_waiter.set_result(None)
# await self.stop() # NOTE: issue #119
async def wait_disconnect(self) -> None:
if self._disconnect_waiter is not None:

Wyświetl plik

@ -67,15 +67,11 @@ class ProtocolHandler:
self.writer: WriterAdapter | None = None
self.plugins_manager: PluginManager = plugins_manager
# TODO: check how to update loop usage best
self._loop = loop if loop is not None else asyncio.get_event_loop_policy().get_event_loop()
# try:
# # Use the currently running loop if available
# self._loop = loop if loop is not None else asyncio.get_running_loop()
# except RuntimeError:
# # If no running loop is found, create a new one
# self._loop = asyncio.new_event_loop()
# asyncio.set_event_loop(self._loop)
try:
self._loop = loop if loop is not None else asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._reader_task: asyncio.Task[None] | None = None
self._keepalive_task: asyncio.TimerHandle | None = None

Wyświetl plik

@ -27,8 +27,6 @@ def get_plugin_manager(namespace: str) -> "PluginManager | None":
class BaseContext:
def __init__(self) -> None:
self.loop: asyncio.AbstractEventLoop | None = None
# TODO: change this usage
# self.logger: logging.Logger | None = None
self.logger: logging.Logger = _LOGGER
self.config: dict[str, Any] | None = None
@ -41,16 +39,11 @@ class PluginManager:
"""
def __init__(self, namespace: str, context: BaseContext | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
# TODO: check how to update loop usage best
# self._loop = loop if loop is not None else asyncio.get_event_loop_policy().get_event_loop()
if loop is None:
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
# asyncio.set_event_loop(self._loop)
else:
self._loop = loop
try:
self._loop = loop if loop is not None else asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self.logger = logging.getLogger(namespace)
self.context = context if context is not None else BaseContext()

Wyświetl plik

@ -158,7 +158,6 @@ timeout = 10
# ------------------------------------ MYPY ------------------------------------
[tool.mypy]
# mypy_path = "amqtt"
exclude = ["^tests/.*", "^docs/.*", "^samples/.*"]
follow_imports = "silent"
show_error_codes = true
@ -235,10 +234,11 @@ never-returning-functions = ["sys.exit", "argparse.parse_error"]
[tool.pylint.DESIGN]
max-branches = 20 # too-many-branches
max-parents = 10
max-parents = 10 # too-many-parents
max-positional-arguments = 10 # too-many-positional-arguments
max-returns = 7
max-returns = 7 # too-many-returns
max-statements = 61 # too-many-statements
max-module-lines = 1500 # too-many-lines
# ---------------------------------- COVERAGE ----------------------------------
[tool.coverage.run]

Wyświetl plik

@ -1,7 +1,7 @@
import asyncio
import logging
import socket
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock, call, patch
import psutil
import pytest
@ -25,6 +25,7 @@ from amqtt.mqtt.connack import ConnackPacket
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.disconnect import DisconnectPacket
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.mqtt.pubcomp import PubcompPacket
from amqtt.mqtt.publish import PublishPacket
from amqtt.mqtt.pubrec import PubrecPacket
@ -677,7 +678,7 @@ def test_matches_multi_level_wildcard(broker):
"sport/tennis",
"sport/tennis/",
]:
assert not broker.matches(bad_topic, test_filter)
assert not broker._matches(bad_topic, test_filter)
for good_topic in [
"sport/tennis/player1",
@ -685,7 +686,7 @@ def test_matches_multi_level_wildcard(broker):
"sport/tennis/player1/ranking",
"sport/tennis/player1/score/wimbledon",
]:
assert broker.matches(good_topic, test_filter)
assert broker._matches(good_topic, test_filter)
def test_matches_single_level_wildcard(broker):
@ -696,37 +697,36 @@ def test_matches_single_level_wildcard(broker):
"sport/tennis/player1/",
"sport/tennis/player1/ranking",
]:
assert not broker.matches(bad_topic, test_filter)
assert not broker._matches(bad_topic, test_filter)
for good_topic in [
"sport/tennis/",
"sport/tennis/player1",
"sport/tennis/player2",
]:
assert broker.matches(good_topic, test_filter)
assert broker._matches(good_topic, test_filter)
# @pytest.mark.asyncio
# async def test_broker_broadcast_cancellation(broker):
# topic = "test"
# data = b"data"
# qos = QOS_0
@pytest.mark.asyncio
async def test_broker_broadcast_cancellation(broker):
topic = "test"
data = b"data"
qos = QOS_0
# sub_client = MQTTClient()
# await sub_client.connect("mqtt://127.0.0.1")
# await sub_client.subscribe([(topic, qos)])
sub_client = MQTTClient()
await sub_client.connect("mqtt://127.0.0.1")
await sub_client.subscribe([(topic, qos)])
# with patch.object(BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError) as mocked_mqtt_publish:
# await _client_publish(topic, data, qos)
with patch.object(BrokerProtocolHandler, "mqtt_publish", side_effect=asyncio.CancelledError) as mocked_mqtt_publish:
await _client_publish(topic, data, qos)
# # Second publish triggers the awaiting of first `mqtt_publish` task
# await _client_publish(topic, data, qos)
# await asyncio.sleep(0.01)
# Second publish triggers the awaiting of first `mqtt_publish` task
await _client_publish(topic, data, qos)
await asyncio.sleep(0.01)
# # `assert_awaited` does not exist in Python before `3.8`
# mocked_mqtt_publish.assert_awaited()
mocked_mqtt_publish.assert_awaited()
# # Ensure broadcast loop is still functional and can deliver the message
# await _client_publish(topic, data, qos)
# message = await asyncio.wait_for(sub_client.deliver_message(), timeout_duration=1)
# assert message
# Ensure broadcast loop is still functional and can deliver the message
await _client_publish(topic, data, qos)
message = await asyncio.wait_for(sub_client.deliver_message(), timeout=1)
assert message