Plugin: rebuild of session persistence (#256)

* rebuild of the persistence plugin to handle storing / restoring sessions
* adding event for broker when a message is being retained
* adding retained message logic to persistence plugin and test cases
* updated documentation
* moving DataClassListJson field to common location for reuse
pull/264/head
Andrew Mirsky 2025-08-09 14:15:45 -04:00 zatwierdzone przez GitHub
rodzic 2a7aa11524
commit de40ca51d3
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
21 zmienionych plików z 1115 dodań i 225 usunięć

Wyświetl plik

@ -33,7 +33,7 @@ jobs:
python-version: "3.13"
- name: 🏗 Install the project
run: uv sync --locked --dev --extra contrib
run: uv sync --locked --dev --all-extras
- name: Run mypy
run: uv run --frozen mypy ${{ env.PROJECT_PATH }}/
@ -69,7 +69,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: 🏗 Install the project
run: uv sync --locked --dev --extra contrib
run: uv sync --locked --dev --all-extras
- name: Run pytest
run: uv run --frozen pytest tests/ --cov=./ --cov-report=xml --junitxml=pytest-report.xml

Wyświetl plik

@ -17,7 +17,7 @@
- Communication over TCP and/or websocket, including support for SSL/TLS
- Support QoS 0, QoS 1 and QoS 2 messages flow
- Client auto-reconnection on network lost
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
- Custom functionality expansion; plugins included: authentication, `$SYS` topic publishing, session persistence
## Installation

Wyświetl plik

@ -108,7 +108,7 @@ class ExternalServer(Server):
class BrokerContext(BaseContext):
"""BrokerContext is used as the context passed to plugins interacting with the broker."""
"""Used to provide the server's context as well as public methods for accessing internal state."""
def __init__(self, broker: "Broker") -> None:
super().__init__()
@ -116,16 +116,21 @@ class BrokerContext(BaseContext):
self._broker_instance = broker
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
"""Send message to all client sessions subscribing to `topic`."""
await self._broker_instance.internal_message_broadcast(topic, data, qos)
def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
self._broker_instance.retain_message(None, topic_name, data, qos)
async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
await self._broker_instance.retain_message(None, topic_name, data, qos)
@property
def sessions(self) -> Generator[Session]:
for session in self._broker_instance.sessions.values():
yield session[0]
def get_session(self, client_id: str) -> Session | None:
"""Return the session associated with `client_id`, if it exists."""
return self._broker_instance.sessions.get(client_id, (None, None))[0]
@property
def retained_messages(self) -> dict[str, RetainedApplicationMessage]:
return self._broker_instance.retained_messages
@ -134,6 +139,20 @@ class BrokerContext(BaseContext):
def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
return self._broker_instance.subscriptions
async def add_subscription(self, client_id: str, topic: str|None, qos: int|None) -> None:
"""Create a topic subscription for the given `client_id`.
If a client session doesn't exist for `client_id`, create a disconnected session.
If `topic` and `qos` are both `None`, only create the client session.
"""
if client_id not in self._broker_instance.sessions:
broker_handler, session = self._broker_instance.create_offline_session(client_id)
self._broker_instance._sessions[client_id] = (session, broker_handler) # noqa: SLF001
if topic is not None and qos is not None:
session, _ = self._broker_instance.sessions[client_id]
await self._broker_instance.add_subscription((topic, qos), session)
class Broker:
"""MQTT 3.1.1 compliant broker implementation.
@ -483,7 +502,17 @@ class Broker:
# Get session from cache
elif client_session.client_id in self._sessions:
self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}")
client_session, _ = self._sessions[client_session.client_id]
# even though the session previously existed, the new connection can bring updated configuration and credentials
existing_client_session, _ = self._sessions[client_session.client_id]
existing_client_session.will_flag = client_session.will_flag
existing_client_session.will_message = client_session.will_message
existing_client_session.will_topic = client_session.will_topic
existing_client_session.will_qos = client_session.will_qos
existing_client_session.keep_alive = client_session.keep_alive
existing_client_session.username = client_session.username
existing_client_session.password = client_session.password
client_session = existing_client_session
client_session.parent = 1
else:
client_session.parent = 0
@ -495,6 +524,14 @@ class Broker:
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
return handler, client_session
def create_offline_session(self, client_id: str) -> tuple[BrokerProtocolHandler, Session]:
session = Session()
session.client_id = client_id
bph = BrokerProtocolHandler(self.plugins_manager, session)
session.transitions.disconnect()
return bph, session
async def _handle_client_session(
self,
reader: ReaderAdapter,
@ -635,7 +672,7 @@ class Broker:
client_session.will_qos,
)
if client_session.will_retain:
self.retain_message(
await self.retain_message(
client_session,
client_session.will_topic,
client_session.will_message,
@ -660,7 +697,7 @@ class Broker:
"""Handle client subscription."""
self.logger.debug(f"{client_session.client_id} handling subscription")
subscriptions = subscribe_waiter.result()
return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics]
return_codes = [await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics]
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
for index, subscription in enumerate(subscriptions.topics):
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
@ -730,7 +767,7 @@ class Broker:
)
await self._broadcast_message(client_session, app_message.topic, app_message.data)
if app_message.publish_packet and app_message.publish_packet.retain_flag:
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
await self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
return True
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
@ -774,7 +811,7 @@ class Broker:
return False
def retain_message(
async def retain_message(
self,
source_session: Session | None,
topic_name: str | None,
@ -785,12 +822,25 @@ class Broker:
# If retained flag set, store the message for further subscriptions
self.logger.debug(f"Retaining message on topic {topic_name}")
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
client_id=None,
retained_message=self._retained_messages[topic_name])
# [MQTT-3.3.1-10]
elif topic_name in self._retained_messages:
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
cleared_message = self._retained_messages[topic_name]
cleared_message.data = b""
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
client_id=None,
retained_message=cleared_message)
del self._retained_messages[topic_name]
async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
topic_filter, qos = subscription
if "#" in topic_filter and not topic_filter.endswith("#"):
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
@ -982,6 +1032,10 @@ class Broker:
retained_message = RetainedApplicationMessage(broadcast["session"], broadcast["topic"], broadcast["data"], qos)
await target_session.retained_messages.put(retained_message)
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
client_id=target_session.client_id,
retained_message=retained_message)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}")

Wyświetl plik

@ -0,0 +1,267 @@
from dataclasses import dataclass
import logging
from pathlib import Path
from sqlalchemy import Boolean, Integer, LargeBinary, Result, String, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from amqtt.broker import BrokerContext, RetainedApplicationMessage
from amqtt.contrib import DataClassListJSON
from amqtt.errors import PluginError
from amqtt.plugins.base import BasePlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
pass
@dataclass
class RetainedMessage:
topic: str
data: str
qos: int
@dataclass
class Subscription:
topic: str
qos: int
class StoredSession(Base):
__tablename__ = "stored_sessions"
id: Mapped[int] = mapped_column(primary_key=True)
client_id: Mapped[str] = mapped_column(String)
clean_session: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
will_flag: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
will_message: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
will_qos: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None)
will_retain: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=None)
will_topic: Mapped[str | None] = mapped_column(String, nullable=True, default=None)
keep_alive: Mapped[int] = mapped_column(Integer, default=0)
retained: Mapped[list[RetainedMessage]] = mapped_column(DataClassListJSON(RetainedMessage), default=list)
subscriptions: Mapped[list[Subscription]] = mapped_column(DataClassListJSON(Subscription), default=list)
class StoredMessage(Base):
__tablename__ = "stored_messages"
id: Mapped[int] = mapped_column(primary_key=True)
topic: Mapped[str] = mapped_column(String)
data: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
qos: Mapped[int] = mapped_column(Integer, default=0)
class SessionDBPlugin(BasePlugin[BrokerContext]):
"""Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
Configuration:
- file *(string)* path & filename to store the session db. default: `amqtt.db`
- clear_on_shutdown *(bool)* if the broker shutdowns down normally, don't retain any information. default: `True`
"""
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
# bypass the `test_plugins_correct_has_attr` until it can be updated
if not hasattr(self.config, "file"):
logger.warning("`Config` is missing a `file` attribute")
return
self._engine = create_async_engine(f"sqlite+aiosqlite:///{self.config.file}")
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
@staticmethod
async def _get_or_create_session(db_session: AsyncSession, client_id:str) -> StoredSession:
stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
stored_session = await db_session.scalar(stmt)
if stored_session is None:
stored_session = StoredSession(client_id=client_id)
db_session.add(stored_session)
await db_session.flush()
return stored_session
@staticmethod
async def _get_or_create_message(db_session: AsyncSession, topic:str) -> StoredMessage:
stmt = select(StoredMessage).filter(StoredMessage.topic == topic)
stored_message = await db_session.scalar(stmt)
if stored_message is None:
stored_message = StoredMessage(topic=topic)
db_session.add(stored_message)
await db_session.flush()
return stored_message
async def on_broker_client_connected(self, client_id:str, client_session:Session) -> None:
"""Search to see if session already exists."""
# if client id doesn't exist, create (can ignore if session is anonymous)
# update session information (will, clean_session, etc)
# don't store session information for clean or anonymous sessions
if client_session.clean_session in (None, True) or client_session.is_anonymous:
return
async with self._db_session_maker() as db_session, db_session.begin():
stored_session = await self._get_or_create_session(db_session, client_id)
stored_session.clean_session = client_session.clean_session
stored_session.will_flag = client_session.will_flag
stored_session.will_message = client_session.will_message # type: ignore[assignment]
stored_session.will_qos = client_session.will_qos
stored_session.will_retain = client_session.will_retain
stored_session.will_topic = client_session.will_topic
stored_session.keep_alive = client_session.keep_alive
await db_session.flush()
async def on_broker_client_subscribed(self, client_id: str, topic: str, qos: int) -> None:
"""Create/update subscription if clean session = false."""
session = self.context.get_session(client_id)
if not session:
logger.warning(f"'{client_id}' is subscribing but doesn't have a session")
return
if session.clean_session:
return
async with self._db_session_maker() as db_session, db_session.begin():
# stored sessions shouldn't need to be created here, but we'll use the same helper...
stored_session = await self._get_or_create_session(db_session, client_id)
stored_session.subscriptions = [*stored_session.subscriptions, Subscription(topic, qos)]
await db_session.flush()
async def on_broker_client_unsubscribed(self, client_id: str, topic: str) -> None:
"""Remove subscription if clean session = false."""
async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None:
"""Update to retained messages.
if retained_message.data is None or '', the message is being cleared
"""
# if client_id is valid, the retained message is for a disconnected client
if client_id is not None:
async with self._db_session_maker() as db_session, db_session.begin():
# stored sessions shouldn't need to be created here, but we'll use the same helper...
stored_session = await self._get_or_create_session(db_session, client_id)
stored_session.retained = [*stored_session.retained, RetainedMessage(retained_message.topic,
retained_message.data.decode(),
retained_message.qos or 0)]
await db_session.flush()
return
async with self._db_session_maker() as db_session, db_session.begin():
# if the retained message has data, we need to store/update for the topic
if retained_message.data:
client_message = await self._get_or_create_message(db_session, retained_message.topic)
client_message.data = retained_message.data # type: ignore[assignment]
client_message.qos = retained_message.qos or 0
await db_session.flush()
return
# if there is no data, clear the stored message (if exists) for the topic
stmt = select(StoredMessage).filter(StoredMessage.topic == retained_message.topic)
topic_message = await db_session.scalar(stmt)
if topic_message is not None:
await db_session.delete(topic_message)
await db_session.flush()
return
async def on_broker_pre_start(self) -> None:
"""Initialize the database and db connection."""
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def on_broker_post_start(self) -> None:
"""Load subscriptions."""
if len(self.context.subscriptions) > 0:
msg = "SessionDBPlugin : broker shouldn't have any subscriptions yet"
raise PluginError(msg)
if len(list(self.context.sessions)) > 0:
msg = "SessionDBPlugin : broker shouldn't have any sessions yet"
raise PluginError(msg)
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(StoredSession)
stored_sessions = await db_session.execute(stmt)
restored_sessions = 0
for stored_session in stored_sessions.scalars():
await self.context.add_subscription(stored_session.client_id, None, None)
for subscription in stored_session.subscriptions:
await self.context.add_subscription(stored_session.client_id,
subscription.topic,
subscription.qos)
session = self.context.get_session(stored_session.client_id)
if not session:
continue
session.clean_session = stored_session.clean_session
session.will_flag = stored_session.will_flag
session.will_message = stored_session.will_message
session.will_qos = stored_session.will_qos
session.will_retain = stored_session.will_retain
session.will_topic = stored_session.will_topic
session.keep_alive = stored_session.keep_alive
for message in stored_session.retained:
retained_message = RetainedApplicationMessage(
source_session=None,
topic=message.topic,
data=message.data.encode(),
qos=message.qos
)
await session.retained_messages.put(retained_message)
restored_sessions += 1
stmt = select(StoredMessage)
stored_messages: Result[tuple[StoredMessage]] = await db_session.execute(stmt)
restored_messages = 0
retained_messages = self.context.retained_messages
for stored_message in stored_messages.scalars():
retained_messages[stored_message.topic] = (RetainedApplicationMessage(
source_session=None,
topic=stored_message.topic,
data=stored_message.data or b"",
qos=stored_message.qos
))
restored_messages += 1
logger.info(f"Retained messages restored: {restored_messages}")
logger.info(f"Restored {restored_sessions} sessions.")
async def on_broker_pre_shutdown(self) -> None:
"""Clean up the db connection."""
await self._engine.dispose()
async def on_broker_post_shutdown(self) -> None:
if self.config.clear_on_shutdown and self.config.file.exists():
self.config.file.unlink()
@dataclass
class Config:
"""Configuration variables."""
file: str | Path = "amqtt.db"
"""path & filename to store the sqlite session db."""
clear_on_shutdown: bool = True
"""if the broker shutdowns down normally, don't retain any information."""
def __post_init__(self) -> None:
"""Create `Path` from string path."""
if isinstance(self.file, str):
self.file = Path(self.file)

Wyświetl plik

@ -31,4 +31,5 @@ class BrokerEvents(Events):
CLIENT_DISCONNECTED = "broker_client_disconnected"
CLIENT_SUBSCRIBED = "broker_client_subscribed"
CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed"
RETAINED_MESSAGE = "broker_retained_message"
MESSAGE_RECEIVED = "broker_message_received"

Wyświetl plik

@ -8,6 +8,8 @@ import copy
from importlib.metadata import EntryPoint, EntryPoints, entry_points
from inspect import iscoroutinefunction
import logging
import sys
import traceback
from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast
import warnings
@ -291,6 +293,15 @@ class PluginManager(Generic[C]):
return asyncio.ensure_future(coro)
def _clean_fired_events(self, future: asyncio.Future[Any]) -> None:
if self.logger.getEffectiveLevel() <= logging.DEBUG:
try:
future.result()
except asyncio.CancelledError:
self.logger.warning("fired event was cancelled")
# display plugin fault; don't allow it to cause a broker failure
except Exception as exc: # noqa: BLE001, pylint: disable=W0718
traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
with contextlib.suppress(KeyError, ValueError):
self._fired_events.remove(future)

Wyświetl plik

@ -1,85 +1,11 @@
import json
import sqlite3
from typing import Any
import warnings
from amqtt.contexts import BaseContext
from amqtt.session import Session
from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin
class SQLitePlugin:
def __init__(self, context: BaseContext) -> None:
self.context: BaseContext = context
self.conn: sqlite3.Connection | None = None
self.cursor: sqlite3.Cursor | None = None
self.db_file: str | None = None
self.persistence_config: dict[str, Any]
class SQLitePlugin(BasePlugin[BrokerContext]):
if (
persistence_config := self.context.config.get("persistence") if self.context.config is not None else None
) is not None:
self.persistence_config = persistence_config
self.init_db()
else:
self.context.logger.warning("'persistence' section not found in context configuration")
def init_db(self) -> None:
self.db_file = self.persistence_config.get("file")
if not self.db_file:
self.context.logger.warning("'file' persistence parameter not found")
else:
try:
self.conn = sqlite3.connect(self.db_file)
self.cursor = self.conn.cursor()
self.context.logger.info(f"Database file '{self.db_file}' opened")
except Exception:
self.context.logger.exception(f"Error while initializing database '{self.db_file}'")
if self.cursor:
self.cursor.execute(
"CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)",
)
self.cursor.execute("PRAGMA table_info(session)")
columns = {col[1] for col in self.cursor.fetchall()}
required_columns = {"client_id", "data"}
if not required_columns.issubset(columns):
self.context.logger.error("Database schema for 'session' table is incompatible.")
async def save_session(self, session: Session) -> None:
if self.cursor and self.conn:
dump: str = json.dumps(session, default=str)
try:
self.cursor.execute(
"INSERT OR REPLACE INTO session (client_id, data) VALUES (?, ?)",
(session.client_id, dump),
)
self.conn.commit()
except Exception:
self.context.logger.exception(f"Failed saving session '{session}'")
async def find_session(self, client_id: str) -> Session | None:
if self.cursor:
row = self.cursor.execute(
"SELECT data FROM session where client_id=?",
(client_id,),
).fetchone()
return json.loads(row[0]) if row else None
return None
async def del_session(self, client_id: str) -> None:
if self.cursor and self.conn:
try:
exists = self.cursor.execute("SELECT 1 FROM session WHERE client_id=?", (client_id,)).fetchone()
if exists:
self.cursor.execute("DELETE FROM session where client_id=?", (client_id,))
self.conn.commit()
except Exception:
self.context.logger.exception(f"Failed deleting session with client_id '{client_id}'")
async def on_broker_post_shutdown(self) -> None:
if self.conn:
try:
self.conn.close()
self.context.logger.info(f"Database file '{self.db_file}' closed")
except Exception:
self.context.logger.exception("Error closing database connection")
finally:
self.conn = None
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
warnings.warn("SQLitePlugin is deprecated, use amqtt.contrib.persistence.SessionDBPlugin", stacklevel=1)

Wyświetl plik

@ -112,7 +112,7 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
"""Initialize statistics and start $SYS broadcasting."""
self._stats[STAT_START_TIME] = int(datetime.now(tz=UTC).timestamp())
version = f"aMQTT version {amqtt.__version__}"
self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())
await self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())
# Start $SYS topics management
try:

