kopia lustrzana https://github.com/Yakifo/amqtt
Plugin: rebuild of session persistence (#256)
* rebuild of the persistence plugin to handle storing / restoring sessions * adding event for broker when a message is being retained * adding retained message logic to persistence plugin and test cases * updated documentation * moving DataClassListJson field to common location for reusepull/264/head
rodzic
2a7aa11524
commit
de40ca51d3
|
@ -33,7 +33,7 @@ jobs:
|
|||
python-version: "3.13"
|
||||
|
||||
- name: 🏗 Install the project
|
||||
run: uv sync --locked --dev --extra contrib
|
||||
run: uv sync --locked --dev --all-extras
|
||||
|
||||
- name: Run mypy
|
||||
run: uv run --frozen mypy ${{ env.PROJECT_PATH }}/
|
||||
|
@ -69,7 +69,7 @@ jobs:
|
|||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: 🏗 Install the project
|
||||
run: uv sync --locked --dev --extra contrib
|
||||
run: uv sync --locked --dev --all-extras
|
||||
|
||||
- name: Run pytest
|
||||
run: uv run --frozen pytest tests/ --cov=./ --cov-report=xml --junitxml=pytest-report.xml
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
- Communication over TCP and/or websocket, including support for SSL/TLS
|
||||
- Support QoS 0, QoS 1 and QoS 2 messages flow
|
||||
- Client auto-reconnection on network lost
|
||||
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
|
||||
- Custom functionality expansion; plugins included: authentication, `$SYS` topic publishing, session persistence
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ class ExternalServer(Server):
|
|||
|
||||
|
||||
class BrokerContext(BaseContext):
|
||||
"""BrokerContext is used as the context passed to plugins interacting with the broker."""
|
||||
"""Used to provide the server's context as well as public methods for accessing internal state."""
|
||||
|
||||
def __init__(self, broker: "Broker") -> None:
|
||||
super().__init__()
|
||||
|
@ -116,16 +116,21 @@ class BrokerContext(BaseContext):
|
|||
self._broker_instance = broker
|
||||
|
||||
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
||||
"""Send message to all client sessions subscribing to `topic`."""
|
||||
await self._broker_instance.internal_message_broadcast(topic, data, qos)
|
||||
|
||||
def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
|
||||
self._broker_instance.retain_message(None, topic_name, data, qos)
|
||||
async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
|
||||
await self._broker_instance.retain_message(None, topic_name, data, qos)
|
||||
|
||||
@property
|
||||
def sessions(self) -> Generator[Session]:
|
||||
for session in self._broker_instance.sessions.values():
|
||||
yield session[0]
|
||||
|
||||
def get_session(self, client_id: str) -> Session | None:
|
||||
"""Return the session associated with `client_id`, if it exists."""
|
||||
return self._broker_instance.sessions.get(client_id, (None, None))[0]
|
||||
|
||||
@property
|
||||
def retained_messages(self) -> dict[str, RetainedApplicationMessage]:
|
||||
return self._broker_instance.retained_messages
|
||||
|
@ -134,6 +139,20 @@ class BrokerContext(BaseContext):
|
|||
def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
|
||||
return self._broker_instance.subscriptions
|
||||
|
||||
async def add_subscription(self, client_id: str, topic: str|None, qos: int|None) -> None:
|
||||
"""Create a topic subscription for the given `client_id`.
|
||||
|
||||
If a client session doesn't exist for `client_id`, create a disconnected session.
|
||||
If `topic` and `qos` are both `None`, only create the client session.
|
||||
"""
|
||||
if client_id not in self._broker_instance.sessions:
|
||||
broker_handler, session = self._broker_instance.create_offline_session(client_id)
|
||||
self._broker_instance._sessions[client_id] = (session, broker_handler) # noqa: SLF001
|
||||
|
||||
if topic is not None and qos is not None:
|
||||
session, _ = self._broker_instance.sessions[client_id]
|
||||
await self._broker_instance.add_subscription((topic, qos), session)
|
||||
|
||||
|
||||
class Broker:
|
||||
"""MQTT 3.1.1 compliant broker implementation.
|
||||
|
@ -483,7 +502,17 @@ class Broker:
|
|||
# Get session from cache
|
||||
elif client_session.client_id in self._sessions:
|
||||
self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}")
|
||||
client_session, _ = self._sessions[client_session.client_id]
|
||||
|
||||
# even though the session previously existed, the new connection can bring updated configuration and credentials
|
||||
existing_client_session, _ = self._sessions[client_session.client_id]
|
||||
existing_client_session.will_flag = client_session.will_flag
|
||||
existing_client_session.will_message = client_session.will_message
|
||||
existing_client_session.will_topic = client_session.will_topic
|
||||
existing_client_session.will_qos = client_session.will_qos
|
||||
existing_client_session.keep_alive = client_session.keep_alive
|
||||
existing_client_session.username = client_session.username
|
||||
existing_client_session.password = client_session.password
|
||||
client_session = existing_client_session
|
||||
client_session.parent = 1
|
||||
else:
|
||||
client_session.parent = 0
|
||||
|
@ -495,6 +524,14 @@ class Broker:
|
|||
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
|
||||
return handler, client_session
|
||||
|
||||
def create_offline_session(self, client_id: str) -> tuple[BrokerProtocolHandler, Session]:
|
||||
session = Session()
|
||||
session.client_id = client_id
|
||||
|
||||
bph = BrokerProtocolHandler(self.plugins_manager, session)
|
||||
session.transitions.disconnect()
|
||||
return bph, session
|
||||
|
||||
async def _handle_client_session(
|
||||
self,
|
||||
reader: ReaderAdapter,
|
||||
|
@ -635,7 +672,7 @@ class Broker:
|
|||
client_session.will_qos,
|
||||
)
|
||||
if client_session.will_retain:
|
||||
self.retain_message(
|
||||
await self.retain_message(
|
||||
client_session,
|
||||
client_session.will_topic,
|
||||
client_session.will_message,
|
||||
|
@ -660,7 +697,7 @@ class Broker:
|
|||
"""Handle client subscription."""
|
||||
self.logger.debug(f"{client_session.client_id} handling subscription")
|
||||
subscriptions = subscribe_waiter.result()
|
||||
return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics]
|
||||
return_codes = [await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics]
|
||||
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
|
||||
for index, subscription in enumerate(subscriptions.topics):
|
||||
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
|
||||
|
@ -730,7 +767,7 @@ class Broker:
|
|||
)
|
||||
await self._broadcast_message(client_session, app_message.topic, app_message.data)
|
||||
if app_message.publish_packet and app_message.publish_packet.retain_flag:
|
||||
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
||||
await self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
||||
return True
|
||||
|
||||
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
|
||||
|
@ -774,7 +811,7 @@ class Broker:
|
|||
|
||||
return False
|
||||
|
||||
def retain_message(
|
||||
async def retain_message(
|
||||
self,
|
||||
source_session: Session | None,
|
||||
topic_name: str | None,
|
||||
|
@ -785,12 +822,25 @@ class Broker:
|
|||
# If retained flag set, store the message for further subscriptions
|
||||
self.logger.debug(f"Retaining message on topic {topic_name}")
|
||||
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
||||
|
||||
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
||||
client_id=None,
|
||||
retained_message=self._retained_messages[topic_name])
|
||||
|
||||
# [MQTT-3.3.1-10]
|
||||
elif topic_name in self._retained_messages:
|
||||
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
|
||||
|
||||
cleared_message = self._retained_messages[topic_name]
|
||||
cleared_message.data = b""
|
||||
|
||||
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
||||
client_id=None,
|
||||
retained_message=cleared_message)
|
||||
|
||||
del self._retained_messages[topic_name]
|
||||
|
||||
async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
|
||||
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
|
||||
topic_filter, qos = subscription
|
||||
if "#" in topic_filter and not topic_filter.endswith("#"):
|
||||
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
|
||||
|
@ -982,6 +1032,10 @@ class Broker:
|
|||
retained_message = RetainedApplicationMessage(broadcast["session"], broadcast["topic"], broadcast["data"], qos)
|
||||
await target_session.retained_messages.put(retained_message)
|
||||
|
||||
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
|
||||
client_id=target_session.client_id,
|
||||
retained_message=retained_message)
|
||||
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}")
|
||||
|
||||
|
|
|
@ -0,0 +1,267 @@
|
|||
from dataclasses import dataclass
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import Boolean, Integer, LargeBinary, Result, String, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from amqtt.broker import BrokerContext, RetainedApplicationMessage
|
||||
from amqtt.contrib import DataClassListJSON
|
||||
from amqtt.errors import PluginError
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetainedMessage:
|
||||
topic: str
|
||||
data: str
|
||||
qos: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Subscription:
|
||||
topic: str
|
||||
qos: int
|
||||
|
||||
|
||||
class StoredSession(Base):
|
||||
__tablename__ = "stored_sessions"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
client_id: Mapped[str] = mapped_column(String)
|
||||
|
||||
clean_session: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
will_flag: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
|
||||
|
||||
will_message: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
|
||||
will_qos: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None)
|
||||
will_retain: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=None)
|
||||
will_topic: Mapped[str | None] = mapped_column(String, nullable=True, default=None)
|
||||
|
||||
keep_alive: Mapped[int] = mapped_column(Integer, default=0)
|
||||
retained: Mapped[list[RetainedMessage]] = mapped_column(DataClassListJSON(RetainedMessage), default=list)
|
||||
subscriptions: Mapped[list[Subscription]] = mapped_column(DataClassListJSON(Subscription), default=list)
|
||||
|
||||
|
||||
class StoredMessage(Base):
|
||||
__tablename__ = "stored_messages"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
topic: Mapped[str] = mapped_column(String)
|
||||
data: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
|
||||
qos: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
|
||||
class SessionDBPlugin(BasePlugin[BrokerContext]):
|
||||
"""Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
|
||||
|
||||
Configuration:
|
||||
- file *(string)* path & filename to store the session db. default: `amqtt.db`
|
||||
- clear_on_shutdown *(bool)* if the broker shutdowns down normally, don't retain any information. default: `True`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
# bypass the `test_plugins_correct_has_attr` until it can be updated
|
||||
if not hasattr(self.config, "file"):
|
||||
logger.warning("`Config` is missing a `file` attribute")
|
||||
return
|
||||
|
||||
self._engine = create_async_engine(f"sqlite+aiosqlite:///{self.config.file}")
|
||||
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_session(db_session: AsyncSession, client_id:str) -> StoredSession:
|
||||
|
||||
stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
|
||||
stored_session = await db_session.scalar(stmt)
|
||||
if stored_session is None:
|
||||
stored_session = StoredSession(client_id=client_id)
|
||||
db_session.add(stored_session)
|
||||
await db_session.flush()
|
||||
return stored_session
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_message(db_session: AsyncSession, topic:str) -> StoredMessage:
|
||||
|
||||
stmt = select(StoredMessage).filter(StoredMessage.topic == topic)
|
||||
stored_message = await db_session.scalar(stmt)
|
||||
if stored_message is None:
|
||||
stored_message = StoredMessage(topic=topic)
|
||||
db_session.add(stored_message)
|
||||
await db_session.flush()
|
||||
return stored_message
|
||||
|
||||
|
||||
async def on_broker_client_connected(self, client_id:str, client_session:Session) -> None:
|
||||
"""Search to see if session already exists."""
|
||||
# if client id doesn't exist, create (can ignore if session is anonymous)
|
||||
# update session information (will, clean_session, etc)
|
||||
|
||||
# don't store session information for clean or anonymous sessions
|
||||
if client_session.clean_session in (None, True) or client_session.is_anonymous:
|
||||
return
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
stored_session = await self._get_or_create_session(db_session, client_id)
|
||||
|
||||
stored_session.clean_session = client_session.clean_session
|
||||
stored_session.will_flag = client_session.will_flag
|
||||
stored_session.will_message = client_session.will_message # type: ignore[assignment]
|
||||
stored_session.will_qos = client_session.will_qos
|
||||
stored_session.will_retain = client_session.will_retain
|
||||
stored_session.will_topic = client_session.will_topic
|
||||
stored_session.keep_alive = client_session.keep_alive
|
||||
|
||||
await db_session.flush()
|
||||
|
||||
async def on_broker_client_subscribed(self, client_id: str, topic: str, qos: int) -> None:
|
||||
"""Create/update subscription if clean session = false."""
|
||||
session = self.context.get_session(client_id)
|
||||
if not session:
|
||||
logger.warning(f"'{client_id}' is subscribing but doesn't have a session")
|
||||
return
|
||||
|
||||
if session.clean_session:
|
||||
return
|
||||
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
# stored sessions shouldn't need to be created here, but we'll use the same helper...
|
||||
stored_session = await self._get_or_create_session(db_session, client_id)
|
||||
stored_session.subscriptions = [*stored_session.subscriptions, Subscription(topic, qos)]
|
||||
await db_session.flush()
|
||||
|
||||
async def on_broker_client_unsubscribed(self, client_id: str, topic: str) -> None:
|
||||
"""Remove subscription if clean session = false."""
|
||||
|
||||
async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None:
|
||||
"""Update to retained messages.
|
||||
|
||||
if retained_message.data is None or '', the message is being cleared
|
||||
"""
|
||||
# if client_id is valid, the retained message is for a disconnected client
|
||||
if client_id is not None:
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
# stored sessions shouldn't need to be created here, but we'll use the same helper...
|
||||
stored_session = await self._get_or_create_session(db_session, client_id)
|
||||
stored_session.retained = [*stored_session.retained, RetainedMessage(retained_message.topic,
|
||||
retained_message.data.decode(),
|
||||
retained_message.qos or 0)]
|
||||
await db_session.flush()
|
||||
return
|
||||
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
# if the retained message has data, we need to store/update for the topic
|
||||
if retained_message.data:
|
||||
client_message = await self._get_or_create_message(db_session, retained_message.topic)
|
||||
client_message.data = retained_message.data # type: ignore[assignment]
|
||||
client_message.qos = retained_message.qos or 0
|
||||
await db_session.flush()
|
||||
return
|
||||
|
||||
# if there is no data, clear the stored message (if exists) for the topic
|
||||
stmt = select(StoredMessage).filter(StoredMessage.topic == retained_message.topic)
|
||||
topic_message = await db_session.scalar(stmt)
|
||||
if topic_message is not None:
|
||||
await db_session.delete(topic_message)
|
||||
await db_session.flush()
|
||||
return
|
||||
|
||||
async def on_broker_pre_start(self) -> None:
|
||||
"""Initialize the database and db connection."""
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async def on_broker_post_start(self) -> None:
|
||||
"""Load subscriptions."""
|
||||
if len(self.context.subscriptions) > 0:
|
||||
msg = "SessionDBPlugin : broker shouldn't have any subscriptions yet"
|
||||
raise PluginError(msg)
|
||||
|
||||
if len(list(self.context.sessions)) > 0:
|
||||
msg = "SessionDBPlugin : broker shouldn't have any sessions yet"
|
||||
raise PluginError(msg)
|
||||
|
||||
async with self._db_session_maker() as db_session, db_session.begin():
|
||||
stmt = select(StoredSession)
|
||||
stored_sessions = await db_session.execute(stmt)
|
||||
|
||||
restored_sessions = 0
|
||||
for stored_session in stored_sessions.scalars():
|
||||
await self.context.add_subscription(stored_session.client_id, None, None)
|
||||
for subscription in stored_session.subscriptions:
|
||||
await self.context.add_subscription(stored_session.client_id,
|
||||
subscription.topic,
|
||||
subscription.qos)
|
||||
session = self.context.get_session(stored_session.client_id)
|
||||
if not session:
|
||||
continue
|
||||
session.clean_session = stored_session.clean_session
|
||||
session.will_flag = stored_session.will_flag
|
||||
session.will_message = stored_session.will_message
|
||||
session.will_qos = stored_session.will_qos
|
||||
session.will_retain = stored_session.will_retain
|
||||
session.will_topic = stored_session.will_topic
|
||||
session.keep_alive = stored_session.keep_alive
|
||||
|
||||
for message in stored_session.retained:
|
||||
retained_message = RetainedApplicationMessage(
|
||||
source_session=None,
|
||||
topic=message.topic,
|
||||
data=message.data.encode(),
|
||||
qos=message.qos
|
||||
)
|
||||
await session.retained_messages.put(retained_message)
|
||||
restored_sessions += 1
|
||||
|
||||
stmt = select(StoredMessage)
|
||||
stored_messages: Result[tuple[StoredMessage]] = await db_session.execute(stmt)
|
||||
|
||||
restored_messages = 0
|
||||
retained_messages = self.context.retained_messages
|
||||
for stored_message in stored_messages.scalars():
|
||||
retained_messages[stored_message.topic] = (RetainedApplicationMessage(
|
||||
source_session=None,
|
||||
topic=stored_message.topic,
|
||||
data=stored_message.data or b"",
|
||||
qos=stored_message.qos
|
||||
))
|
||||
restored_messages += 1
|
||||
logger.info(f"Retained messages restored: {restored_messages}")
|
||||
|
||||
|
||||
logger.info(f"Restored {restored_sessions} sessions.")
|
||||
|
||||
async def on_broker_pre_shutdown(self) -> None:
|
||||
"""Clean up the db connection."""
|
||||
await self._engine.dispose()
|
||||
|
||||
async def on_broker_post_shutdown(self) -> None:
|
||||
|
||||
if self.config.clear_on_shutdown and self.config.file.exists():
|
||||
self.config.file.unlink()
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration variables."""
|
||||
|
||||
file: str | Path = "amqtt.db"
|
||||
"""path & filename to store the sqlite session db."""
|
||||
clear_on_shutdown: bool = True
|
||||
"""if the broker shutdowns down normally, don't retain any information."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create `Path` from string path."""
|
||||
if isinstance(self.file, str):
|
||||
self.file = Path(self.file)
|
|
@ -31,4 +31,5 @@ class BrokerEvents(Events):
|
|||
CLIENT_DISCONNECTED = "broker_client_disconnected"
|
||||
CLIENT_SUBSCRIBED = "broker_client_subscribed"
|
||||
CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed"
|
||||
RETAINED_MESSAGE = "broker_retained_message"
|
||||
MESSAGE_RECEIVED = "broker_message_received"
|
||||
|
|
|
@ -8,6 +8,8 @@ import copy
|
|||
from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
||||
from inspect import iscoroutinefunction
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast
|
||||
import warnings
|
||||
|
||||
|
@ -291,6 +293,15 @@ class PluginManager(Generic[C]):
|
|||
return asyncio.ensure_future(coro)
|
||||
|
||||
def _clean_fired_events(self, future: asyncio.Future[Any]) -> None:
|
||||
if self.logger.getEffectiveLevel() <= logging.DEBUG:
|
||||
try:
|
||||
future.result()
|
||||
except asyncio.CancelledError:
|
||||
self.logger.warning("fired event was cancelled")
|
||||
# display plugin fault; don't allow it to cause a broker failure
|
||||
except Exception as exc: # noqa: BLE001, pylint: disable=W0718
|
||||
traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
|
||||
|
||||
with contextlib.suppress(KeyError, ValueError):
|
||||
self._fired_events.remove(future)
|
||||
|
||||
|
|
|
@ -1,85 +1,11 @@
|
|||
import json
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.session import Session
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
|
||||
|
||||
class SQLitePlugin:
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
self.context: BaseContext = context
|
||||
self.conn: sqlite3.Connection | None = None
|
||||
self.cursor: sqlite3.Cursor | None = None
|
||||
self.db_file: str | None = None
|
||||
self.persistence_config: dict[str, Any]
|
||||
class SQLitePlugin(BasePlugin[BrokerContext]):
|
||||
|
||||
if (
|
||||
persistence_config := self.context.config.get("persistence") if self.context.config is not None else None
|
||||
) is not None:
|
||||
self.persistence_config = persistence_config
|
||||
self.init_db()
|
||||
else:
|
||||
self.context.logger.warning("'persistence' section not found in context configuration")
|
||||
|
||||
def init_db(self) -> None:
|
||||
self.db_file = self.persistence_config.get("file")
|
||||
if not self.db_file:
|
||||
self.context.logger.warning("'file' persistence parameter not found")
|
||||
else:
|
||||
try:
|
||||
self.conn = sqlite3.connect(self.db_file)
|
||||
self.cursor = self.conn.cursor()
|
||||
self.context.logger.info(f"Database file '{self.db_file}' opened")
|
||||
except Exception:
|
||||
self.context.logger.exception(f"Error while initializing database '{self.db_file}'")
|
||||
if self.cursor:
|
||||
self.cursor.execute(
|
||||
"CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)",
|
||||
)
|
||||
self.cursor.execute("PRAGMA table_info(session)")
|
||||
columns = {col[1] for col in self.cursor.fetchall()}
|
||||
required_columns = {"client_id", "data"}
|
||||
if not required_columns.issubset(columns):
|
||||
self.context.logger.error("Database schema for 'session' table is incompatible.")
|
||||
|
||||
async def save_session(self, session: Session) -> None:
|
||||
if self.cursor and self.conn:
|
||||
dump: str = json.dumps(session, default=str)
|
||||
try:
|
||||
self.cursor.execute(
|
||||
"INSERT OR REPLACE INTO session (client_id, data) VALUES (?, ?)",
|
||||
(session.client_id, dump),
|
||||
)
|
||||
self.conn.commit()
|
||||
except Exception:
|
||||
self.context.logger.exception(f"Failed saving session '{session}'")
|
||||
|
||||
async def find_session(self, client_id: str) -> Session | None:
|
||||
if self.cursor:
|
||||
row = self.cursor.execute(
|
||||
"SELECT data FROM session where client_id=?",
|
||||
(client_id,),
|
||||
).fetchone()
|
||||
return json.loads(row[0]) if row else None
|
||||
return None
|
||||
|
||||
async def del_session(self, client_id: str) -> None:
|
||||
if self.cursor and self.conn:
|
||||
try:
|
||||
exists = self.cursor.execute("SELECT 1 FROM session WHERE client_id=?", (client_id,)).fetchone()
|
||||
if exists:
|
||||
self.cursor.execute("DELETE FROM session where client_id=?", (client_id,))
|
||||
self.conn.commit()
|
||||
except Exception:
|
||||
self.context.logger.exception(f"Failed deleting session with client_id '{client_id}'")
|
||||
|
||||
async def on_broker_post_shutdown(self) -> None:
|
||||
if self.conn:
|
||||
try:
|
||||
self.conn.close()
|
||||
self.context.logger.info(f"Database file '{self.db_file}' closed")
|
||||
except Exception:
|
||||
self.context.logger.exception("Error closing database connection")
|
||||
finally:
|
||||
self.conn = None
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
warnings.warn("SQLitePlugin is deprecated, use amqtt.contrib.persistence.SessionDBPlugin", stacklevel=1)
|
||||
|
|
|
@ -112,7 +112,7 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
|
|||
"""Initialize statistics and start $SYS broadcasting."""
|
||||
self._stats[STAT_START_TIME] = int(datetime.now(tz=UTC).timestamp())
|
||||
version = f"aMQTT version {amqtt.__version__}"
|
||||
self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())
|
||||
await self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())
|
||||
|
||||
# Start $SYS topics management
|
||||
try:
|
||||
|
|
|
@ -106,6 +106,8 @@ All plugins are notified of events if the `BasePlugin` subclass implements one o
|
|||
- `async def on_broker_client_connected(self, *, client_id:str) -> None`
|
||||
- `async def on_broker_client_disconnected(self, *, client_id:str) -> None`
|
||||
|
||||
- `async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None`
|
||||
|
||||
- `async def on_broker_client_subscribed(self, *, client_id: str, topic: str, qos: int) -> None`
|
||||
- `async def on_broker_client_unsubscribed(self, *, client_id: str, topic: str) -> None`
|
||||
|
||||
|
@ -115,6 +117,11 @@ All plugins are notified of events if the `BasePlugin` subclass implements one o
|
|||
- `async def on_mqtt_packet_received(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
|
||||
|
||||
|
||||
!!! note retained message event
|
||||
if the `client_id` is `None`, the message is retained for a topic
|
||||
if the `retained_message.data` is `None` or empty (`''`), the topic message is being cleared
|
||||
|
||||
|
||||
## Authentication Plugins
|
||||
|
||||
In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin`
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Existing Plugins
|
||||
# Packaged Plugins
|
||||
|
||||
With the aMQTT plugins framework, one can add additional functionality without
|
||||
having to rewrite core logic in the broker or client. Plugins can be loaded and configured using
|
||||
|
@ -51,7 +51,6 @@ Authentication plugin allowing anonymous access.
|
|||
extra:
|
||||
class_style: "simple"
|
||||
|
||||
|
||||
!!! danger
|
||||
even if `allow_anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client
|
||||
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
# Session Persistence
|
||||
|
||||
`amqtt.plugins.persistence.SessionDBPlugin`
|
||||
|
||||
Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
|
||||
|
||||
::: amqtt.plugins.persistence.SessionDBPlugin.Config
|
||||
options:
|
||||
show_source: false
|
||||
heading_level: 4
|
||||
extra:
|
||||
class_style: "simple"
|
|
@ -10,7 +10,7 @@
|
|||
- Communication over TCP and/or websocket, including support for SSL/TLS
|
||||
- Support QoS 0, QoS 1 and QoS 2 messages flow
|
||||
- Client auto-reconnection on network lost
|
||||
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
|
||||
- Custom functionality expansion; plugins included: authentication, `$SYS` topic publishing, session persistence
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ dependencies = [
|
|||
"typer==0.15.4",
|
||||
"aiohttp>=3.12.7",
|
||||
"dacite>=1.9.2",
|
||||
"psutil>=7.0.0",
|
||||
"psutil>=7.0.0"
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
@ -57,6 +57,7 @@ dev = [
|
|||
"pytest>=8.3.5", # https://pypi.org/project/pytest
|
||||
"ruff>=0.11.3", # https://pypi.org/project/ruff
|
||||
"setuptools>=78.1.0",
|
||||
"sqlalchemy[mypy]>=2.0.41",
|
||||
"types-mock>=5.2.0.20250306", # https://pypi.org/project/types-mock
|
||||
"types-PyYAML>=6.0.12.20250402", # https://pypi.org/project/types-PyYAML
|
||||
"types-setuptools>=78.1.0.20250329", # https://pypi.org/project/types-setuptools
|
||||
|
@ -84,7 +85,9 @@ docs = [
|
|||
[project.optional-dependencies]
|
||||
ci = ["coveralls==4.0.1"]
|
||||
contrib = [
|
||||
"sqlalchemy>=2.0.41",
|
||||
"aiosqlite>=0.21.0",
|
||||
"greenlet>=3.2.3",
|
||||
"sqlalchemy[asyncio]>=2.0.41",
|
||||
"argon2-cffi>=25.1.0",
|
||||
"aiohttp>=3.12.13",
|
||||
]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
@ -23,13 +24,13 @@ config = {
|
|||
},
|
||||
"auto_reconnect": False,
|
||||
"check_hostname": False,
|
||||
"certfile": "cert.pem",
|
||||
"certfile": "",
|
||||
}
|
||||
|
||||
client = MQTTClient(config=config)
|
||||
|
||||
|
||||
async def test_coro() -> None:
|
||||
async def test_coro(certfile: str) -> None:
|
||||
config['certfile'] = certfile
|
||||
client = MQTTClient(config=config)
|
||||
|
||||
await client.connect("mqtts://localhost:8883")
|
||||
tasks = [
|
||||
|
@ -45,7 +46,11 @@ async def test_coro() -> None:
|
|||
def __main__():
|
||||
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
asyncio.run(test_coro())
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cert', default='cert.pem', help="path & file to verify server's authenticity")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test_coro(args.cert))
|
||||
|
||||
if __name__ == "__main__":
|
||||
__main__()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
@ -15,19 +16,46 @@ log = logging.getLogger(__name__)
|
|||
|
||||
pytest_plugins = ["pytest_logdog"]
|
||||
|
||||
test_config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 15},
|
||||
"ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 15},
|
||||
"wss": {"type": "ws", "bind": "127.0.0.1:8081", "max_connections": 15},
|
||||
},
|
||||
"sys_interval": 0,
|
||||
"auth": {
|
||||
"allow-anonymous": True,
|
||||
}
|
||||
}
|
||||
@pytest.fixture
|
||||
def rsa_keys():
|
||||
tmp_dir = tempfile.TemporaryDirectory(prefix='amqtt-test-')
|
||||
cert = Path(tmp_dir.name) / "cert.pem"
|
||||
key = Path(tmp_dir.name) / "key.pem"
|
||||
cmd = f'openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout {key} -out {cert} -subj "/CN=localhost"'
|
||||
subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
yield cert, key
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(rsa_keys):
|
||||
certfile, keyfile = rsa_keys
|
||||
yield {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 15},
|
||||
"mqtts": {
|
||||
"type": "tcp",
|
||||
"bind": "127.0.0.1:1884",
|
||||
"max_connections": 15,
|
||||
"ssl": True,
|
||||
"certfile": certfile,
|
||||
"keyfile": keyfile
|
||||
},
|
||||
"ws": {"type": "ws", "bind": "127.0.0.1:8080", "max_connections": 15},
|
||||
"wss": {
|
||||
"type": "ws",
|
||||
"bind": "127.0.0.1:8081",
|
||||
"max_connections": 15,
|
||||
"ssl": True,
|
||||
'certfile': certfile,
|
||||
'keyfile': keyfile},
|
||||
},
|
||||
"sys_interval": 0,
|
||||
"auth": {
|
||||
"allow-anonymous": True,
|
||||
}
|
||||
}
|
||||
|
||||
test_config_acl: dict[str, int | dict[str, Any]] = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1884", "max_connections": 10},
|
||||
|
@ -64,7 +92,7 @@ def mock_plugin_manager():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def broker_fixture():
|
||||
async def broker_fixture(test_config):
|
||||
broker = Broker(test_config, plugin_namespace="amqtt.test.plugins")
|
||||
await broker.start()
|
||||
assert broker.transitions.is_started()
|
||||
|
@ -78,7 +106,7 @@ async def broker_fixture():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def broker(mock_plugin_manager):
|
||||
async def broker(mock_plugin_manager, test_config):
|
||||
# just making sure the mock is in place before we start our broker
|
||||
assert mock_plugin_manager is not None
|
||||
|
||||
|
|
|
@ -0,0 +1,543 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
import pytest
|
||||
import aiosqlite
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy import select
|
||||
|
||||
from amqtt.broker import Broker, BrokerContext, RetainedApplicationMessage
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_1
|
||||
from amqtt.contrib.persistence import SessionDBPlugin, Subscription, StoredSession, RetainedMessage
|
||||
from amqtt.session import Session
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@pytest.fixture
|
||||
async def db_file():
|
||||
db_file = Path(__file__).parent / "amqtt.db"
|
||||
if db_file.exists():
|
||||
raise NotImplementedError("existing db file found, should it be cleaned up?")
|
||||
|
||||
yield db_file
|
||||
if db_file.exists():
|
||||
db_file.unlink()
|
||||
|
||||
@pytest.fixture
|
||||
async def broker_context():
|
||||
cfg = {
|
||||
'listeners': { 'default': {'type': 'tcp', 'bind': 'localhost:1883' }},
|
||||
'plugins': {}
|
||||
}
|
||||
|
||||
context = BrokerContext(broker=Broker(config=cfg))
|
||||
yield context
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session_factory(db_file):
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{str(db_file)}")
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
yield factory
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_tables(db_file, broker_context):
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
assert db_file.exists()
|
||||
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
cursor = conn.cursor()
|
||||
table_name = 'stored_sessions'
|
||||
cursor.execute(f"PRAGMA table_info({table_name});")
|
||||
rows = cursor.fetchall()
|
||||
column_names = [row[1] for row in rows]
|
||||
assert len(column_names) > 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_stored_session(db_file, broker_context, db_session_factory):
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
async with db_session_factory() as db_session:
|
||||
async with db_session.begin():
|
||||
stored_session = await session_db_plugin._get_or_create_session(db_session, 'test_client_1')
|
||||
assert stored_session.client_id == 'test_client_1'
|
||||
|
||||
async with aiosqlite.connect(str(db_file)) as db:
|
||||
async with await db.execute("SELECT * FROM stored_sessions") as cursor:
|
||||
async for row in cursor:
|
||||
assert row[1] == 'test_client_1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stored_session(db_file, broker_context, db_session_factory):
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
async with aiosqlite.connect(str(db_file)) as db:
|
||||
sql = """INSERT INTO stored_sessions (
|
||||
client_id, clean_session, will_flag,
|
||||
will_qos, keep_alive,
|
||||
retained, subscriptions
|
||||
) VALUES (
|
||||
'test_client_1',
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
60,
|
||||
'[]',
|
||||
'[{"topic":"sensors/#","qos":1}]'
|
||||
)"""
|
||||
await db.execute(sql)
|
||||
await db.commit()
|
||||
|
||||
async with db_session_factory() as db_session:
|
||||
async with db_session.begin():
|
||||
stored_session = await session_db_plugin._get_or_create_session(db_session, 'test_client_1')
|
||||
assert stored_session.subscriptions == [Subscription(topic='sensors/#', qos=1)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_stored_session(db_file, broker_context, db_session_factory):
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
|
||||
# create session for client id (without subscription)
|
||||
await broker_context.add_subscription('test_client_1', None, None)
|
||||
|
||||
session = broker_context.get_session('test_client_1')
|
||||
assert session is not None
|
||||
session.clean_session = False
|
||||
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
# initialize with stored client session
|
||||
async with aiosqlite.connect(str(db_file)) as db:
|
||||
sql = """INSERT INTO stored_sessions (
|
||||
client_id, clean_session, will_flag,
|
||||
will_qos, keep_alive,
|
||||
retained, subscriptions
|
||||
) VALUES (
|
||||
'test_client_1',
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
60,
|
||||
'[]',
|
||||
'[{"topic":"sensors/#","qos":1}]'
|
||||
)"""
|
||||
await db.execute(sql)
|
||||
await db.commit()
|
||||
|
||||
await session_db_plugin.on_broker_client_subscribed(client_id='test_client_1', topic='my/topic', qos=2)
|
||||
|
||||
# verify that the stored session has been updated with the new subscription
|
||||
has_stored_session = False
|
||||
async with aiosqlite.connect(str(db_file)) as db:
|
||||
async with await db.execute("SELECT * FROM stored_sessions") as cursor:
|
||||
async for row in cursor:
|
||||
assert row[1] == 'test_client_1'
|
||||
assert row[-1] == '[{"topic": "sensors/#", "qos": 1}, {"topic": "my/topic", "qos": 2}]'
|
||||
has_stored_session = True
|
||||
assert has_stored_session, "stored session wasn't updated"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_connected_with_clean_session(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
session = Session()
|
||||
session.client_id = 'test_client_connected'
|
||||
session.is_anonymous = False
|
||||
session.clean_session = True
|
||||
|
||||
await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session)
|
||||
|
||||
async with aiosqlite.connect(str(db_file)) as db_conn:
|
||||
db_conn.row_factory = sqlite3.Row
|
||||
|
||||
async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor:
|
||||
assert len(await cursor.fetchall()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_connected_anonymous_session(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
session = Session()
|
||||
session.is_anonymous = True
|
||||
session.client_id = 'test_client_connected'
|
||||
|
||||
await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session)
|
||||
|
||||
async with aiosqlite.connect(str(db_file)) as db_conn:
|
||||
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
||||
async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor:
|
||||
assert len(await cursor.fetchall()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_connected_and_stored_session(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
session = Session()
|
||||
session.client_id = 'test_client_connected'
|
||||
session.is_anonymous = False
|
||||
session.clean_session = False
|
||||
session.will_flag = True
|
||||
session.will_qos = 1
|
||||
session.will_topic = 'my/will/topic'
|
||||
session.will_retain = False
|
||||
session.will_message = b'test connected client has a last will (and testament) message'
|
||||
session.keep_alive = 42
|
||||
|
||||
await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session)
|
||||
|
||||
has_stored_session = False
|
||||
async with aiosqlite.connect(str(db_file)) as db_conn:
|
||||
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
||||
|
||||
async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor:
|
||||
for row in await cursor.fetchall():
|
||||
assert row['client_id'] == 'test_client_connected'
|
||||
assert row['clean_session'] == False
|
||||
assert row['will_flag'] == True
|
||||
assert row['will_qos'] == 1
|
||||
assert row['will_topic'] == 'my/will/topic'
|
||||
assert row['will_retain'] == False
|
||||
assert row['will_message'] == b'test connected client has a last will (and testament) message'
|
||||
assert row['keep_alive'] == 42
|
||||
has_stored_session = True
|
||||
assert has_stored_session, "client session wasn't stored"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repopulate_stored_sessions(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
async with aiosqlite.connect(str(db_file)) as db:
|
||||
sql = """INSERT INTO stored_sessions (
|
||||
client_id, clean_session, will_flag,
|
||||
will_qos, keep_alive,
|
||||
retained, subscriptions
|
||||
) VALUES (
|
||||
'test_client_1',
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
60,
|
||||
'[{"topic":"sensors/#","data":"this message is retained when client reconnects","qos":1}]',
|
||||
'[{"topic":"sensors/#","qos":1}]'
|
||||
)"""
|
||||
await db.execute(sql)
|
||||
await db.commit()
|
||||
|
||||
await session_db_plugin.on_broker_post_start()
|
||||
|
||||
session = broker_context.get_session('test_client_1')
|
||||
assert session is not None
|
||||
assert session.retained_messages.qsize() == 1
|
||||
|
||||
assert 'sensors/#' in broker_context._broker_instance._subscriptions
|
||||
|
||||
# ugly: b/c _subscriptions is a list of dictionaries of tuples
|
||||
assert broker_context._broker_instance._subscriptions['sensors/#'][0][1] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_retained_message(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
# add a session to the broker
|
||||
await broker_context.add_subscription('test_retained_client', None, None)
|
||||
|
||||
# update the session so that it's retained
|
||||
session = broker_context.get_session('test_retained_client')
|
||||
assert session is not None
|
||||
session.is_anonymous = False
|
||||
session.clean_session = False
|
||||
session.transitions.disconnect()
|
||||
|
||||
retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'retain message for disconnected client', qos=2)
|
||||
await session_db_plugin.on_broker_retained_message(client_id='test_retained_client', retained_message=retained_message)
|
||||
|
||||
async with db_session_factory() as db_session:
|
||||
async with db_session.begin():
|
||||
stmt = select(StoredSession).filter(StoredSession.client_id == 'test_retained_client')
|
||||
stored_session = await db_session.scalar(stmt)
|
||||
assert stored_session is not None
|
||||
assert len(stored_session.retained) > 0
|
||||
assert RetainedMessage(topic='my/retained/topic', data='retained message', qos=2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_topic_retained_message(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file, clear_on_shutdown=False)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
# add a session to the broker
|
||||
await broker_context.add_subscription('test_retained_client', None, None)
|
||||
|
||||
# update the session so that it's retained
|
||||
session = broker_context.get_session('test_retained_client')
|
||||
assert session is not None
|
||||
session.is_anonymous = False
|
||||
session.clean_session = False
|
||||
session.transitions.disconnect()
|
||||
|
||||
retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'retained message', qos=2)
|
||||
await session_db_plugin.on_broker_retained_message(client_id=None, retained_message=retained_message)
|
||||
|
||||
has_stored_message = False
|
||||
async with aiosqlite.connect(str(db_file)) as db_conn:
|
||||
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
||||
|
||||
async with await db_conn.execute("SELECT * FROM stored_messages") as cursor:
|
||||
# assert(len(await cursor.fetchall()) > 0)
|
||||
for row in await cursor.fetchall():
|
||||
assert row['topic'] == 'my/retained/topic'
|
||||
assert row['data'] == b'retained message'
|
||||
assert row['qos'] == 2
|
||||
has_stored_message = True
|
||||
|
||||
assert has_stored_message, "retained topic message wasn't stored"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_topic_clear_retained_message(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
# add a session to the broker
|
||||
await broker_context.add_subscription('test_retained_client', None, None)
|
||||
|
||||
# update the session so that it's retained
|
||||
session = broker_context.get_session('test_retained_client')
|
||||
assert session is not None
|
||||
session.is_anonymous = False
|
||||
session.clean_session = False
|
||||
session.transitions.disconnect()
|
||||
|
||||
retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'', qos=0)
|
||||
await session_db_plugin.on_broker_retained_message(client_id=None, retained_message=retained_message)
|
||||
|
||||
async with aiosqlite.connect(str(db_file)) as db_conn:
|
||||
db_conn.row_factory = sqlite3.Row
|
||||
|
||||
async with await db_conn.execute("SELECT * FROM stored_messages") as cursor:
|
||||
assert(len(await cursor.fetchall()) == 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restoring_retained_message(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
stmts = ("INSERT INTO stored_messages VALUES(1,'my/retained/topic1',X'72657461696e6564206d657373616765',2)",
|
||||
"INSERT INTO stored_messages VALUES(2,'my/retained/topic2',X'72657461696e6564206d65737361676532',2)",
|
||||
"INSERT INTO stored_messages VALUES(3,'my/retained/topic3',X'72657461696e6564206d65737361676533',2)")
|
||||
|
||||
async with aiosqlite.connect(str(db_file)) as db:
|
||||
for stmt in stmts:
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
await session_db_plugin.on_broker_post_start()
|
||||
|
||||
assert len(broker_context.retained_messages) == 3
|
||||
assert 'my/retained/topic1' in broker_context.retained_messages
|
||||
assert 'my/retained/topic2' in broker_context.retained_messages
|
||||
assert 'my/retained/topic3' in broker_context.retained_messages
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_config(db_file):
|
||||
return {
|
||||
'listeners': {
|
||||
'default': {
|
||||
'type': 'tcp',
|
||||
'bind': '127.0.0.1:1883'
|
||||
}
|
||||
},
|
||||
'plugins': {
|
||||
'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False},
|
||||
'amqtt.contrib.persistence.SessionDBPlugin': {
|
||||
'file': db_file,
|
||||
'clear_on_shutdown': False
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_client_no_cleanup(db_file, db_config) -> None:
|
||||
|
||||
b1 = Broker(config=db_config)
|
||||
await b1.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
|
||||
await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False)
|
||||
|
||||
# test that this message is retained for the topic upon restore
|
||||
await c1.publish("my/retained/topic", b'retained message for topic my/retained/topic', retain=True)
|
||||
await asyncio.sleep(0.2)
|
||||
await c1.disconnect()
|
||||
await b1.shutdown()
|
||||
|
||||
# new broker should load the previous broker's db file since clean_on_shutdown is false in config
|
||||
b2 = Broker(config=db_config)
|
||||
await b2.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# upon subscribing to topic with retained message, it should be received
|
||||
c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False})
|
||||
await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False)
|
||||
await c2.subscribe([
|
||||
('my/retained/topic', QOS_1)
|
||||
])
|
||||
msg = await c2.deliver_message(timeout_duration=1)
|
||||
assert msg is not None
|
||||
assert msg.topic == "my/retained/topic"
|
||||
assert msg.data == b'retained message for topic my/retained/topic'
|
||||
|
||||
await c2.disconnect()
|
||||
await b2.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_client_retain_subscription(db_file, db_config) -> None:
|
||||
|
||||
b1 = Broker(config=db_config)
|
||||
await b1.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
|
||||
await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False)
|
||||
|
||||
# test to make sure the subscription is re-established upon reconnection after broker restart (clear_on_shutdown = False)
|
||||
ret = await c1.subscribe([
|
||||
('my/offline/topic', QOS_1)
|
||||
])
|
||||
assert ret == [QOS_1,]
|
||||
await asyncio.sleep(0.2)
|
||||
await c1.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await b1.shutdown()
|
||||
|
||||
# new broker should load the previous broker's db file
|
||||
b2 = Broker(config=db_config)
|
||||
await b2.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# client1's subscription should have been restored, so when it connects, it should receive this message
|
||||
c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False})
|
||||
await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False)
|
||||
await c2.publish('my/offline/topic', b'standard message to be retained for offline clients')
|
||||
await asyncio.sleep(0.1)
|
||||
await c2.disconnect()
|
||||
|
||||
await c1.reconnect(cleansession=False)
|
||||
await asyncio.sleep(0.1)
|
||||
msg = await c1.deliver_message(timeout_duration=2)
|
||||
assert msg is not None
|
||||
assert msg.topic == "my/offline/topic"
|
||||
assert msg.data == b'standard message to be retained for offline clients'
|
||||
|
||||
await c1.disconnect()
|
||||
await b2.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_client_retain_message(db_file, db_config) -> None:
|
||||
"""test to make sure that the retained message because client1 is offline,
|
||||
gets sent when back online after broker restart."""
|
||||
|
||||
b1 = Broker(config=db_config)
|
||||
await b1.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
|
||||
await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False)
|
||||
|
||||
# subscribe to a topic with QOS_1 so that we receive messages, even if we're disconnected when sent
|
||||
ret = await c1.subscribe([
|
||||
('my/offline/topic', QOS_1)
|
||||
])
|
||||
assert ret == [QOS_1,]
|
||||
await asyncio.sleep(0.2)
|
||||
# go offline
|
||||
await c1.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# another client sends a message to previously subscribed to topic
|
||||
c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False})
|
||||
await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False)
|
||||
|
||||
# this message should be delivered after broker stops and restarts (and client connects)
|
||||
await c2.publish('my/offline/topic', b'standard message to be retained for offline clients')
|
||||
await asyncio.sleep(0.1)
|
||||
await c2.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await b1.shutdown()
|
||||
|
||||
# new broker should load the previous broker's db file since we declared clear_on_shutdown = False in config
|
||||
b2 = Broker(config=db_config)
|
||||
await b2.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# when first client reconnects, it should receive the message that had been previously retained for it
|
||||
await c1.reconnect(cleansession=False)
|
||||
await asyncio.sleep(0.1)
|
||||
msg = await c1.deliver_message(timeout_duration=2)
|
||||
assert msg is not None
|
||||
assert msg.topic == "my/offline/topic"
|
||||
assert msg.data == b'standard message to be retained for offline clients'
|
||||
|
||||
# client should also receive a message if send on this topic
|
||||
await c1.publish("my/offline/topic", b"online message should also be received")
|
||||
await asyncio.sleep(0.1)
|
||||
msg = await c1.deliver_message(timeout_duration=2)
|
||||
assert msg is not None
|
||||
assert msg.topic == "my/offline/topic"
|
||||
assert msg.data == b'online message should also be received'
|
||||
|
||||
await c1.disconnect()
|
||||
await b2.shutdown()
|
|
@ -1,56 +0,0 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
import unittest
|
||||
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.plugins.persistence import SQLitePlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
|
||||
|
||||
class TestSQLitePlugin(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
def test_create_tables(self) -> None:
|
||||
dbfile = Path(__file__).resolve().parent / "test.db"
|
||||
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger(__name__)
|
||||
context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config
|
||||
SQLitePlugin(context)
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection
|
||||
cursor = conn.cursor()
|
||||
rows = cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table'")
|
||||
tables = [row[0] for row in rows] # List comprehension for brevity
|
||||
assert "session" in tables
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_save_session(self) -> None:
|
||||
dbfile = Path(__file__).resolve().parent / "test.db"
|
||||
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger(__name__)
|
||||
context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config
|
||||
sql_plugin = SQLitePlugin(context)
|
||||
|
||||
s = Session()
|
||||
s.client_id = "test_save_session"
|
||||
|
||||
self.loop.run_until_complete(sql_plugin.save_session(session=s))
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection
|
||||
cursor = conn.cursor()
|
||||
row = cursor.execute("SELECT client_id FROM session WHERE client_id = 'test_save_session'").fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == s.client_id
|
||||
finally:
|
||||
conn.close()
|
|
@ -1,12 +1,11 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from functools import partial
|
||||
from importlib.metadata import EntryPoint
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Coroutine
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -15,13 +14,15 @@ from amqtt.broker import Broker, BrokerContext
|
|||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import PluginInitError, PluginImportError
|
||||
from amqtt.events import MQTTEvents, BrokerEvents
|
||||
from amqtt.mqtt.constants import QOS_0
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.contexts import BaseContext
|
||||
from amqtt.contrib.persistence import RetainedMessage
|
||||
|
||||
_INVALID_METHOD: str = "invalid_foo"
|
||||
_PLUGIN: str = "Plugin"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class _TestContext(BaseContext):
|
||||
def __init__(self) -> None:
|
||||
|
@ -159,7 +160,7 @@ async def test_all_plugin_events():
|
|||
client = MQTTClient()
|
||||
await client.connect("mqtt://127.0.0.1:1883/")
|
||||
await client.subscribe([('my/test/topic', QOS_0),])
|
||||
await client.publish('test/topic', b'my test message')
|
||||
await client.publish('test/topic', b'my test message', retain=True)
|
||||
await client.unsubscribe(['my/test/topic',])
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(1)
|
||||
|
@ -170,3 +171,70 @@ async def test_all_plugin_events():
|
|||
await asyncio.sleep(1)
|
||||
|
||||
assert all(test_plugin.test_flags.values()), f'event not received: {[event for event, value in test_plugin.test_flags.items() if not value]}'
|
||||
|
||||
|
||||
class RetainedMessageEventPlugin(BasePlugin[BrokerContext]):
|
||||
"""A plugin to verify all events get sent to plugins."""
|
||||
def __init__(self, context: BaseContext) -> None:
|
||||
super().__init__(context)
|
||||
self.topic_retained_message_flag = False
|
||||
self.session_retained_message_flag = False
|
||||
self.topic_clear_retained_message_flag = False
|
||||
|
||||
async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedMessage) -> None:
|
||||
"""retaining message event handler."""
|
||||
if client_id:
|
||||
session = self.context.get_session(client_id)
|
||||
assert session.transitions.state != "connected"
|
||||
logger.debug("retained message event fired for offline client")
|
||||
self.session_retained_message_flag = True
|
||||
else:
|
||||
if not retained_message.data:
|
||||
logger.debug("retained message event fired for clearing a topic")
|
||||
self.topic_clear_retained_message_flag = True
|
||||
else:
|
||||
logger.debug("retained message event fired for setting a topic")
|
||||
self.topic_retained_message_flag = True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retained_message_plugin_event():
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1,
|
||||
'plugins':[{'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False}},
|
||||
{'tests.plugins.test_plugins.RetainedMessageEventPlugin': {}}]
|
||||
}
|
||||
|
||||
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
||||
await broker.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# make sure all expected events get triggered
|
||||
client1 = MQTTClient(config={'auto_reconnect': False})
|
||||
await client1.connect("mqtt://myUsername@127.0.0.1:1883/", cleansession=False)
|
||||
await client1.subscribe([('test/topic', QOS_1),])
|
||||
await client1.publish('test/retained', b'message should be retained for test/retained', retain=True)
|
||||
await asyncio.sleep(0.1)
|
||||
await client1.disconnect()
|
||||
|
||||
client2 = MQTTClient(config={'auto_reconnect': False})
|
||||
await client2.connect("mqtt://myOtherUsername@127.0.0.1:1883/", cleansession=True)
|
||||
await client2.publish('test/topic', b'message should be retained for myUsername since subscription was qos > 0')
|
||||
await client2.publish('test/retained', b'', retain=True) # should clear previously retained message
|
||||
await asyncio.sleep(0.1)
|
||||
await client2.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# get the plugin so it doesn't get gc on shutdown
|
||||
test_plugin = broker.plugins_manager.get_plugin('RetainedMessageEventPlugin')
|
||||
await broker.shutdown()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert test_plugin.topic_retained_message_flag, "message to topic wasn't retained"
|
||||
assert test_plugin.session_retained_message_flag, "message to disconnected client wasn't retained"
|
||||
assert test_plugin.topic_clear_retained_message_flag, "message to retained topic wasn't cleared"
|
||||
|
|
|
@ -15,20 +15,22 @@ logging.basicConfig(level=logging.ERROR, format=formatter)
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_connect_tcp():
|
||||
# client = MQTTClient()
|
||||
# await client.connect("mqtt://broker.hivemq.com:1883/")
|
||||
# assert client.session is not None
|
||||
# await client.disconnect()
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_connect_tcp_secure(ca_file_fixture):
|
||||
# client = MQTTClient(config={"check_hostname": False})
|
||||
# await client.connect("mqtts://broker.hivemq.com:8883/")
|
||||
# assert client.session is not None
|
||||
# await client.disconnect()
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_tcp(broker_fixture):
|
||||
client = MQTTClient()
|
||||
await client.connect("mqtt://localhost:1883/")
|
||||
assert client.session is not None
|
||||
await client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_tcp_secure(rsa_keys, broker_fixture):
|
||||
certfile, _ = rsa_keys
|
||||
client = MQTTClient(config={"check_hostname": False, "auto_reconnect": False})
|
||||
|
||||
# since we're using a self-signed certificate, need to provide the server's certificate to verify authenticity
|
||||
await client.connect("mqtts://localhost:1884/", cafile=certfile)
|
||||
assert client.session is not None
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -62,9 +64,11 @@ async def test_reconnect_ws_retain_username_password(broker_fixture):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_ws_secure(ca_file_fixture, broker_fixture):
|
||||
client = MQTTClient()
|
||||
await client.connect("ws://127.0.0.1:8081/", cafile=ca_file_fixture)
|
||||
async def test_connect_ws_secure(rsa_keys, broker_fixture):
|
||||
certfile, _ = rsa_keys
|
||||
client = MQTTClient(config={"auto_reconnect": False})
|
||||
# since we're using a self-signed certificate, need to provide the server's certificate to verify authenticity
|
||||
await client.connect("wss://localhost:8081/", cafile=certfile)
|
||||
assert client.session is not None
|
||||
await client.disconnect()
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ async def test_broker_acl():
|
|||
broker_acl_script = Path(__file__).parent.parent / "samples/broker_acl.py"
|
||||
process = subprocess.Popen(["python", broker_acl_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
# Send the interrupt signal
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(2)
|
||||
process.send_signal(signal.SIGINT)
|
||||
stdout, stderr = process.communicate()
|
||||
logger.debug(stderr.decode("utf-8"))
|
||||
|
@ -34,7 +34,7 @@ async def test_broker_acl():
|
|||
async def test_broker_simple():
|
||||
broker_simple_script = Path(__file__).parent.parent / "samples/broker_simple.py"
|
||||
process = subprocess.Popen(["python", broker_simple_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Send the interrupt signal
|
||||
process.send_signal(signal.SIGINT)
|
||||
|
@ -50,7 +50,7 @@ async def test_broker_simple():
|
|||
async def test_broker_start():
|
||||
broker_start_script = Path(__file__).parent.parent / "samples/broker_start.py"
|
||||
process = subprocess.Popen(["python", broker_start_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Send the interrupt signal to stop broker
|
||||
process.send_signal(signal.SIGINT)
|
||||
|
@ -65,7 +65,7 @@ async def test_broker_start():
|
|||
async def test_broker_taboo():
|
||||
broker_taboo_script = Path(__file__).parent.parent / "samples/broker_taboo.py"
|
||||
process = subprocess.Popen(["python", broker_taboo_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Send the interrupt signal to stop broker
|
||||
process.send_signal(signal.SIGINT)
|
||||
|
@ -86,7 +86,7 @@ async def test_client_keepalive():
|
|||
|
||||
keep_alive_script = Path(__file__).parent.parent / "samples/client_keepalive.py"
|
||||
process = subprocess.Popen(["python", keep_alive_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
await asyncio.sleep(2)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
stdout, stderr = process.communicate()
|
||||
assert "ERROR" not in stderr.decode("utf-8")
|
||||
|
@ -111,28 +111,30 @@ async def test_client_publish():
|
|||
|
||||
await broker.shutdown()
|
||||
|
||||
broker_ssl_config = {
|
||||
"listeners": {
|
||||
"default": {
|
||||
"type": "tcp",
|
||||
"bind": "0.0.0.0:8883",
|
||||
"ssl": True,
|
||||
"certfile": "cert.pem",
|
||||
"keyfile": "key.pem",
|
||||
}
|
||||
},
|
||||
"auth": {
|
||||
"allow-anonymous": True,
|
||||
"plugins": ["auth_anonymous"]
|
||||
|
||||
@pytest.fixture
|
||||
def broker_ssl_config(rsa_keys):
|
||||
certfile, keyfile = rsa_keys
|
||||
return {
|
||||
"listeners": {
|
||||
"default": {
|
||||
"type": "tcp",
|
||||
"bind": "0.0.0.0:8883",
|
||||
"ssl": True,
|
||||
"certfile": certfile,
|
||||
"keyfile": keyfile,
|
||||
}
|
||||
}
|
||||
},
|
||||
"auth": {
|
||||
"allow-anonymous": True,
|
||||
"plugins": ["auth_anonymous"]
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_ssl():
|
||||
|
||||
async def test_client_publish_ssl(broker_ssl_config, rsa_keys):
|
||||
certfile, _ = rsa_keys
|
||||
# generate a self-signed certificate for this test
|
||||
cmd = 'openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout key.pem -out cert.pem -subj "/CN=localhost"'
|
||||
subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# start a secure broker
|
||||
broker = Broker(config=broker_ssl_config)
|
||||
|
@ -140,7 +142,7 @@ async def test_client_publish_ssl():
|
|||
await asyncio.sleep(2)
|
||||
# run the sample
|
||||
client_publish_ssl_script = Path(__file__).parent.parent / "samples/client_publish_ssl.py"
|
||||
process = subprocess.Popen(["python", client_publish_ssl_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
process = subprocess.Popen(["python", client_publish_ssl_script, '--cert', certfile], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
await asyncio.sleep(2)
|
||||
stdout, stderr = process.communicate()
|
||||
|
||||
|
@ -179,7 +181,7 @@ broker_ws_config = {
|
|||
"auth": {
|
||||
"allow-anonymous": True,
|
||||
"plugins": ["auth_anonymous"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -205,13 +207,14 @@ broker_std_config = {
|
|||
"listeners": {
|
||||
"default": {
|
||||
"type": "tcp",
|
||||
"bind": "0.0.0.0:1883", }
|
||||
"bind": "0.0.0.0:1883",
|
||||
}
|
||||
},
|
||||
'sys_interval':2,
|
||||
"auth": {
|
||||
"allow-anonymous": True,
|
||||
"plugins": ["auth_anonymous"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -241,6 +244,7 @@ async def test_client_subscribe():
|
|||
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_subscribe_plugin_acl():
|
||||
broker = Broker(config=broker_acl_config)
|
||||
|
@ -249,7 +253,7 @@ async def test_client_subscribe_plugin_acl():
|
|||
broker_simple_script = Path(__file__).parent.parent / "samples/client_subscribe_acl.py"
|
||||
process = subprocess.Popen(["python", broker_simple_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
# Send the interrupt signal
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(2)
|
||||
process.send_signal(signal.SIGINT)
|
||||
stdout, stderr = process.communicate()
|
||||
logger.debug(stderr.decode("utf-8"))
|
||||
|
@ -268,7 +272,7 @@ async def test_client_subscribe_plugin_taboo():
|
|||
broker_simple_script = Path(__file__).parent.parent / "samples/client_subscribe_acl.py"
|
||||
process = subprocess.Popen(["python", broker_simple_script], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
# Send the interrupt signal
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(2)
|
||||
process.send_signal(signal.SIGINT)
|
||||
stdout, stderr = process.communicate()
|
||||
logger.debug(stderr.decode("utf-8"))
|
||||
|
|
18
uv.lock
18
uv.lock
|
@ -148,8 +148,10 @@ ci = [
|
|||
]
|
||||
contrib = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "argon2-cffi" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "greenlet" },
|
||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
|
@ -171,6 +173,7 @@ dev = [
|
|||
{ name = "pytest-timeout" },
|
||||
{ name = "ruff" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "sqlalchemy", extra = ["mypy"] },
|
||||
{ name = "types-mock" },
|
||||
{ name = "types-pyyaml" },
|
||||
{ name = "types-setuptools" },
|
||||
|
@ -197,13 +200,15 @@ docs = [
|
|||
requires-dist = [
|
||||
{ name = "aiohttp", specifier = ">=3.12.7" },
|
||||
{ name = "aiohttp", marker = "extra == 'contrib'", specifier = ">=3.12.13" },
|
||||
{ name = "aiosqlite", marker = "extra == 'contrib'", specifier = ">=0.21.0" },
|
||||
{ name = "argon2-cffi", marker = "extra == 'contrib'", specifier = ">=25.1.0" },
|
||||
{ name = "coveralls", marker = "extra == 'ci'", specifier = "==4.0.1" },
|
||||
{ name = "dacite", specifier = ">=1.9.2" },
|
||||
{ name = "greenlet", marker = "extra == 'contrib'", specifier = ">=3.2.3" },
|
||||
{ name = "passlib", specifier = "==1.7.4" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
{ name = "pyyaml", specifier = "==6.0.2" },
|
||||
{ name = "sqlalchemy", marker = "extra == 'contrib'", specifier = ">=2.0.41" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'contrib'", specifier = ">=2.0.41" },
|
||||
{ name = "transitions", specifier = "==0.9.2" },
|
||||
{ name = "typer", specifier = "==0.15.4" },
|
||||
{ name = "websockets", specifier = "==15.0.1" },
|
||||
|
@ -229,6 +234,7 @@ dev = [
|
|||
{ name = "pytest-timeout", specifier = ">=2.3.1" },
|
||||
{ name = "ruff", specifier = ">=0.11.3" },
|
||||
{ name = "setuptools", specifier = ">=78.1.0" },
|
||||
{ name = "sqlalchemy", extras = ["mypy"], specifier = ">=2.0.41" },
|
||||
{ name = "types-mock", specifier = ">=5.2.0.20250306" },
|
||||
{ name = "types-pyyaml", specifier = ">=6.0.12.20250402" },
|
||||
{ name = "types-setuptools", specifier = ">=78.1.0.20250329" },
|
||||
|
@ -2415,6 +2421,14 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
asyncio = [
|
||||
{ name = "greenlet" },
|
||||
]
|
||||
mypy = [
|
||||
{ name = "mypy" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
|
|
Ładowanie…
Reference in New Issue