From b649ce406d76116b1456bdece702e32b641a26d4 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Mon, 7 Jul 2025 12:09:47 -0400 Subject: [PATCH] additional test cases for retained topic messages --- amqtt/broker.py | 3 - amqtt/plugins/persistence.py | 18 +++++- tests/plugins/test_persistence.py | 99 +++++++++++-------------------- 3 files changed, 51 insertions(+), 69 deletions(-) diff --git a/amqtt/broker.py b/amqtt/broker.py index 1435f4b..80245ab 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -108,9 +108,6 @@ class BrokerContext(BaseContext): async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None: await self._broker_instance.internal_message_broadcast(topic, data, qos) - async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None: - await self._broker_instance.retain_message(None, topic_name, data, qos) - @property def sessions(self) -> Generator[Session]: for session in self._broker_instance.sessions.values(): diff --git a/amqtt/plugins/persistence.py b/amqtt/plugins/persistence.py index 8fdb655..10ace15 100644 --- a/amqtt/plugins/persistence.py +++ b/amqtt/plugins/persistence.py @@ -265,6 +265,22 @@ class SessionDBPlugin(BasePlugin[BrokerContext]): await session.retained_messages.put(retained_message) restored_sessions += 1 + stmt = select(StoredMessage) + stored_messages = await db_session.execute(stmt) + + restored_messages = 0 + retained_messages = self.context.retained_messages + for stored_message in stored_messages.scalars(): + retained_messages[stored_message.topic] = (RetainedApplicationMessage( + source_session=None, + topic=stored_message.topic, + data=stored_message.data, + qos=stored_message.qos + )) + restored_messages += 1 + logger.info(f"Retained messages restored: {restored_messages}") + + logger.info(f"Restored {restored_sessions} sessions.") async def on_broker_pre_shutdown(self) -> None: @@ -280,7 +296,7 @@ class SessionDBPlugin(BasePlugin[BrokerContext]): class Config: """Configuration variables.""" - file: str | Path = "amqtt.sqlite3" + file: str | Path = "amqtt.db" retain_interval: int = 5 clear_on_shutdown: bool = True diff --git a/tests/plugins/test_persistence.py b/tests/plugins/test_persistence.py index edebc25..230661f 100644 --- a/tests/plugins/test_persistence.py +++ b/tests/plugins/test_persistence.py @@ -9,6 +9,8 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy import select from amqtt.broker import Broker, BrokerContext, RetainedApplicationMessage +from amqtt.client import MQTTClient +from amqtt.mqtt.constants import QOS_0, QOS_1 from amqtt.plugins.persistence import SessionDBPlugin, Subscription, StoredSession, RetainedMessage, \ StoredMessage from amqtt.session import Session @@ -20,7 +22,7 @@ logger = logging.getLogger(__name__) @pytest.fixture async def db_file(): - db_file = Path(__file__).parent / "amqtt.lite" + db_file = Path(__file__).parent / "amqtt.db" yield db_file if db_file.exists(): db_file.unlink() @@ -297,7 +299,7 @@ async def test_client_retained_message(db_file, broker_context, db_session_facto @pytest.mark.asyncio async def test_topic_retained_message(db_file, broker_context, db_session_factory) -> None: - broker_context.config = SessionDBPlugin.Config(file=db_file) + broker_context.config = SessionDBPlugin.Config(file=db_file, clear_on_shutdown=False) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() @@ -356,11 +358,32 @@ async def test_topic_clear_retained_message(db_file, broker_context, db_session_ assert(len(await cursor.fetchall()) == 0) +@pytest.mark.asyncio +async def test_restoring_retained_message(db_file, broker_context, db_session_factory) -> None: + broker_context.config = SessionDBPlugin.Config(file=db_file) + session_db_plugin = SessionDBPlugin(broker_context) + await session_db_plugin.on_broker_pre_start() + + stmts = ("INSERT INTO stored_messages VALUES(1,'my/retained/topic1',X'72657461696e6564206d657373616765',2)", + "INSERT INTO stored_messages VALUES(2,'my/retained/topic2',X'72657461696e6564206d65737361676532',2)", + "INSERT INTO stored_messages VALUES(3,'my/retained/topic3',X'72657461696e6564206d65737361676533',2)") + + async with aiosqlite.connect(str(db_file)) as db: + for stmt in stmts: + await db.execute(stmt) + await db.commit() + + await session_db_plugin.on_broker_post_start() + + assert len(broker_context.retained_messages) == 3 + assert 'my/retained/topic1' in broker_context.retained_messages + assert 'my/retained/topic2' in broker_context.retained_messages + assert 'my/retained/topic3' in broker_context.retained_messages # @pytest.mark.asyncio -# async def test_create_stored_session() -> None: +# async def test_full_broker_and_client() -> None: # # cfg = { # 'listeners': { @@ -370,8 +393,10 @@ async def test_topic_clear_retained_message(db_file, broker_context, db_session_ # } # }, # 'plugins': { -# 'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow-anonymous': True}, -# 'amqtt.plugins.persistence.SessionDBPlugin': {} +# 'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False}, +# 'amqtt.plugins.persistence.SessionDBPlugin': { +# 'clean_on_shutdown': False, +# } # } # } # @@ -379,64 +404,8 @@ async def test_topic_clear_retained_message(db_file, broker_context, db_session_ # await b.start() # await asyncio.sleep(1) # -# c = MQTTClient(client_id='test_client1', config={'auto_reconnect':False}) -# await c.connect(cleansession=False) -# await c.subscribe( -# [ -# ('my/topic', QOS_0) -# ] -# ) +# c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False}) +# await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False) # -# await c.disconnect() -# await asyncio.sleep(2) -# await b.shutdown() -# await asyncio.sleep(1) - - - - - - - - -""" - -def test_create_tables(self) -> None: - dbfile = Path(__file__).resolve().parent / "test.db" - - context = BaseContext() - context.logger = logging.getLogger(__name__) - context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config - SQLitePlugin(context) - - try: - conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection - cursor = conn.cursor() - rows = cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table'") - tables = [row[0] for row in rows] # List comprehension for brevity - assert "session" in tables - finally: - conn.close() - -def test_save_session(self) -> None: - dbfile = Path(__file__).resolve().parent / "test.db" - - context = BaseContext() - context.logger = logging.getLogger(__name__) - context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config - sql_plugin = SQLitePlugin(context) - - s = Session() - s.client_id = "test_save_session" - - self.loop.run_until_complete(sql_plugin.save_session(session=s)) - - try: - conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection - cursor = conn.cursor() - row = cursor.execute("SELECT client_id FROM session WHERE client_id = 'test_save_session'").fetchone() - assert row is not None - assert row[0] == s.client_id - finally: - conn.close() -""" \ No newline at end of file +# await c1.publish("my/topic", b'my retained message', retain=True) +# await c1.disconnect()