Wyświetl plik

@ -106,6 +106,8 @@ All plugins are notified of events if the `BasePlugin` subclass implements one o
- `async def on_broker_client_connected(self, *, client_id:str) -> None`
- `async def on_broker_client_disconnected(self, *, client_id:str) -> None`
- `async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None`
- `async def on_broker_client_subscribed(self, *, client_id: str, topic: str, qos: int) -> None`
- `async def on_broker_client_unsubscribed(self, *, client_id: str, topic: str) -> None`
@ -115,6 +117,11 @@ All plugins are notified of events if the `BasePlugin` subclass implements one o
- `async def on_mqtt_packet_received(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
!!! note retained message event
if the `client_id` is `None`, the message is retained for a topic
if the `retained_message.data` is `None` or empty (`''`), the topic message is being cleared
## Authentication Plugins
In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin`

Wyświetl plik

@ -1,4 +1,4 @@
# Existing Plugins
# Packaged Plugins
With the aMQTT plugins framework, one can add additional functionality without
having to rewrite core logic in the broker or client. Plugins can be loaded and configured using
@ -51,7 +51,6 @@ Authentication plugin allowing anonymous access.
extra:
class_style: "simple"
!!! danger
even if `allow_anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client

Wyświetl plik

@ -0,0 +1,12 @@
# Session Persistence
`amqtt.plugins.persistence.SessionDBPlugin`
Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
::: amqtt.plugins.persistence.SessionDBPlugin.Config
options:
show_source: false
heading_level: 4
extra:
class_style: "simple"

Wyświetl plik

@ -10,7 +10,7 @@
- Communication over TCP and/or websocket, including support for SSL/TLS
- Support QoS 0, QoS 1 and QoS 2 messages flow
- Client auto-reconnection on network lost
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
- Custom functionality expansion; plugins included: authentication, `$SYS` topic publishing, session persistence
## Installation

Wyświetl plik

@ -35,7 +35,7 @@ dependencies = [
"typer==0.15.4",
"aiohttp>=3.12.7",
"dacite>=1.9.2",
"psutil>=7.0.0",
"psutil>=7.0.0"
]
[dependency-groups]
@ -57,6 +57,7 @@ dev = [
"pytest>=8.3.5", # https://pypi.org/project/pytest
"ruff>=0.11.3", # https://pypi.org/project/ruff
"setuptools>=78.1.0",
"sqlalchemy[mypy]>=2.0.41",
"types-mock>=5.2.0.20250306", # https://pypi.org/project/types-mock
"types-PyYAML>=6.0.12.20250402", # https://pypi.org/project/types-PyYAML
"types-setuptools>=78.1.0.20250329", # https://pypi.org/project/types-setuptools
@ -84,7 +85,9 @@ docs = [
[project.optional-dependencies]
ci = ["coveralls==4.0.1"]
contrib = [
"sqlalchemy>=2.0.41",
"aiosqlite>=0.21.0",
"greenlet>=3.2.3",
"sqlalchemy[asyncio]>=2.0.41",
"argon2-cffi>=25.1.0",
"aiohttp>=3.12.13",
]

Wyświetl plik

@ -1,3 +1,4 @@
import argparse
import asyncio
import logging
from pathlib import Path
@ -23,13 +24,13 @@ config = {
},
"auto_reconnect": False,
"check_hostname": False,
"certfile": "cert.pem",
"certfile": "",
}
client = MQTTClient(config=config)
async def test_coro() -> None:
async def test_coro(certfile: str) -> None:
config['certfile'] = certfile
client = MQTTClient(config=config)
await client.connect("mqtts://localhost:8883")
tasks = [
@ -45,7 +46,11 @@ async def test_coro() -> None:
def __main__():
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter)
asyncio.run(test_coro())
parser = argparse.ArgumentParser()
parser.add_argument('--cert', default='cert.pem', help="path & file to verify server's authenticity")
args = parser.parse_args()
asyncio.run(test_coro(args.cert))
if __name__ == "__main__":
__main__()

Wyświetl plik

@ -1,4 +1,5 @@
import logging
import subprocess
from pathlib import Path
import tempfile
from typing import Any
@ -15,19 +16,46 @@ log = logging.getLogger(__name__)
pytest_plugins = ["pytest_logdog"]
test_config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 15},
"ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 15},
"wss": {"type": "ws", "bind": "127.0.0.1:8081", "max_connections": 15},
},
"sys_interval": 0,
"auth": {
"allow-anonymous": True,
}
}
@pytest.fixture
def rsa_keys():
tmp_dir = tempfile.TemporaryDirectory(prefix='amqtt-test-')
cert = Path(tmp_dir.name) / "cert.pem"
key = Path(tmp_dir.name) / "key.pem"
cmd = f'openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout {key} -out {cert} -subj "/CN=localhost"'
subprocess.run(cmd, shell=True, capture_output=True, text=True)
yield cert, key
tmp_dir.cleanup()
@pytest.fixture
def test_config(rsa_keys):
certfile, keyfile = rsa_keys
yield {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 15},
"mqtts": {
"type": "tcp",
"bind": "127.0.0.1:1884",
"max_connections": 15,
"ssl": True,
"certfile": certfile,
"keyfile": keyfile
},
"ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 15},
"wss": {
"type": "ws",
"bind": "127.0.0.1:8081",
"max_connections": 15,
"ssl": True,
'certfile': certfile,
'keyfile': keyfile},
},
"sys_interval": 0,
"auth": {
"allow-anonymous": True,
}
}
test_config_acl: dict[str, int | dict[str, Any]] = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1884", "max_connections": 10},
@ -64,7 +92,7 @@ def mock_plugin_manager():
@pytest.fixture
async def broker_fixture():
async def broker_fixture(test_config):
broker = Broker(test_config, plugin_namespace="amqtt.test.plugins")
await broker.start()
assert broker.transitions.is_started()
@ -78,7 +106,7 @@ async def broker_fixture():
@pytest.fixture
async def broker(mock_plugin_manager):
async def broker(mock_plugin_manager, test_config):
# just making sure the mock is in place before we start our broker
assert mock_plugin_manager is not None

Wyświetl plik

@ -0,0 +1,543 @@
import asyncio
import logging
from pathlib import Path
import sqlite3
import pytest
import aiosqlite
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy import select
from amqtt.broker import Broker, BrokerContext, RetainedApplicationMessage
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_1
from amqtt.contrib.persistence import SessionDBPlugin, Subscription, StoredSession, RetainedMessage
from amqtt.session import Session
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter)
logger = logging.getLogger(__name__)
@pytest.fixture
async def db_file():
db_file = Path(__file__).parent / "amqtt.db"
if db_file.exists():
raise NotImplementedError("existing db file found, should it be cleaned up?")
yield db_file
if db_file.exists():
db_file.unlink()
@pytest.fixture
async def broker_context():
cfg = {
'listeners': { 'default': {'type': 'tcp', 'bind': 'localhost:1883' }},
'plugins': {}
}
context = BrokerContext(broker=Broker(config=cfg))
yield context
@pytest.fixture
async def db_session_factory(db_file):
engine = create_async_engine(f"sqlite+aiosqlite:///{str(db_file)}")
factory = async_sessionmaker(engine, expire_on_commit=False)
yield factory
@pytest.mark.asyncio
async def test_initialize_tables(db_file, broker_context):
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
assert db_file.exists()
conn = sqlite3.connect(str(db_file))
cursor = conn.cursor()
table_name = 'stored_sessions'
cursor.execute(f"PRAGMA table_info({table_name});")
rows = cursor.fetchall()
column_names = [row[1] for row in rows]
assert len(column_names) > 1
@pytest.mark.asyncio
async def test_create_stored_session(db_file, broker_context, db_session_factory):
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
async with db_session_factory() as db_session:
async with db_session.begin():
stored_session = await session_db_plugin._get_or_create_session(db_session, 'test_client_1')
assert stored_session.client_id == 'test_client_1'
async with aiosqlite.connect(str(db_file)) as db:
async with await db.execute("SELECT * FROM stored_sessions") as cursor:
async for row in cursor:
assert row[1] == 'test_client_1'
@pytest.mark.asyncio
async def test_get_stored_session(db_file, broker_context, db_session_factory):
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
async with aiosqlite.connect(str(db_file)) as db:
sql = """INSERT INTO stored_sessions (
client_id, clean_session, will_flag,
will_qos, keep_alive,
retained, subscriptions
) VALUES (
'test_client_1',
1,
0,
1,
60,
'[]',
'[{"topic":"sensors/#","qos":1}]'
)"""
await db.execute(sql)
await db.commit()
async with db_session_factory() as db_session:
async with db_session.begin():
stored_session = await session_db_plugin._get_or_create_session(db_session, 'test_client_1')
assert stored_session.subscriptions == [Subscription(topic='sensors/#', qos=1)]
@pytest.mark.asyncio
async def test_update_stored_session(db_file, broker_context, db_session_factory):
broker_context.config = SessionDBPlugin.Config(file=db_file)
# create session for client id (without subscription)
await broker_context.add_subscription('test_client_1', None, None)
session = broker_context.get_session('test_client_1')
assert session is not None
session.clean_session = False
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
# initialize with stored client session
async with aiosqlite.connect(str(db_file)) as db:
sql = """INSERT INTO stored_sessions (
client_id, clean_session, will_flag,
will_qos, keep_alive,
retained, subscriptions
) VALUES (
'test_client_1',
1,
0,
1,
60,
'[]',
'[{"topic":"sensors/#","qos":1}]'
)"""
await db.execute(sql)
await db.commit()
await session_db_plugin.on_broker_client_subscribed(client_id='test_client_1', topic='my/topic', qos=2)
# verify that the stored session has been updated with the new subscription
has_stored_session = False
async with aiosqlite.connect(str(db_file)) as db:
async with await db.execute("SELECT * FROM stored_sessions") as cursor:
async for row in cursor:
assert row[1] == 'test_client_1'
assert row[-1] == '[{"topic": "sensors/#", "qos": 1}, {"topic": "my/topic", "qos": 2}]'
has_stored_session = True
assert has_stored_session, "stored session wasn't updated"
@pytest.mark.asyncio
async def test_client_connected_with_clean_session(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
session = Session()
session.client_id = 'test_client_connected'
session.is_anonymous = False
session.clean_session = True
await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session)
async with aiosqlite.connect(str(db_file)) as db_conn:
db_conn.row_factory = sqlite3.Row
async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor:
assert len(await cursor.fetchall()) == 0
@pytest.mark.asyncio
async def test_client_connected_anonymous_session(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
session = Session()
session.is_anonymous = True
session.client_id = 'test_client_connected'
await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session)
async with aiosqlite.connect(str(db_file)) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor:
assert len(await cursor.fetchall()) == 0
@pytest.mark.asyncio
async def test_client_connected_and_stored_session(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
session = Session()
session.client_id = 'test_client_connected'
session.is_anonymous = False
session.clean_session = False
session.will_flag = True
session.will_qos = 1
session.will_topic = 'my/will/topic'
session.will_retain = False
session.will_message = b'test connected client has a last will (and testament) message'
session.keep_alive = 42
await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session)
has_stored_session = False
async with aiosqlite.connect(str(db_file)) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor:
for row in await cursor.fetchall():
assert row['client_id'] == 'test_client_connected'
assert row['clean_session'] == False
assert row['will_flag'] == True
assert row['will_qos'] == 1
assert row['will_topic'] == 'my/will/topic'
assert row['will_retain'] == False
assert row['will_message'] == b'test connected client has a last will (and testament) message'
assert row['keep_alive'] == 42
has_stored_session = True
assert has_stored_session, "client session wasn't stored"
@pytest.mark.asyncio
async def test_repopulate_stored_sessions(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
async with aiosqlite.connect(str(db_file)) as db:
sql = """INSERT INTO stored_sessions (
client_id, clean_session, will_flag,
will_qos, keep_alive,
retained, subscriptions
) VALUES (
'test_client_1',
1,
0,
1,
60,
'[{"topic":"sensors/#","data":"this message is retained when client reconnects","qos":1}]',
'[{"topic":"sensors/#","qos":1}]'
)"""
await db.execute(sql)
await db.commit()
await session_db_plugin.on_broker_post_start()
session = broker_context.get_session('test_client_1')
assert session is not None
assert session.retained_messages.qsize() == 1
assert 'sensors/#' in broker_context._broker_instance._subscriptions
# ugly: b/c _subscriptions is a list of dictionaries of tuples
assert broker_context._broker_instance._subscriptions['sensors/#'][0][1] == 1
@pytest.mark.asyncio
async def test_client_retained_message(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
# add a session to the broker
await broker_context.add_subscription('test_retained_client', None, None)
# update the session so that it's retained
session = broker_context.get_session('test_retained_client')
assert session is not None
session.is_anonymous = False
session.clean_session = False
session.transitions.disconnect()
retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'retain message for disconnected client', qos=2)
await session_db_plugin.on_broker_retained_message(client_id='test_retained_client', retained_message=retained_message)
async with db_session_factory() as db_session:
async with db_session.begin():
stmt = select(StoredSession).filter(StoredSession.client_id == 'test_retained_client')
stored_session = await db_session.scalar(stmt)
assert stored_session is not None
assert len(stored_session.retained) > 0
assert RetainedMessage(topic='my/retained/topic', data='retained message', qos=2)
@pytest.mark.asyncio
async def test_topic_retained_message(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file, clear_on_shutdown=False)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
# add a session to the broker
await broker_context.add_subscription('test_retained_client', None, None)
# update the session so that it's retained
session = broker_context.get_session('test_retained_client')
assert session is not None
session.is_anonymous = False
session.clean_session = False
session.transitions.disconnect()
retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'retained message', qos=2)
await session_db_plugin.on_broker_retained_message(client_id=None, retained_message=retained_message)
has_stored_message = False
async with aiosqlite.connect(str(db_file)) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
async with await db_conn.execute("SELECT * FROM stored_messages") as cursor:
# assert(len(await cursor.fetchall()) > 0)
for row in await cursor.fetchall():
assert row['topic'] == 'my/retained/topic'
assert row['data'] == b'retained message'
assert row['qos'] == 2
has_stored_message = True
assert has_stored_message, "retained topic message wasn't stored"
@pytest.mark.asyncio
async def test_topic_clear_retained_message(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
# add a session to the broker
await broker_context.add_subscription('test_retained_client', None, None)
# update the session so that it's retained
session = broker_context.get_session('test_retained_client')
assert session is not None
session.is_anonymous = False
session.clean_session = False
session.transitions.disconnect()
retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'', qos=0)
await session_db_plugin.on_broker_retained_message(client_id=None, retained_message=retained_message)
async with aiosqlite.connect(str(db_file)) as db_conn:
db_conn.row_factory = sqlite3.Row
async with await db_conn.execute("SELECT * FROM stored_messages") as cursor:
assert(len(await cursor.fetchall()) == 0)
@pytest.mark.asyncio
async def test_restoring_retained_message(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
stmts = ("INSERT INTO stored_messages VALUES(1,'my/retained/topic1',X'72657461696e6564206d657373616765',2)",
"INSERT INTO stored_messages VALUES(2,'my/retained/topic2',X'72657461696e6564206d65737361676532',2)",
"INSERT INTO stored_messages VALUES(3,'my/retained/topic3',X'72657461696e6564206d65737361676533',2)")
async with aiosqlite.connect(str(db_file)) as db:
for stmt in stmts:
await db.execute(stmt)
await db.commit()
await session_db_plugin.on_broker_post_start()
assert len(broker_context.retained_messages) == 3
assert 'my/retained/topic1' in broker_context.retained_messages
assert 'my/retained/topic2' in broker_context.retained_messages
assert 'my/retained/topic3' in broker_context.retained_messages
@pytest.fixture
def db_config(db_file):
return {
'listeners': {
'default': {
'type': 'tcp',
'bind': '127.0.0.1:1883'
}
},
'plugins': {
'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False},
'amqtt.contrib.persistence.SessionDBPlugin': {
'file': db_file,
'clear_on_shutdown': False
}
}
}
@pytest.mark.asyncio
async def test_broker_client_no_cleanup(db_file, db_config) -> None:
b1 = Broker(config=db_config)
await b1.start()
await asyncio.sleep(0.1)
c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False)
# test that this message is retained for the topic upon restore
await c1.publish("my/retained/topic", b'retained message for topic my/retained/topic', retain=True)
await asyncio.sleep(0.2)
await c1.disconnect()
await b1.shutdown()
# new broker should load the previous broker's db file since clean_on_shutdown is false in config
b2 = Broker(config=db_config)
await b2.start()
await asyncio.sleep(0.1)
# upon subscribing to topic with retained message, it should be received
c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False})
await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False)
await c2.subscribe([
('my/retained/topic', QOS_1)
])
msg = await c2.deliver_message(timeout_duration=1)
assert msg is not None
assert msg.topic == "my/retained/topic"
assert msg.data == b'retained message for topic my/retained/topic'
await c2.disconnect()
await b2.shutdown()
@pytest.mark.asyncio
async def test_broker_client_retain_subscription(db_file, db_config) -> None:
b1 = Broker(config=db_config)
await b1.start()
await asyncio.sleep(0.1)
c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False)
# test to make sure the subscription is re-established upon reconnection after broker restart (clear_on_shutdown = False)
ret = await c1.subscribe([
('my/offline/topic', QOS_1)
])
assert ret == [QOS_1,]
await asyncio.sleep(0.2)
await c1.disconnect()
await asyncio.sleep(0.1)
await b1.shutdown()
# new broker should load the previous broker's db file
b2 = Broker(config=db_config)
await b2.start()
await asyncio.sleep(0.1)
# client1's subscription should have been restored, so when it connects, it should receive this message
c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False})
await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False)
await c2.publish('my/offline/topic', b'standard message to be retained for offline clients')
await asyncio.sleep(0.1)
await c2.disconnect()
await c1.reconnect(cleansession=False)
await asyncio.sleep(0.1)
msg = await c1.deliver_message(timeout_duration=2)
assert msg is not None
assert msg.topic == "my/offline/topic"
assert msg.data == b'standard message to be retained for offline clients'
await c1.disconnect()
await b2.shutdown()
@pytest.mark.asyncio
async def test_broker_client_retain_message(db_file, db_config) -> None:
"""test to make sure that the retained message because client1 is offline,
gets sent when back online after broker restart."""
b1 = Broker(config=db_config)
await b1.start()
await asyncio.sleep(0.1)
c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False)
# subscribe to a topic with QOS_1 so that we receive messages, even if we're disconnected when sent
ret = await c1.subscribe([
('my/offline/topic', QOS_1)
])
assert ret == [QOS_1,]
await asyncio.sleep(0.2)
# go offline
await c1.disconnect()
await asyncio.sleep(0.1)
# another client sends a message to previously subscribed to topic
c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False})
await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False)
# this message should be delivered after broker stops and restarts (and client connects)
await c2.publish('my/offline/topic', b'standard message to be retained for offline clients')
await asyncio.sleep(0.1)
await c2.disconnect()
await asyncio.sleep(0.1)
await b1.shutdown()
# new broker should load the previous broker's db file since we declared clear_on_shutdown = False in config
b2 = Broker(config=db_config)
await b2.start()
await asyncio.sleep(0.1)
# when first client reconnects, it should receive the message that had been previously retained for it
await c1.reconnect(cleansession=False)
await asyncio.sleep(0.1)
msg = await c1.deliver_message(timeout_duration=2)
assert msg is not None
assert msg.topic == "my/offline/topic"
assert msg.data == b'standard message to be retained for offline clients'
# client should also receive a message if send on this topic
await c1.publish("my/offline/topic", b"online message should also be received")
await asyncio.sleep(0.1)
msg = await c1.deliver_message(timeout_duration=2)
assert msg is not None
assert msg.topic == "my/offline/topic"
assert msg.data == b'online message should also be received'
await c1.disconnect()
await b2.shutdown()

Wyświetl plik

@ -1,56 +0,0 @@
import asyncio
import logging
from pathlib import Path
import sqlite3
import unittest
from amqtt.contexts import BaseContext
from amqtt.plugins.persistence import SQLitePlugin
from amqtt.session import Session
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter)
class TestSQLitePlugin(unittest.TestCase):
def setUp(self) -> None:
self.loop = asyncio.new_event_loop()
def test_create_tables(self) -> None:
dbfile = Path(__file__).resolve().parent / "test.db"
context = BaseContext()
context.logger = logging.getLogger(__name__)
context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config
SQLitePlugin(context)
try:
conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection
cursor = conn.cursor()
rows = cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table'")
tables = [row[0] for row in rows] # List comprehension for brevity
assert "session" in tables
finally:
conn.close()
def test_save_session(self) -> None:
dbfile = Path(__file__).resolve().parent / "test.db"
context = BaseContext()
context.logger = logging.getLogger(__name__)
context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config
sql_plugin = SQLitePlugin(context)
s = Session()
s.client_id = "test_save_session"
self.loop.run_until_complete(sql_plugin.save_session(session=s))
try:
conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection
cursor = conn.cursor()
row = cursor.execute("SELECT client_id FROM session WHERE client_id = 'test_save_session'").fetchone()
assert row is not None
assert row[0] == s.client_id
finally:
conn.close()

Wyświetl plik

@ -1,12 +1,11 @@
import asyncio
import inspect
import logging
from functools import partial
from importlib.metadata import EntryPoint
from logging import getLogger
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Coroutine
from unittest.mock import patch
import pytest
@ -15,13 +14,15 @@ from amqtt.broker import Broker, BrokerContext
from amqtt.client import MQTTClient
from amqtt.errors import PluginInitError, PluginImportError
from amqtt.events import MQTTEvents, BrokerEvents
from amqtt.mqtt.constants import QOS_0
from amqtt.mqtt.constants import QOS_0, QOS_1
from amqtt.plugins.base import BasePlugin
from amqtt.contexts import BaseContext
from amqtt.contrib.persistence import RetainedMessage
_INVALID_METHOD: str = "invalid_foo"
_PLUGIN: str = "Plugin"
logger = logging.getLogger(__name__)
class _TestContext(BaseContext):
def __init__(self) -> None:
@ -159,7 +160,7 @@ async def test_all_plugin_events():
client = MQTTClient()
await client.connect("mqtt://127.0.0.1:1883/")
await client.subscribe([('my/test/topic', QOS_0),])
await client.publish('test/topic', b'my test message')
await client.publish('test/topic', b'my test message', retain=True)
await client.unsubscribe(['my/test/topic',])
await client.disconnect()
await asyncio.sleep(1)
@ -170,3 +171,70 @@ async def test_all_plugin_events():
await asyncio.sleep(1)
assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}'
class RetainedMessageEventPlugin(BasePlugin[BrokerContext]):
"""A plugin to verify all events get sent to plugins."""
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
self.topic_retained_message_flag = False
self.session_retained_message_flag = False
self.topic_clear_retained_message_flag = False
async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedMessage) -> None:
"""retaining message event handler."""
if client_id:
session = self.context.get_session(client_id)
assert session.transitions.state != "connected"
logger.debug("retained message event fired for offline client")
self.session_retained_message_flag = True
else:
if not retained_message.data:
logger.debug("retained message event fired for clearing a topic")
self.topic_clear_retained_message_flag = True
else:
logger.debug("retained message event fired for setting a topic")
self.topic_retained_message_flag = True
@pytest.mark.asyncio
async def test_retained_message_plugin_event():
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1,
'plugins':[{'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False}},
{'tests.plugins.test_plugins.RetainedMessageEventPlugin': {}}]
}
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
await broker.start()
await asyncio.sleep(0.1)
# make sure all expected events get triggered
client1 = MQTTClient(config={'auto_reconnect': False})
await client1.connect("mqtt://myUsername@127.0.0.1:1883/", cleansession=False)
await client1.subscribe([('test/topic', QOS_1),])
await client1.publish('test/retained', b'message should be retained for test/retained', retain=True)
await asyncio.sleep(0.1)
await client1.disconnect()
client2 = MQTTClient(config={'auto_reconnect': False})
await client2.connect("mqtt://myOtherUsername@127.0.0.1:1883/", cleansession=True)
await client2.publish('test/topic', b'message should be retained for myUsername since subscription was qos > 0')
await client2.publish('test/retained', b'', retain=True) # should clear previously retained message
await asyncio.sleep(0.1)
await client2.disconnect()
await asyncio.sleep(0.1)
# get the plugin so it doesn't get gc on shutdown
test_plugin = broker.plugins_manager.get_plugin('RetainedMessageEventPlugin')
await broker.shutdown()
await asyncio.sleep(0.1)
assert test_plugin.topic_retained_message_flag, "message to topic wasn't retained"
assert test_plugin.session_retained_message_flag, "message to disconnected client wasn't retained"
assert test_plugin.topic_clear_retained_message_flag, "message to retained topic wasn't cleared"

Wyświetl plik

@ -15,20 +15,22 @@ logging.basicConfig(level=logging.ERROR, format=formatter)
log = logging.getLogger(__name__)
# @pytest.mark.asyncio
# async def test_connect_tcp():
# client = MQTTClient()
# await client.connect("mqtt://broker.hivemq.com:1883/")
# assert client.session is not None
# await client.disconnect()
#
#
# @pytest.mark.asyncio
# async def test_connect_tcp_secure(ca_file_fixture):
# client = MQTTClient(config={"check_hostname": False})
# await client.connect("mqtts://broker.hivemq.com:8883/")
# assert client.session is not None
# await client.disconnect()
@pytest.mark.asyncio
async def test_connect_tcp(broker_fixture):
client = MQTTClient()
await client.connect("mqtt://localhost:1883/")
assert client.session is not None
await client.disconnect()
@pytest.mark.asyncio
async def test_connect_tcp_secure(rsa_keys, broker_fixture):
certfile, _ = rsa_keys
client = MQTTClient(config={"check_hostname": False, "auto_reconnect": False})
# since we're using a self-signed certificate, need to provide the server's certificate to verify authenticity
await client.connect("mqtts://localhost:1884/", cafile=certfile)
assert client.session is not None
await client.disconnect()
@pytest.mark.asyncio
@ -62,9 +64,11 @@ async def test_reconnect_ws_retain_username_password(broker_fixture):
@pytest.mark.asyncio
async def test_connect_ws_secure(ca_file_fixture, broker_fixture):
client = MQTTClient()
await client.connect("ws://127.0.0.1:8081/", cafile=ca_file_fixture)
async def test_connect_ws_secure(rsa_keys, broker_fixture):
certfile, _ = rsa_keys
client = MQTTClient(config={"auto_reconnect": False})
# since we're using a self-signed certificate, need to provide the server's certificate to verify authenticity
await client.connect("wss://localhost:8081/", cafile=certfile)
assert client.session is not None
await client.disconnect()

Wyświetl plik

@ -21,7 +21,7 @@ async def test_broker_acl():
broker_acl_script = Path(__file__).parent.parent / "samples/broker_acl.py"
process = subprocess.Popen(["python", broker_acl_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# Send the interrupt signal
await asyncio.sleep(5)
await asyncio.sleep(2)
process.send_signal(signal.SIGINT)
stdout, stderr = process.communicate()
logger.debug(stderr.decode("utf-8"))
@ -34,7 +34,7 @@ async def test_broker_acl():
async def test_broker_simple():
broker_simple_script = Path(__file__).parent.parent / "samples/broker_simple.py"
process = subprocess.Popen(["python", broker_simple_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
await asyncio.sleep(5)
await asyncio.sleep(2)
# Send the interrupt signal
process.send_signal(signal.SIGINT)
@ -50,7 +50,7 @@ async def test_broker_simple():
async def test_broker_start():
broker_start_script = Path(__file__).parent.parent / "samples/broker_start.py"
process = subprocess.Popen(["python", broker_start_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
await asyncio.sleep(5)
await asyncio.sleep(2)
# Send the interrupt signal to stop broker
process.send_signal(signal.SIGINT)
@ -65,7 +65,7 @@ async def test_broker_start():
async def test_broker_taboo():
broker_taboo_script = Path(__file__).parent.parent / "samples/broker_taboo.py"
process = subprocess.Popen(["python", broker_taboo_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
await asyncio.sleep(5)
await asyncio.sleep(2)
# Send the interrupt signal to stop broker
process.send_signal(signal.SIGINT)
@ -86,7 +86,7 @@ async def test_client_keepalive():
keep_alive_script = Path(__file__).parent.parent / "samples/client_keepalive.py"
process = subprocess.Popen(["python", keep_alive_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
await asyncio.sleep(2)
await asyncio.sleep(1)
stdout, stderr = process.communicate()
assert "ERROR" not in stderr.decode("utf-8")
@ -111,28 +111,30 @@ async def test_client_publish():
await broker.shutdown()
broker_ssl_config = {
"listeners": {
"default": {
"type": "tcp",
"bind": "0.0.0.0:8883",
"ssl": True,
"certfile": "cert.pem",
"keyfile": "key.pem",
}
},
"auth": {
"allow-anonymous": True,
"plugins": ["auth_anonymous"]
@pytest.fixture
def broker_ssl_config(rsa_keys):
certfile, keyfile = rsa_keys
return {
"listeners": {
"default": {
"type": "tcp",
"bind": "0.0.0.0:8883",
"ssl": True,
"certfile": certfile,
"keyfile": keyfile,
}
}
},
"auth": {
"allow-anonymous": True,
"plugins": ["auth_anonymous"]
}
}
@pytest.mark.asyncio
async def test_client_publish_ssl():
async def test_client_publish_ssl(broker_ssl_config, rsa_keys):
certfile, _ = rsa_keys
# generate a self-signed certificate for this test
cmd = 'openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout key.pem -out cert.pem -subj "/CN=localhost"'
subprocess.run(cmd, shell=True, capture_output=True, text=True)
# start a secure broker
broker = Broker(config=broker_ssl_config)
@ -140,7 +142,7 @@ async def test_client_publish_ssl():
await asyncio.sleep(2)
# run the sample
client_publish_ssl_script = Path(__file__).parent.parent / "samples/client_publish_ssl.py"
process = subprocess.Popen(["python", client_publish_ssl_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
process = subprocess.Popen(["python", client_publish_ssl_script, '--cert', certfile], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
await asyncio.sleep(2)
stdout, stderr = process.communicate()
@ -179,7 +181,7 @@ broker_ws_config = {
"auth": {
"allow-anonymous": True,
"plugins": ["auth_anonymous"]
}
}
}
@pytest.mark.asyncio
@ -205,13 +207,14 @@ broker_std_config = {
"listeners": {
"default": {
"type": "tcp",
"bind": "0.0.0.0:1883", }
"bind": "0.0.0.0:1883",
}
},
'sys_interval':2,
"auth": {
"allow-anonymous": True,
"plugins": ["auth_anonymous"]
}
}
}
@ -241,6 +244,7 @@ async def test_client_subscribe():
await broker.shutdown()
@pytest.mark.asyncio
async def test_client_subscribe_plugin_acl():
broker = Broker(config=broker_acl_config)
@ -249,7 +253,7 @@ async def test_client_subscribe_plugin_acl():
broker_simple_script = Path(__file__).parent.parent / "samples/client_subscribe_acl.py"
process = subprocess.Popen(["python", broker_simple_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# Send the interrupt signal
await asyncio.sleep(5)
await asyncio.sleep(2)
process.send_signal(signal.SIGINT)
stdout, stderr = process.communicate()
logger.debug(stderr.decode("utf-8"))
@ -268,7 +272,7 @@ async def test_client_subscribe_plugin_taboo():
broker_simple_script = Path(__file__).parent.parent / "samples/client_subscribe_acl.py"
process = subprocess.Popen(["python", broker_simple_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# Send the interrupt signal
await asyncio.sleep(5)
await asyncio.sleep(2)
process.send_signal(signal.SIGINT)
stdout, stderr = process.communicate()
logger.debug(stderr.decode("utf-8"))

18
uv.lock
Wyświetl plik

@ -148,8 +148,10 @@ ci = [
]
contrib = [
{ name = "aiohttp" },
{ name = "aiosqlite" },
{ name = "argon2-cffi" },
{ name = "sqlalchemy" },
{ name = "greenlet" },
{ name = "sqlalchemy", extra = ["asyncio"] },
]
[package.dev-dependencies]
@ -171,6 +173,7 @@ dev = [
{ name = "pytest-timeout" },
{ name = "ruff" },
{ name = "setuptools" },
{ name = "sqlalchemy", extra = ["mypy"] },
{ name = "types-mock" },
{ name = "types-pyyaml" },
{ name = "types-setuptools" },
@ -197,13 +200,15 @@ docs = [
requires-dist = [
{ name = "aiohttp", specifier = ">=3.12.7" },
{ name = "aiohttp", marker = "extra == 'contrib'", specifier = ">=3.12.13" },
{ name = "aiosqlite", marker = "extra == 'contrib'", specifier = ">=0.21.0" },
{ name = "argon2-cffi", marker = "extra == 'contrib'", specifier = ">=25.1.0" },
{ name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" },
{ name = "dacite", specifier = ">=1.9.2" },
{ name = "greenlet", marker = "extra == 'contrib'", specifier = ">=3.2.3" },
{ name = "passlib", specifier = "==1.7.4" },
{ name = "psutil", specifier = ">=7.0.0" },
{ name = "pyyaml", specifier = "==6.0.2" },
{ name = "sqlalchemy", marker = "extra == 'contrib'", specifier = ">=2.0.41" },
{ name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'contrib'", specifier = ">=2.0.41" },
{ name = "transitions", specifier = "==0.9.2" },
{ name = "typer", specifier = "==0.15.4" },
{ name = "websockets", specifier = "==15.0.1" },
@ -229,6 +234,7 @@ dev = [
{ name = "pytest-timeout", specifier = ">=2.3.1" },
{ name = "ruff", specifier = ">=0.11.3" },
{ name = "setuptools", specifier = ">=78.1.0" },
{ name = "sqlalchemy", extras = ["mypy"], specifier = ">=2.0.41" },
{ name = "types-mock", specifier = ">=5.2.0.20250306" },
{ name = "types-pyyaml", specifier = ">=6.0.12.20250402" },
{ name = "types-setuptools", specifier = ">=78.1.0.20250329" },
@ -2415,6 +2421,14 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" },
]
[package.optional-dependencies]
asyncio = [
{ name = "greenlet" },
]
mypy = [
{ name = "mypy" },
]
[[package]]
name = "tomli"
version = "2.2.1"