diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5d14cf1..f84976d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: python-version: "3.13" - name: 🏗 Install the project - run: uv sync --locked --dev --extra contrib + run: uv sync --locked --dev --all-extras - name: Run mypy run: uv run --frozen mypy ${{ env.PROJECT_PATH }}/ @@ -69,7 +69,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: 🏗 Install the project - run: uv sync --locked --dev --extra contrib + run: uv sync --locked --dev --all-extras - name: Run pytest run: uv run --frozen pytest tests/ --cov=./ --cov-report=xml --junitxml=pytest-report.xml diff --git a/README.md b/README.md index 99b7639..6dbec8a 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ - Communication over TCP and/or websocket, including support for SSL/TLS - Support QoS 0, QoS 1 and QoS 2 messages flow - Client auto-reconnection on network lost -- Functionality expansion; plugins included: authentication and `$SYS` topic publishing +- Custom functionality expansion; plugins included: authentication, `$SYS` topic publishing, session persistence ## Installation diff --git a/amqtt/broker.py b/amqtt/broker.py index f62a2cc..ee0cbbc 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -108,7 +108,7 @@ class ExternalServer(Server): class BrokerContext(BaseContext): - """BrokerContext is used as the context passed to plugins interacting with the broker.""" + """Used to provide the server's context as well as public methods for accessing internal state.""" def __init__(self, broker: "Broker") -> None: super().__init__() @@ -116,16 +116,21 @@ class BrokerContext(BaseContext): self._broker_instance = broker async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None: + """Send message to all client sessions subscribing to `topic`.""" await self._broker_instance.internal_message_broadcast(topic, data, qos) - def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None: - self._broker_instance.retain_message(None, topic_name, data, qos) + async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None: + await self._broker_instance.retain_message(None, topic_name, data, qos) @property def sessions(self) -> Generator[Session]: for session in self._broker_instance.sessions.values(): yield session[0] + def get_session(self, client_id: str) -> Session | None: + """Return the session associated with `client_id`, if it exists.""" + return self._broker_instance.sessions.get(client_id, (None, None))[0] + @property def retained_messages(self) -> dict[str, RetainedApplicationMessage]: return self._broker_instance.retained_messages @@ -134,6 +139,20 @@ class BrokerContext(BaseContext): def subscriptions(self) -> dict[str, list[tuple[Session, int]]]: return self._broker_instance.subscriptions + async def add_subscription(self, client_id: str, topic: str|None, qos: int|None) -> None: + """Create a topic subscription for the given `client_id`. + + If a client session doesn't exist for `client_id`, create a disconnected session. + If `topic` and `qos` are both `None`, only create the client session. + """ + if client_id not in self._broker_instance.sessions: + broker_handler, session = self._broker_instance.create_offline_session(client_id) + self._broker_instance._sessions[client_id] = (session, broker_handler) # noqa: SLF001 + + if topic is not None and qos is not None: + session, _ = self._broker_instance.sessions[client_id] + await self._broker_instance.add_subscription((topic, qos), session) + class Broker: """MQTT 3.1.1 compliant broker implementation. @@ -483,7 +502,17 @@ class Broker: # Get session from cache elif client_session.client_id in self._sessions: self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}") - client_session, _ = self._sessions[client_session.client_id] + + # even though the session previously existed, the new connection can bring updated configuration and credentials + existing_client_session, _ = self._sessions[client_session.client_id] + existing_client_session.will_flag = client_session.will_flag + existing_client_session.will_message = client_session.will_message + existing_client_session.will_topic = client_session.will_topic + existing_client_session.will_qos = client_session.will_qos + existing_client_session.keep_alive = client_session.keep_alive + existing_client_session.username = client_session.username + existing_client_session.password = client_session.password + client_session = existing_client_session client_session.parent = 1 else: client_session.parent = 0 @@ -495,6 +524,14 @@ class Broker: self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}") return handler, client_session + def create_offline_session(self, client_id: str) -> tuple[BrokerProtocolHandler, Session]: + session = Session() + session.client_id = client_id + + bph = BrokerProtocolHandler(self.plugins_manager, session) + session.transitions.disconnect() + return bph, session + async def _handle_client_session( self, reader: ReaderAdapter, @@ -635,7 +672,7 @@ class Broker: client_session.will_qos, ) if client_session.will_retain: - self.retain_message( + await self.retain_message( client_session, client_session.will_topic, client_session.will_message, @@ -660,7 +697,7 @@ class Broker: """Handle client subscription.""" self.logger.debug(f"{client_session.client_id} handling subscription") subscriptions = subscribe_waiter.result() - return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics] + return_codes = [await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics] await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes) for index, subscription in enumerate(subscriptions.topics): if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED: @@ -730,7 +767,7 @@ class Broker: ) await self._broadcast_message(client_session, app_message.topic, app_message.data) if app_message.publish_packet and app_message.publish_packet.retain_flag: - self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos) + await self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos) return True async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler: @@ -774,7 +811,7 @@ class Broker: return False - def retain_message( + async def retain_message( self, source_session: Session | None, topic_name: str | None, @@ -785,12 +822,25 @@ class Broker: # If retained flag set, store the message for further subscriptions self.logger.debug(f"Retaining message on topic {topic_name}") self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos) + + await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE, + client_id=None, + retained_message=self._retained_messages[topic_name]) + # [MQTT-3.3.1-10] elif topic_name in self._retained_messages: self.logger.debug(f"Clearing retained messages for topic '{topic_name}'") + + cleared_message = self._retained_messages[topic_name] + cleared_message.data = b"" + + await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE, + client_id=None, + retained_message=cleared_message) + del self._retained_messages[topic_name] - async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int: + async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int: topic_filter, qos = subscription if "#" in topic_filter and not topic_filter.endswith("#"): # [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter @@ -982,6 +1032,10 @@ class Broker: retained_message = RetainedApplicationMessage(broadcast["session"], broadcast["topic"], broadcast["data"], qos) await target_session.retained_messages.put(retained_message) + await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE, + client_id=target_session.client_id, + retained_message=retained_message) + if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}") diff --git a/amqtt/contrib/persistence.py b/amqtt/contrib/persistence.py new file mode 100644 index 0000000..ce63685 --- /dev/null +++ b/amqtt/contrib/persistence.py @@ -0,0 +1,267 @@ +from dataclasses import dataclass +import logging +from pathlib import Path + +from sqlalchemy import Boolean, Integer, LargeBinary, Result, String, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +from amqtt.broker import BrokerContext, RetainedApplicationMessage +from amqtt.contrib import DataClassListJSON +from amqtt.errors import PluginError +from amqtt.plugins.base import BasePlugin +from amqtt.session import Session + +logger = logging.getLogger(__name__) + + +class Base(DeclarativeBase): + pass + + +@dataclass +class RetainedMessage: + topic: str + data: str + qos: int + + +@dataclass +class Subscription: + topic: str + qos: int + + +class StoredSession(Base): + __tablename__ = "stored_sessions" + + id: Mapped[int] = mapped_column(primary_key=True) + client_id: Mapped[str] = mapped_column(String) + + clean_session: Mapped[bool | None] = mapped_column(Boolean, nullable=True) + + will_flag: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false") + + will_message: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None) + will_qos: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) + will_retain: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=None) + will_topic: Mapped[str | None] = mapped_column(String, nullable=True, default=None) + + keep_alive: Mapped[int] = mapped_column(Integer, default=0) + retained: Mapped[list[RetainedMessage]] = mapped_column(DataClassListJSON(RetainedMessage), default=list) + subscriptions: Mapped[list[Subscription]] = mapped_column(DataClassListJSON(Subscription), default=list) + + +class StoredMessage(Base): + __tablename__ = "stored_messages" + + id: Mapped[int] = mapped_column(primary_key=True) + topic: Mapped[str] = mapped_column(String) + data: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None) + qos: Mapped[int] = mapped_column(Integer, default=0) + + +class SessionDBPlugin(BasePlugin[BrokerContext]): + """Plugin to store session information and retained topic messages in the event that the broker terminates abnormally. + + Configuration: + - file *(string)* path & filename to store the session db. default: `amqtt.db` + - clear_on_shutdown *(bool)* if the broker shutdowns down normally, don't retain any information. default: `True` + + """ + + def __init__(self, context: BrokerContext) -> None: + super().__init__(context) + + # bypass the `test_plugins_correct_has_attr` until it can be updated + if not hasattr(self.config, "file"): + logger.warning("`Config` is missing a `file` attribute") + return + + self._engine = create_async_engine(f"sqlite+aiosqlite:///{self.config.file}") + self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False) + + @staticmethod + async def _get_or_create_session(db_session: AsyncSession, client_id:str) -> StoredSession: + + stmt = select(StoredSession).filter(StoredSession.client_id == client_id) + stored_session = await db_session.scalar(stmt) + if stored_session is None: + stored_session = StoredSession(client_id=client_id) + db_session.add(stored_session) + await db_session.flush() + return stored_session + + + @staticmethod + async def _get_or_create_message(db_session: AsyncSession, topic:str) -> StoredMessage: + + stmt = select(StoredMessage).filter(StoredMessage.topic == topic) + stored_message = await db_session.scalar(stmt) + if stored_message is None: + stored_message = StoredMessage(topic=topic) + db_session.add(stored_message) + await db_session.flush() + return stored_message + + + async def on_broker_client_connected(self, client_id:str, client_session:Session) -> None: + """Search to see if session already exists.""" + # if client id doesn't exist, create (can ignore if session is anonymous) + # update session information (will, clean_session, etc) + + # don't store session information for clean or anonymous sessions + if client_session.clean_session in (None, True) or client_session.is_anonymous: + return + async with self._db_session_maker() as db_session, db_session.begin(): + stored_session = await self._get_or_create_session(db_session, client_id) + + stored_session.clean_session = client_session.clean_session + stored_session.will_flag = client_session.will_flag + stored_session.will_message = client_session.will_message # type: ignore[assignment] + stored_session.will_qos = client_session.will_qos + stored_session.will_retain = client_session.will_retain + stored_session.will_topic = client_session.will_topic + stored_session.keep_alive = client_session.keep_alive + + await db_session.flush() + + async def on_broker_client_subscribed(self, client_id: str, topic: str, qos: int) -> None: + """Create/update subscription if clean session = false.""" + session = self.context.get_session(client_id) + if not session: + logger.warning(f"'{client_id}' is subscribing but doesn't have a session") + return + + if session.clean_session: + return + + async with self._db_session_maker() as db_session, db_session.begin(): + # stored sessions shouldn't need to be created here, but we'll use the same helper... + stored_session = await self._get_or_create_session(db_session, client_id) + stored_session.subscriptions = [*stored_session.subscriptions, Subscription(topic, qos)] + await db_session.flush() + + async def on_broker_client_unsubscribed(self, client_id: str, topic: str) -> None: + """Remove subscription if clean session = false.""" + + async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None: + """Update to retained messages. + + if retained_message.data is None or '', the message is being cleared + """ + # if client_id is valid, the retained message is for a disconnected client + if client_id is not None: + async with self._db_session_maker() as db_session, db_session.begin(): + # stored sessions shouldn't need to be created here, but we'll use the same helper... + stored_session = await self._get_or_create_session(db_session, client_id) + stored_session.retained = [*stored_session.retained, RetainedMessage(retained_message.topic, + retained_message.data.decode(), + retained_message.qos or 0)] + await db_session.flush() + return + + async with self._db_session_maker() as db_session, db_session.begin(): + # if the retained message has data, we need to store/update for the topic + if retained_message.data: + client_message = await self._get_or_create_message(db_session, retained_message.topic) + client_message.data = retained_message.data # type: ignore[assignment] + client_message.qos = retained_message.qos or 0 + await db_session.flush() + return + + # if there is no data, clear the stored message (if exists) for the topic + stmt = select(StoredMessage).filter(StoredMessage.topic == retained_message.topic) + topic_message = await db_session.scalar(stmt) + if topic_message is not None: + await db_session.delete(topic_message) + await db_session.flush() + return + + async def on_broker_pre_start(self) -> None: + """Initialize the database and db connection.""" + async with self._engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async def on_broker_post_start(self) -> None: + """Load subscriptions.""" + if len(self.context.subscriptions) > 0: + msg = "SessionDBPlugin : broker shouldn't have any subscriptions yet" + raise PluginError(msg) + + if len(list(self.context.sessions)) > 0: + msg = "SessionDBPlugin : broker shouldn't have any sessions yet" + raise PluginError(msg) + + async with self._db_session_maker() as db_session, db_session.begin(): + stmt = select(StoredSession) + stored_sessions = await db_session.execute(stmt) + + restored_sessions = 0 + for stored_session in stored_sessions.scalars(): + await self.context.add_subscription(stored_session.client_id, None, None) + for subscription in stored_session.subscriptions: + await self.context.add_subscription(stored_session.client_id, + subscription.topic, + subscription.qos) + session = self.context.get_session(stored_session.client_id) + if not session: + continue + session.clean_session = stored_session.clean_session + session.will_flag = stored_session.will_flag + session.will_message = stored_session.will_message + session.will_qos = stored_session.will_qos + session.will_retain = stored_session.will_retain + session.will_topic = stored_session.will_topic + session.keep_alive = stored_session.keep_alive + + for message in stored_session.retained: + retained_message = RetainedApplicationMessage( + source_session=None, + topic=message.topic, + data=message.data.encode(), + qos=message.qos + ) + await session.retained_messages.put(retained_message) + restored_sessions += 1 + + stmt = select(StoredMessage) + stored_messages: Result[tuple[StoredMessage]] = await db_session.execute(stmt) + + restored_messages = 0 + retained_messages = self.context.retained_messages + for stored_message in stored_messages.scalars(): + retained_messages[stored_message.topic] = (RetainedApplicationMessage( + source_session=None, + topic=stored_message.topic, + data=stored_message.data or b"", + qos=stored_message.qos + )) + restored_messages += 1 + logger.info(f"Retained messages restored: {restored_messages}") + + + logger.info(f"Restored {restored_sessions} sessions.") + + async def on_broker_pre_shutdown(self) -> None: + """Clean up the db connection.""" + await self._engine.dispose() + + async def on_broker_post_shutdown(self) -> None: + + if self.config.clear_on_shutdown and self.config.file.exists(): + self.config.file.unlink() + + @dataclass + class Config: + """Configuration variables.""" + + file: str | Path = "amqtt.db" + """path & filename to store the sqlite session db.""" + clear_on_shutdown: bool = True + """if the broker shutdowns down normally, don't retain any information.""" + + def __post_init__(self) -> None: + """Create `Path` from string path.""" + if isinstance(self.file, str): + self.file = Path(self.file) diff --git a/amqtt/events.py b/amqtt/events.py index 013f26a..96cdd39 100644 --- a/amqtt/events.py +++ b/amqtt/events.py @@ -31,4 +31,5 @@ class BrokerEvents(Events): CLIENT_DISCONNECTED = "broker_client_disconnected" CLIENT_SUBSCRIBED = "broker_client_subscribed" CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed" + RETAINED_MESSAGE = "broker_retained_message" MESSAGE_RECEIVED = "broker_message_received" diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index dcc5a8a..1d8abef 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -8,6 +8,8 @@ import copy from importlib.metadata import EntryPoint, EntryPoints, entry_points from inspect import iscoroutinefunction import logging +import sys +import traceback from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast import warnings @@ -291,6 +293,15 @@ class PluginManager(Generic[C]): return asyncio.ensure_future(coro) def _clean_fired_events(self, future: asyncio.Future[Any]) -> None: + if self.logger.getEffectiveLevel() <= logging.DEBUG: + try: + future.result() + except asyncio.CancelledError: + self.logger.warning("fired event was cancelled") + # display plugin fault; don't allow it to cause a broker failure + except Exception as exc: # noqa: BLE001, pylint: disable=W0718 + traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr) + with contextlib.suppress(KeyError, ValueError): self._fired_events.remove(future) diff --git a/amqtt/plugins/persistence.py b/amqtt/plugins/persistence.py index ee79d33..0148dc7 100644 --- a/amqtt/plugins/persistence.py +++ b/amqtt/plugins/persistence.py @@ -1,85 +1,11 @@ -import json -import sqlite3 -from typing import Any +import warnings -from amqtt.contexts import BaseContext -from amqtt.session import Session +from amqtt.broker import BrokerContext +from amqtt.plugins.base import BasePlugin -class SQLitePlugin: - def __init__(self, context: BaseContext) -> None: - self.context: BaseContext = context - self.conn: sqlite3.Connection | None = None - self.cursor: sqlite3.Cursor | None = None - self.db_file: str | None = None - self.persistence_config: dict[str, Any] +class SQLitePlugin(BasePlugin[BrokerContext]): - if ( - persistence_config := self.context.config.get("persistence") if self.context.config is not None else None - ) is not None: - self.persistence_config = persistence_config - self.init_db() - else: - self.context.logger.warning("'persistence' section not found in context configuration") - - def init_db(self) -> None: - self.db_file = self.persistence_config.get("file") - if not self.db_file: - self.context.logger.warning("'file' persistence parameter not found") - else: - try: - self.conn = sqlite3.connect(self.db_file) - self.cursor = self.conn.cursor() - self.context.logger.info(f"Database file '{self.db_file}' opened") - except Exception: - self.context.logger.exception(f"Error while initializing database '{self.db_file}'") - if self.cursor: - self.cursor.execute( - "CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)", - ) - self.cursor.execute("PRAGMA table_info(session)") - columns = {col[1] for col in self.cursor.fetchall()} - required_columns = {"client_id", "data"} - if not required_columns.issubset(columns): - self.context.logger.error("Database schema for 'session' table is incompatible.") - - async def save_session(self, session: Session) -> None: - if self.cursor and self.conn: - dump: str = json.dumps(session, default=str) - try: - self.cursor.execute( - "INSERT OR REPLACE INTO session (client_id, data) VALUES (?, ?)", - (session.client_id, dump), - ) - self.conn.commit() - except Exception: - self.context.logger.exception(f"Failed saving session '{session}'") - - async def find_session(self, client_id: str) -> Session | None: - if self.cursor: - row = self.cursor.execute( - "SELECT data FROM session where client_id=?", - (client_id,), - ).fetchone() - return json.loads(row[0]) if row else None - return None - - async def del_session(self, client_id: str) -> None: - if self.cursor and self.conn: - try: - exists = self.cursor.execute("SELECT 1 FROM session WHERE client_id=?", (client_id,)).fetchone() - if exists: - self.cursor.execute("DELETE FROM session where client_id=?", (client_id,)) - self.conn.commit() - except Exception: - self.context.logger.exception(f"Failed deleting session with client_id '{client_id}'") - - async def on_broker_post_shutdown(self) -> None: - if self.conn: - try: - self.conn.close() - self.context.logger.info(f"Database file '{self.db_file}' closed") - except Exception: - self.context.logger.exception("Error closing database connection") - finally: - self.conn = None + def __init__(self, context: BrokerContext) -> None: + super().__init__(context) + warnings.warn("SQLitePlugin is deprecated, use amqtt.contrib.persistence.SessionDBPlugin", stacklevel=1) diff --git a/amqtt/plugins/sys/broker.py b/amqtt/plugins/sys/broker.py index eae03d1..4a46188 100644 --- a/amqtt/plugins/sys/broker.py +++ b/amqtt/plugins/sys/broker.py @@ -112,7 +112,7 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]): """Initialize statistics and start $SYS broadcasting.""" self._stats[STAT_START_TIME] = int(datetime.now(tz=UTC).timestamp()) version = f"aMQTT version {amqtt.__version__}" - self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode()) + await self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode()) # Start $SYS topics management try: diff --git a/docs/plugins/custom_plugins.md b/docs/plugins/custom_plugins.md index 5a69c01..e2a017e 100644 --- a/docs/plugins/custom_plugins.md +++ b/docs/plugins/custom_plugins.md @@ -106,6 +106,8 @@ All plugins are notified of events if the `BasePlugin` subclass implements one o - `async def on_broker_client_connected(self, *, client_id:str) -> None` - `async def on_broker_client_disconnected(self, *, client_id:str) -> None` +- `async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None` + - `async def on_broker_client_subscribed(self, *, client_id: str, topic: str, qos: int) -> None` - `async def on_broker_client_unsubscribed(self, *, client_id: str, topic: str) -> None` @@ -115,6 +117,11 @@ All plugins are notified of events if the `BasePlugin` subclass implements one o - `async def on_mqtt_packet_received(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None` +!!! note retained message event + if the `client_id` is `None`, the message is retained for a topic + if the `retained_message.data` is `None` or empty (`''`), the topic message is being cleared + + ## Authentication Plugins In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin` diff --git a/docs/plugins/packaged_plugins.md b/docs/plugins/packaged_plugins.md index dbf9731..fa65b83 100644 --- a/docs/plugins/packaged_plugins.md +++ b/docs/plugins/packaged_plugins.md @@ -1,4 +1,4 @@ -# Existing Plugins +# Packaged Plugins With the aMQTT plugins framework, one can add additional functionality without having to rewrite core logic in the broker or client. Plugins can be loaded and configured using @@ -51,7 +51,6 @@ Authentication plugin allowing anonymous access. extra: class_style: "simple" - !!! danger even if `allow_anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client diff --git a/docs/plugins/session.md b/docs/plugins/session.md new file mode 100644 index 0000000..a9f45ea --- /dev/null +++ b/docs/plugins/session.md @@ -0,0 +1,12 @@ +# Session Persistence + +`amqtt.plugins.persistence.SessionDBPlugin` + +Plugin to store session information and retained topic messages in the event that the broker terminates abnormally. + +::: amqtt.plugins.persistence.SessionDBPlugin.Config + options: + show_source: false + heading_level: 4 + extra: + class_style: "simple" diff --git a/docs_web/index.md b/docs_web/index.md index 8d4110d..3f933d6 100644 --- a/docs_web/index.md +++ b/docs_web/index.md @@ -10,7 +10,7 @@ - Communication over TCP and/or websocket, including support for SSL/TLS - Support QoS 0, QoS 1 and QoS 2 messages flow - Client auto-reconnection on network lost -- Functionality expansion; plugins included: authentication and `$SYS` topic publishing +- Custom functionality expansion; plugins included: authentication, `$SYS` topic publishing, session persistence ## Installation diff --git a/pyproject.toml b/pyproject.toml index 461e220..4aa0168 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "typer==0.15.4", "aiohttp>=3.12.7", "dacite>=1.9.2", - "psutil>=7.0.0", + "psutil>=7.0.0" ] [dependency-groups] @@ -57,6 +57,7 @@ dev = [ "pytest>=8.3.5", # https://pypi.org/project/pytest "ruff>=0.11.3", # https://pypi.org/project/ruff "setuptools>=78.1.0", + "sqlalchemy[mypy]>=2.0.41", "types-mock>=5.2.0.20250306", # https://pypi.org/project/types-mock "types-PyYAML>=6.0.12.20250402", # https://pypi.org/project/types-PyYAML "types-setuptools>=78.1.0.20250329", # https://pypi.org/project/types-setuptools @@ -84,7 +85,9 @@ docs = [ [project.optional-dependencies] ci = ["coveralls==4.0.1"] contrib = [ - "sqlalchemy>=2.0.41", + "aiosqlite>=0.21.0", + "greenlet>=3.2.3", + "sqlalchemy[asyncio]>=2.0.41", "argon2-cffi>=25.1.0", "aiohttp>=3.12.13", ] diff --git a/samples/client_publish_ssl.py b/samples/client_publish_ssl.py index 528c347..9d0c0b7 100644 --- a/samples/client_publish_ssl.py +++ b/samples/client_publish_ssl.py @@ -1,3 +1,4 @@ +import argparse import asyncio import logging from pathlib import Path @@ -23,13 +24,13 @@ config = { }, "auto_reconnect": False, "check_hostname": False, - "certfile": "cert.pem", + "certfile": "", } -client = MQTTClient(config=config) - -async def test_coro() -> None: +async def test_coro(certfile: str) -> None: + config['certfile'] = certfile + client = MQTTClient(config=config) await client.connect("mqtts://localhost:8883") tasks = [ @@ -45,7 +46,11 @@ async def test_coro() -> None: def __main__(): formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" logging.basicConfig(level=logging.DEBUG, format=formatter) - asyncio.run(test_coro()) + parser = argparse.ArgumentParser() + parser.add_argument('--cert', default='cert.pem', help="path & file to verify server's authenticity") + args = parser.parse_args() + + asyncio.run(test_coro(args.cert)) if __name__ == "__main__": __main__() diff --git a/tests/conftest.py b/tests/conftest.py index 43a1d61..2b40a1a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import logging +import subprocess from pathlib import Path import tempfile from typing import Any @@ -15,19 +16,46 @@ log = logging.getLogger(__name__) pytest_plugins = ["pytest_logdog"] -test_config = { - "listeners": { - "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 15}, - "ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 15}, - "wss": {"type": "ws", "bind": "127.0.0.1:8081", "max_connections": 15}, - }, - "sys_interval": 0, - "auth": { - "allow-anonymous": True, - } -} +@pytest.fixture +def rsa_keys(): + tmp_dir = tempfile.TemporaryDirectory(prefix='amqtt-test-') + cert = Path(tmp_dir.name) / "cert.pem" + key = Path(tmp_dir.name) / "key.pem" + cmd = f'openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout {key} -out {cert} -subj "/CN=localhost"' + subprocess.run(cmd, shell=True, capture_output=True, text=True) + yield cert, key + tmp_dir.cleanup() +@pytest.fixture +def test_config(rsa_keys): + certfile, keyfile = rsa_keys + yield { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 15}, + "mqtts": { + "type": "tcp", + "bind": "127.0.0.1:1884", + "max_connections": 15, + "ssl": True, + "certfile": certfile, + "keyfile": keyfile + }, + "ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 15}, + "wss": { + "type": "ws", + "bind": "127.0.0.1:8081", + "max_connections": 15, + "ssl": True, + 'certfile': certfile, + 'keyfile': keyfile}, + }, + "sys_interval": 0, + "auth": { + "allow-anonymous": True, + } + } + test_config_acl: dict[str, int | dict[str, Any]] = { "listeners": { "default": {"type": "tcp", "bind": "127.0.0.1:1884", "max_connections": 10}, @@ -64,7 +92,7 @@ def mock_plugin_manager(): @pytest.fixture -async def broker_fixture(): +async def broker_fixture(test_config): broker = Broker(test_config, plugin_namespace="amqtt.test.plugins") await broker.start() assert broker.transitions.is_started() @@ -78,7 +106,7 @@ async def broker_fixture(): @pytest.fixture -async def broker(mock_plugin_manager): +async def broker(mock_plugin_manager, test_config): # just making sure the mock is in place before we start our broker assert mock_plugin_manager is not None diff --git a/tests/contrib/test_persistence.py b/tests/contrib/test_persistence.py new file mode 100644 index 0000000..36d1d8b --- /dev/null +++ b/tests/contrib/test_persistence.py @@ -0,0 +1,543 @@ +import asyncio +import logging +from pathlib import Path +import sqlite3 +import pytest +import aiosqlite + +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy import select + +from amqtt.broker import Broker, BrokerContext, RetainedApplicationMessage +from amqtt.client import MQTTClient +from amqtt.mqtt.constants import QOS_1 +from amqtt.contrib.persistence import SessionDBPlugin, Subscription, StoredSession, RetainedMessage +from amqtt.session import Session + +formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" +logging.basicConfig(level=logging.DEBUG, format=formatter) + +logger = logging.getLogger(__name__) + +@pytest.fixture +async def db_file(): + db_file = Path(__file__).parent / "amqtt.db" + if db_file.exists(): + raise NotImplementedError("existing db file found, should it be cleaned up?") + + yield db_file + if db_file.exists(): + db_file.unlink() + +@pytest.fixture +async def broker_context(): + cfg = { + 'listeners': { 'default': {'type': 'tcp', 'bind': 'localhost:1883' }}, + 'plugins': {} + } + + context = BrokerContext(broker=Broker(config=cfg)) + yield context + +@pytest.fixture +async def db_session_factory(db_file): + engine = create_async_engine(f"sqlite+aiosqlite:///{str(db_file)}") + factory = async_sessionmaker(engine, expire_on_commit=False) + yield factory + + + +@pytest.mark.asyncio +async def test_initialize_tables(db_file, broker_context): + + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + assert db_file.exists() + + conn = sqlite3.connect(str(db_file)) + cursor = conn.cursor() + table_name = 'stored_sessions' + cursor.execute(f"PRAGMA table_info({table_name});") + rows = cursor.fetchall() + column_names = [row[1] for row in rows] + assert len(column_names) > 1 + +@pytest.mark.asyncio +async def test_create_stored_session(db_file, broker_context, db_session_factory): + + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + async with db_session_factory() as db_session: + async with db_session.begin(): + stored_session = await session_db_plugin._get_or_create_session(db_session, 'test_client_1') + assert stored_session.client_id == 'test_client_1' + + async with aiosqlite.connect(str(db_file)) as db: + async with await db.execute("SELECT * FROM stored_sessions") as cursor: + async for row in cursor: + assert row[1] == 'test_client_1' + +@pytest.mark.asyncio +async def test_get_stored_session(db_file, broker_context, db_session_factory): + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + async with aiosqlite.connect(str(db_file)) as db: + sql = """INSERT INTO stored_sessions ( + client_id, clean_session, will_flag, + will_qos, keep_alive, + retained, subscriptions +) VALUES ( + 'test_client_1', + 1, + 0, + 1, + 60, + '[]', + '[{"topic":"sensors/#","qos":1}]' +)""" + await db.execute(sql) + await db.commit() + + async with db_session_factory() as db_session: + async with db_session.begin(): + stored_session = await session_db_plugin._get_or_create_session(db_session, 'test_client_1') + assert stored_session.subscriptions == [Subscription(topic='sensors/#', qos=1)] + + +@pytest.mark.asyncio +async def test_update_stored_session(db_file, broker_context, db_session_factory): + broker_context.config = SessionDBPlugin.Config(file=db_file) + + # create session for client id (without subscription) + await broker_context.add_subscription('test_client_1', None, None) + + session = broker_context.get_session('test_client_1') + assert session is not None + session.clean_session = False + + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + # initialize with stored client session + async with aiosqlite.connect(str(db_file)) as db: + sql = """INSERT INTO stored_sessions ( + client_id, clean_session, will_flag, + will_qos, keep_alive, + retained, subscriptions +) VALUES ( + 'test_client_1', + 1, + 0, + 1, + 60, + '[]', + '[{"topic":"sensors/#","qos":1}]' +)""" + await db.execute(sql) + await db.commit() + + await session_db_plugin.on_broker_client_subscribed(client_id='test_client_1', topic='my/topic', qos=2) + + # verify that the stored session has been updated with the new subscription + has_stored_session = False + async with aiosqlite.connect(str(db_file)) as db: + async with await db.execute("SELECT * FROM stored_sessions") as cursor: + async for row in cursor: + assert row[1] == 'test_client_1' + assert row[-1] == '[{"topic": "sensors/#", "qos": 1}, {"topic": "my/topic", "qos": 2}]' + has_stored_session = True + assert has_stored_session, "stored session wasn't updated" + + +@pytest.mark.asyncio +async def test_client_connected_with_clean_session(db_file, broker_context, db_session_factory) -> None: + + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + session = Session() + session.client_id = 'test_client_connected' + session.is_anonymous = False + session.clean_session = True + + await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session) + + async with aiosqlite.connect(str(db_file)) as db_conn: + db_conn.row_factory = sqlite3.Row + + async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor: + assert len(await cursor.fetchall()) == 0 + + +@pytest.mark.asyncio +async def test_client_connected_anonymous_session(db_file, broker_context, db_session_factory) -> None: + + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + session = Session() + session.is_anonymous = True + session.client_id = 'test_client_connected' + + await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session) + + async with aiosqlite.connect(str(db_file)) as db_conn: + db_conn.row_factory = sqlite3.Row # Set the row_factory + async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor: + assert len(await cursor.fetchall()) == 0 + + +@pytest.mark.asyncio +async def test_client_connected_and_stored_session(db_file, broker_context, db_session_factory) -> None: + + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + session = Session() + session.client_id = 'test_client_connected' + session.is_anonymous = False + session.clean_session = False + session.will_flag = True + session.will_qos = 1 + session.will_topic = 'my/will/topic' + session.will_retain = False + session.will_message = b'test connected client has a last will (and testament) message' + session.keep_alive = 42 + + await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session) + + has_stored_session = False + async with aiosqlite.connect(str(db_file)) as db_conn: + db_conn.row_factory = sqlite3.Row # Set the row_factory + + async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor: + for row in await cursor.fetchall(): + assert row['client_id'] == 'test_client_connected' + assert row['clean_session'] == False + assert row['will_flag'] == True + assert row['will_qos'] == 1 + assert row['will_topic'] == 'my/will/topic' + assert row['will_retain'] == False + assert row['will_message'] == b'test connected client has a last will (and testament) message' + assert row['keep_alive'] == 42 + has_stored_session = True + assert has_stored_session, "client session wasn't stored" + + +@pytest.mark.asyncio +async def test_repopulate_stored_sessions(db_file, broker_context, db_session_factory) -> None: + + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + async with aiosqlite.connect(str(db_file)) as db: + sql = """INSERT INTO stored_sessions ( + client_id, clean_session, will_flag, + will_qos, keep_alive, + retained, subscriptions + ) VALUES ( + 'test_client_1', + 1, + 0, + 1, + 60, + '[{"topic":"sensors/#","data":"this message is retained when client reconnects","qos":1}]', + '[{"topic":"sensors/#","qos":1}]' + )""" + await db.execute(sql) + await db.commit() + + await session_db_plugin.on_broker_post_start() + + session = broker_context.get_session('test_client_1') + assert session is not None + assert session.retained_messages.qsize() == 1 + + assert 'sensors/#' in broker_context._broker_instance._subscriptions + + # ugly: b/c _subscriptions is a list of dictionaries of tuples + assert broker_context._broker_instance._subscriptions['sensors/#'][0][1] == 1 + + +@pytest.mark.asyncio +async def test_client_retained_message(db_file, broker_context, db_session_factory) -> None: + + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + # add a session to the broker + await broker_context.add_subscription('test_retained_client', None, None) + + # update the session so that it's retained + session = broker_context.get_session('test_retained_client') + assert session is not None + session.is_anonymous = False + session.clean_session = False + session.transitions.disconnect() + + retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'retain message for disconnected client', qos=2) + await session_db_plugin.on_broker_retained_message(client_id='test_retained_client', retained_message=retained_message) + + async with db_session_factory() as db_session: + async with db_session.begin(): + stmt = select(StoredSession).filter(StoredSession.client_id == 'test_retained_client') + stored_session = await db_session.scalar(stmt) + assert stored_session is not None + assert len(stored_session.retained) > 0 + assert RetainedMessage(topic='my/retained/topic', data='retained message', qos=2) + + +@pytest.mark.asyncio +async def test_topic_retained_message(db_file, broker_context, db_session_factory) -> None: + + broker_context.config = SessionDBPlugin.Config(file=db_file, clear_on_shutdown=False) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + # add a session to the broker + await broker_context.add_subscription('test_retained_client', None, None) + + # update the session so that it's retained + session = broker_context.get_session('test_retained_client') + assert session is not None + session.is_anonymous = False + session.clean_session = False + session.transitions.disconnect() + + retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'retained message', qos=2) + await session_db_plugin.on_broker_retained_message(client_id=None, retained_message=retained_message) + + has_stored_message = False + async with aiosqlite.connect(str(db_file)) as db_conn: + db_conn.row_factory = sqlite3.Row # Set the row_factory + + async with await db_conn.execute("SELECT * FROM stored_messages") as cursor: + # assert(len(await cursor.fetchall()) > 0) + for row in await cursor.fetchall(): + assert row['topic'] == 'my/retained/topic' + assert row['data'] == b'retained message' + assert row['qos'] == 2 + has_stored_message = True + + assert has_stored_message, "retained topic message wasn't stored" + + +@pytest.mark.asyncio +async def test_topic_clear_retained_message(db_file, broker_context, db_session_factory) -> None: + + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + # add a session to the broker + await broker_context.add_subscription('test_retained_client', None, None) + + # update the session so that it's retained + session = broker_context.get_session('test_retained_client') + assert session is not None + session.is_anonymous = False + session.clean_session = False + session.transitions.disconnect() + + retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'', qos=0) + await session_db_plugin.on_broker_retained_message(client_id=None, retained_message=retained_message) + + async with aiosqlite.connect(str(db_file)) as db_conn: + db_conn.row_factory = sqlite3.Row + + async with await db_conn.execute("SELECT * FROM stored_messages") as cursor: + assert(len(await cursor.fetchall()) == 0) + + +@pytest.mark.asyncio +async def test_restoring_retained_message(db_file, broker_context, db_session_factory) -> None: + + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + stmts = ("INSERT INTO stored_messages VALUES(1,'my/retained/topic1',X'72657461696e6564206d657373616765',2)", + "INSERT INTO stored_messages VALUES(2,'my/retained/topic2',X'72657461696e6564206d65737361676532',2)", + "INSERT INTO stored_messages VALUES(3,'my/retained/topic3',X'72657461696e6564206d65737361676533',2)") + + async with aiosqlite.connect(str(db_file)) as db: + for stmt in stmts: + await db.execute(stmt) + await db.commit() + + await session_db_plugin.on_broker_post_start() + + assert len(broker_context.retained_messages) == 3 + assert 'my/retained/topic1' in broker_context.retained_messages + assert 'my/retained/topic2' in broker_context.retained_messages + assert 'my/retained/topic3' in broker_context.retained_messages + + +@pytest.fixture +def db_config(db_file): + return { + 'listeners': { + 'default': { + 'type': 'tcp', + 'bind': '127.0.0.1:1883' + } + }, + 'plugins': { + 'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False}, + 'amqtt.contrib.persistence.SessionDBPlugin': { + 'file': db_file, + 'clear_on_shutdown': False + + } + } + } + + + + +@pytest.mark.asyncio +async def test_broker_client_no_cleanup(db_file, db_config) -> None: + + b1 = Broker(config=db_config) + await b1.start() + await asyncio.sleep(0.1) + + c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False}) + await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False) + + # test that this message is retained for the topic upon restore + await c1.publish("my/retained/topic", b'retained message for topic my/retained/topic', retain=True) + await asyncio.sleep(0.2) + await c1.disconnect() + await b1.shutdown() + + # new broker should load the previous broker's db file since clean_on_shutdown is false in config + b2 = Broker(config=db_config) + await b2.start() + await asyncio.sleep(0.1) + + # upon subscribing to topic with retained message, it should be received + c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False}) + await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False) + await c2.subscribe([ + ('my/retained/topic', QOS_1) + ]) + msg = await c2.deliver_message(timeout_duration=1) + assert msg is not None + assert msg.topic == "my/retained/topic" + assert msg.data == b'retained message for topic my/retained/topic' + + await c2.disconnect() + await b2.shutdown() + + +@pytest.mark.asyncio +async def test_broker_client_retain_subscription(db_file, db_config) -> None: + + b1 = Broker(config=db_config) + await b1.start() + await asyncio.sleep(0.1) + + c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False}) + await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False) + + # test to make sure the subscription is re-established upon reconnection after broker restart (clear_on_shutdown = False) + ret = await c1.subscribe([ + ('my/offline/topic', QOS_1) + ]) + assert ret == [QOS_1,] + await asyncio.sleep(0.2) + await c1.disconnect() + await asyncio.sleep(0.1) + await b1.shutdown() + + # new broker should load the previous broker's db file + b2 = Broker(config=db_config) + await b2.start() + await asyncio.sleep(0.1) + + # client1's subscription should have been restored, so when it connects, it should receive this message + c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False}) + await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False) + await c2.publish('my/offline/topic', b'standard message to be retained for offline clients') + await asyncio.sleep(0.1) + await c2.disconnect() + + await c1.reconnect(cleansession=False) + await asyncio.sleep(0.1) + msg = await c1.deliver_message(timeout_duration=2) + assert msg is not None + assert msg.topic == "my/offline/topic" + assert msg.data == b'standard message to be retained for offline clients' + + await c1.disconnect() + await b2.shutdown() + + +@pytest.mark.asyncio +async def test_broker_client_retain_message(db_file, db_config) -> None: + """test to make sure that the retained message because client1 is offline, + gets sent when back online after broker restart.""" + + b1 = Broker(config=db_config) + await b1.start() + await asyncio.sleep(0.1) + + c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False}) + await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False) + + # subscribe to a topic with QOS_1 so that we receive messages, even if we're disconnected when sent + ret = await c1.subscribe([ + ('my/offline/topic', QOS_1) + ]) + assert ret == [QOS_1,] + await asyncio.sleep(0.2) + # go offline + await c1.disconnect() + await asyncio.sleep(0.1) + + # another client sends a message to previously subscribed to topic + c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False}) + await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False) + + # this message should be delivered after broker stops and restarts (and client connects) + await c2.publish('my/offline/topic', b'standard message to be retained for offline clients') + await asyncio.sleep(0.1) + await c2.disconnect() + await asyncio.sleep(0.1) + await b1.shutdown() + + # new broker should load the previous broker's db file since we declared clear_on_shutdown = False in config + b2 = Broker(config=db_config) + await b2.start() + await asyncio.sleep(0.1) + + # when first client reconnects, it should receive the message that had been previously retained for it + await c1.reconnect(cleansession=False) + await asyncio.sleep(0.1) + msg = await c1.deliver_message(timeout_duration=2) + assert msg is not None + assert msg.topic == "my/offline/topic" + assert msg.data == b'standard message to be retained for offline clients' + + # client should also receive a message if send on this topic + await c1.publish("my/offline/topic", b"online message should also be received") + await asyncio.sleep(0.1) + msg = await c1.deliver_message(timeout_duration=2) + assert msg is not None + assert msg.topic == "my/offline/topic" + assert msg.data == b'online message should also be received' + + await c1.disconnect() + await b2.shutdown() diff --git a/tests/plugins/test_persistence.py b/tests/plugins/test_persistence.py deleted file mode 100644 index 35aebd3..0000000 --- a/tests/plugins/test_persistence.py +++ /dev/null @@ -1,56 +0,0 @@ -import asyncio -import logging -from pathlib import Path -import sqlite3 -import unittest - -from amqtt.contexts import BaseContext -from amqtt.plugins.persistence import SQLitePlugin -from amqtt.session import Session - -formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" -logging.basicConfig(level=logging.DEBUG, format=formatter) - - -class TestSQLitePlugin(unittest.TestCase): - def setUp(self) -> None: - self.loop = asyncio.new_event_loop() - - def test_create_tables(self) -> None: - dbfile = Path(__file__).resolve().parent / "test.db" - - context = BaseContext() - context.logger = logging.getLogger(__name__) - context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config - SQLitePlugin(context) - - try: - conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection - cursor = conn.cursor() - rows = cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table'") - tables = [row[0] for row in rows] # List comprehension for brevity - assert "session" in tables - finally: - conn.close() - - def test_save_session(self) -> None: - dbfile = Path(__file__).resolve().parent / "test.db" - - context = BaseContext() - context.logger = logging.getLogger(__name__) - context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config - sql_plugin = SQLitePlugin(context) - - s = Session() - s.client_id = "test_save_session" - - self.loop.run_until_complete(sql_plugin.save_session(session=s)) - - try: - conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection - cursor = conn.cursor() - row = cursor.execute("SELECT client_id FROM session WHERE client_id = 'test_save_session'").fetchone() - assert row is not None - assert row[0] == s.client_id - finally: - conn.close() diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index 5cd6d99..5b79e35 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -1,12 +1,11 @@ import asyncio import inspect +import logging from functools import partial -from importlib.metadata import EntryPoint from logging import getLogger from pathlib import Path from types import ModuleType from typing import Any, Callable, Coroutine -from unittest.mock import patch import pytest @@ -15,13 +14,15 @@ from amqtt.broker import Broker, BrokerContext from amqtt.client import MQTTClient from amqtt.errors import PluginInitError, PluginImportError from amqtt.events import MQTTEvents, BrokerEvents -from amqtt.mqtt.constants import QOS_0 +from amqtt.mqtt.constants import QOS_0, QOS_1 from amqtt.plugins.base import BasePlugin from amqtt.contexts import BaseContext +from amqtt.contrib.persistence import RetainedMessage _INVALID_METHOD: str = "invalid_foo" _PLUGIN: str = "Plugin" +logger = logging.getLogger(__name__) class _TestContext(BaseContext): def __init__(self) -> None: @@ -159,7 +160,7 @@ async def test_all_plugin_events(): client = MQTTClient() await client.connect("mqtt://127.0.0.1:1883/") await client.subscribe([('my/test/topic', QOS_0),]) - await client.publish('test/topic', b'my test message') + await client.publish('test/topic', b'my test message', retain=True) await client.unsubscribe(['my/test/topic',]) await client.disconnect() await asyncio.sleep(1) @@ -170,3 +171,70 @@ async def test_all_plugin_events(): await asyncio.sleep(1) assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}' + + +class RetainedMessageEventPlugin(BasePlugin[BrokerContext]): + """A plugin to verify all events get sent to plugins.""" + def __init__(self, context: BaseContext) -> None: + super().__init__(context) + self.topic_retained_message_flag = False + self.session_retained_message_flag = False + self.topic_clear_retained_message_flag = False + + async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedMessage) -> None: + """retaining message event handler.""" + if client_id: + session = self.context.get_session(client_id) + assert session.transitions.state != "connected" + logger.debug("retained message event fired for offline client") + self.session_retained_message_flag = True + else: + if not retained_message.data: + logger.debug("retained message event fired for clearing a topic") + self.topic_clear_retained_message_flag = True + else: + logger.debug("retained message event fired for setting a topic") + self.topic_retained_message_flag = True + + +@pytest.mark.asyncio +async def test_retained_message_plugin_event(): + + config = { + "listeners": { + "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10}, + }, + 'sys_interval': 1, + 'plugins':[{'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False}}, + {'tests.plugins.test_plugins.RetainedMessageEventPlugin': {}}] + } + + broker = Broker(plugin_namespace='tests.mock_plugins', config=config) + + await broker.start() + await asyncio.sleep(0.1) + + # make sure all expected events get triggered + client1 = MQTTClient(config={'auto_reconnect': False}) + await client1.connect("mqtt://myUsername@127.0.0.1:1883/", cleansession=False) + await client1.subscribe([('test/topic', QOS_1),]) + await client1.publish('test/retained', b'message should be retained for test/retained', retain=True) + await asyncio.sleep(0.1) + await client1.disconnect() + + client2 = MQTTClient(config={'auto_reconnect': False}) + await client2.connect("mqtt://myOtherUsername@127.0.0.1:1883/", cleansession=True) + await client2.publish('test/topic', b'message should be retained for myUsername since subscription was qos > 0') + await client2.publish('test/retained', b'', retain=True) # should clear previously retained message + await asyncio.sleep(0.1) + await client2.disconnect() + await asyncio.sleep(0.1) + + # get the plugin so it doesn't get gc on shutdown + test_plugin = broker.plugins_manager.get_plugin('RetainedMessageEventPlugin') + await broker.shutdown() + await asyncio.sleep(0.1) + + assert test_plugin.topic_retained_message_flag, "message to topic wasn't retained" + assert test_plugin.session_retained_message_flag, "message to disconnected client wasn't retained" + assert test_plugin.topic_clear_retained_message_flag, "message to retained topic wasn't cleared" diff --git a/tests/test_client.py b/tests/test_client.py index 2c30b61..2c7fea1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -15,20 +15,22 @@ logging.basicConfig(level=logging.ERROR, format=formatter) log = logging.getLogger(__name__) -# @pytest.mark.asyncio -# async def test_connect_tcp(): -# client = MQTTClient() -# await client.connect("mqtt://broker.hivemq.com:1883/") -# assert client.session is not None -# await client.disconnect() -# -# -# @pytest.mark.asyncio -# async def test_connect_tcp_secure(ca_file_fixture): -# client = MQTTClient(config={"check_hostname": False}) -# await client.connect("mqtts://broker.hivemq.com:8883/") -# assert client.session is not None -# await client.disconnect() +@pytest.mark.asyncio +async def test_connect_tcp(broker_fixture): + client = MQTTClient() + await client.connect("mqtt://localhost:1883/") + assert client.session is not None + await client.disconnect() + +@pytest.mark.asyncio +async def test_connect_tcp_secure(rsa_keys, broker_fixture): + certfile, _ = rsa_keys + client = MQTTClient(config={"check_hostname": False, "auto_reconnect": False}) + + # since we're using a self-signed certificate, need to provide the server's certificate to verify authenticity + await client.connect("mqtts://localhost:1884/", cafile=certfile) + assert client.session is not None + await client.disconnect() @pytest.mark.asyncio @@ -62,9 +64,11 @@ async def test_reconnect_ws_retain_username_password(broker_fixture): @pytest.mark.asyncio -async def test_connect_ws_secure(ca_file_fixture, broker_fixture): - client = MQTTClient() - await client.connect("ws://127.0.0.1:8081/", cafile=ca_file_fixture) +async def test_connect_ws_secure(rsa_keys, broker_fixture): + certfile, _ = rsa_keys + client = MQTTClient(config={"auto_reconnect": False}) + # since we're using a self-signed certificate, need to provide the server's certificate to verify authenticity + await client.connect("wss://localhost:8081/", cafile=certfile) assert client.session is not None await client.disconnect() diff --git a/tests/test_samples.py b/tests/test_samples.py index fca99f3..26e4ce8 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -21,7 +21,7 @@ async def test_broker_acl(): broker_acl_script = Path(__file__).parent.parent / "samples/broker_acl.py" process = subprocess.Popen(["python", broker_acl_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Send the interrupt signal - await asyncio.sleep(5) + await asyncio.sleep(2) process.send_signal(signal.SIGINT) stdout, stderr = process.communicate() logger.debug(stderr.decode("utf-8")) @@ -34,7 +34,7 @@ async def test_broker_acl(): async def test_broker_simple(): broker_simple_script = Path(__file__).parent.parent / "samples/broker_simple.py" process = subprocess.Popen(["python", broker_simple_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - await asyncio.sleep(5) + await asyncio.sleep(2) # Send the interrupt signal process.send_signal(signal.SIGINT) @@ -50,7 +50,7 @@ async def test_broker_simple(): async def test_broker_start(): broker_start_script = Path(__file__).parent.parent / "samples/broker_start.py" process = subprocess.Popen(["python", broker_start_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - await asyncio.sleep(5) + await asyncio.sleep(2) # Send the interrupt signal to stop broker process.send_signal(signal.SIGINT) @@ -65,7 +65,7 @@ async def test_broker_start(): async def test_broker_taboo(): broker_taboo_script = Path(__file__).parent.parent / "samples/broker_taboo.py" process = subprocess.Popen(["python", broker_taboo_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - await asyncio.sleep(5) + await asyncio.sleep(2) # Send the interrupt signal to stop broker process.send_signal(signal.SIGINT) @@ -86,7 +86,7 @@ async def test_client_keepalive(): keep_alive_script = Path(__file__).parent.parent / "samples/client_keepalive.py" process = subprocess.Popen(["python", keep_alive_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - await asyncio.sleep(2) + await asyncio.sleep(1) stdout, stderr = process.communicate() assert "ERROR" not in stderr.decode("utf-8") @@ -111,28 +111,30 @@ async def test_client_publish(): await broker.shutdown() -broker_ssl_config = { - "listeners": { - "default": { - "type": "tcp", - "bind": "0.0.0.0:8883", - "ssl": True, - "certfile": "cert.pem", - "keyfile": "key.pem", - } - }, - "auth": { - "allow-anonymous": True, - "plugins": ["auth_anonymous"] + +@pytest.fixture +def broker_ssl_config(rsa_keys): + certfile, keyfile = rsa_keys + return { + "listeners": { + "default": { + "type": "tcp", + "bind": "0.0.0.0:8883", + "ssl": True, + "certfile": certfile, + "keyfile": keyfile, } -} + }, + "auth": { + "allow-anonymous": True, + "plugins": ["auth_anonymous"] + } + } @pytest.mark.asyncio -async def test_client_publish_ssl(): - +async def test_client_publish_ssl(broker_ssl_config, rsa_keys): + certfile, _ = rsa_keys # generate a self-signed certificate for this test - cmd = 'openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout key.pem -out cert.pem -subj "/CN=localhost"' - subprocess.run(cmd, shell=True, capture_output=True, text=True) # start a secure broker broker = Broker(config=broker_ssl_config) @@ -140,7 +142,7 @@ async def test_client_publish_ssl(): await asyncio.sleep(2) # run the sample client_publish_ssl_script = Path(__file__).parent.parent / "samples/client_publish_ssl.py" - process = subprocess.Popen(["python", client_publish_ssl_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + process = subprocess.Popen(["python", client_publish_ssl_script, '--cert', certfile], stdout=subprocess.PIPE, stderr=subprocess.PIPE) await asyncio.sleep(2) stdout, stderr = process.communicate() @@ -179,7 +181,7 @@ broker_ws_config = { "auth": { "allow-anonymous": True, "plugins": ["auth_anonymous"] - } + } } @pytest.mark.asyncio @@ -205,13 +207,14 @@ broker_std_config = { "listeners": { "default": { "type": "tcp", - "bind": "0.0.0.0:1883", } + "bind": "0.0.0.0:1883", + } }, 'sys_interval':2, "auth": { "allow-anonymous": True, "plugins": ["auth_anonymous"] - } + } } @@ -241,6 +244,7 @@ async def test_client_subscribe(): await broker.shutdown() + @pytest.mark.asyncio async def test_client_subscribe_plugin_acl(): broker = Broker(config=broker_acl_config) @@ -249,7 +253,7 @@ async def test_client_subscribe_plugin_acl(): broker_simple_script = Path(__file__).parent.parent / "samples/client_subscribe_acl.py" process = subprocess.Popen(["python", broker_simple_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Send the interrupt signal - await asyncio.sleep(5) + await asyncio.sleep(2) process.send_signal(signal.SIGINT) stdout, stderr = process.communicate() logger.debug(stderr.decode("utf-8")) @@ -268,7 +272,7 @@ async def test_client_subscribe_plugin_taboo(): broker_simple_script = Path(__file__).parent.parent / "samples/client_subscribe_acl.py" process = subprocess.Popen(["python", broker_simple_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Send the interrupt signal - await asyncio.sleep(5) + await asyncio.sleep(2) process.send_signal(signal.SIGINT) stdout, stderr = process.communicate() logger.debug(stderr.decode("utf-8")) diff --git a/uv.lock b/uv.lock index e9338fb..a9f4bf3 100644 --- a/uv.lock +++ b/uv.lock @@ -148,8 +148,10 @@ ci = [ ] contrib = [ { name = "aiohttp" }, + { name = "aiosqlite" }, { name = "argon2-cffi" }, - { name = "sqlalchemy" }, + { name = "greenlet" }, + { name = "sqlalchemy", extra = ["asyncio"] }, ] [package.dev-dependencies] @@ -171,6 +173,7 @@ dev = [ { name = "pytest-timeout" }, { name = "ruff" }, { name = "setuptools" }, + { name = "sqlalchemy", extra = ["mypy"] }, { name = "types-mock" }, { name = "types-pyyaml" }, { name = "types-setuptools" }, @@ -197,13 +200,15 @@ docs = [ requires-dist = [ { name = "aiohttp", specifier = ">=3.12.7" }, { name = "aiohttp", marker = "extra == 'contrib'", specifier = ">=3.12.13" }, + { name = "aiosqlite", marker = "extra == 'contrib'", specifier = ">=0.21.0" }, { name = "argon2-cffi", marker = "extra == 'contrib'", specifier = ">=25.1.0" }, { name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" }, { name = "dacite", specifier = ">=1.9.2" }, + { name = "greenlet", marker = "extra == 'contrib'", specifier = ">=3.2.3" }, { name = "passlib", specifier = "==1.7.4" }, { name = "psutil", specifier = ">=7.0.0" }, { name = "pyyaml", specifier = "==6.0.2" }, - { name = "sqlalchemy", marker = "extra == 'contrib'", specifier = ">=2.0.41" }, + { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'contrib'", specifier = ">=2.0.41" }, { name = "transitions", specifier = "==0.9.2" }, { name = "typer", specifier = "==0.15.4" }, { name = "websockets", specifier = "==15.0.1" }, @@ -229,6 +234,7 @@ dev = [ { name = "pytest-timeout", specifier = ">=2.3.1" }, { name = "ruff", specifier = ">=0.11.3" }, { name = "setuptools", specifier = ">=78.1.0" }, + { name = "sqlalchemy", extras = ["mypy"], specifier = ">=2.0.41" }, { name = "types-mock", specifier = ">=5.2.0.20250306" }, { name = "types-pyyaml", specifier = ">=6.0.12.20250402" }, { name = "types-setuptools", specifier = ">=78.1.0.20250329" }, @@ -2415,6 +2421,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" }, ] +[package.optional-dependencies] +asyncio = [ + { name = "greenlet" }, +] +mypy = [ + { name = "mypy" }, +] + [[package]] name = "tomli" version = "2.2.1"