additional test cases for retained topic messages

pull/256/head
Andrew Mirsky 2025-07-07 12:09:47 -04:00
rodzic dc9816a54a
commit b649ce406d
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
3 zmienionych plików z 51 dodań i 69 usunięć

Wyświetl plik

@ -108,9 +108,6 @@ class BrokerContext(BaseContext):
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
await self._broker_instance.internal_message_broadcast(topic, data, qos)
async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
await self._broker_instance.retain_message(None, topic_name, data, qos)
@property
def sessions(self) -> Generator[Session]:
for session in self._broker_instance.sessions.values():

Wyświetl plik

@ -265,6 +265,22 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
await session.retained_messages.put(retained_message)
restored_sessions += 1
stmt = select(StoredMessage)
stored_messages = await db_session.execute(stmt)
restored_messages = 0
retained_messages = self.context.retained_messages
for stored_message in stored_messages.scalars():
retained_messages[stored_message.topic] = (RetainedApplicationMessage(
source_session=None,
topic=stored_message.topic,
data=stored_message.data,
qos=stored_message.qos
))
restored_messages += 1
logger.info(f"Retained messages restored: {restored_messages}")
logger.info(f"Restored {restored_sessions} sessions.")
async def on_broker_pre_shutdown(self) -> None:
@ -280,7 +296,7 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
class Config:
"""Configuration variables."""
file: str | Path = "amqtt.sqlite3"
file: str | Path = "amqtt.db"
retain_interval: int = 5
clear_on_shutdown: bool = True

Wyświetl plik

@ -9,6 +9,8 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy import select
from amqtt.broker import Broker, BrokerContext, RetainedApplicationMessage
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_0, QOS_1
from amqtt.plugins.persistence import SessionDBPlugin, Subscription, StoredSession, RetainedMessage, \
StoredMessage
from amqtt.session import Session
@ -20,7 +22,7 @@ logger = logging.getLogger(__name__)
@pytest.fixture
async def db_file():
db_file = Path(__file__).parent / "amqtt.lite"
db_file = Path(__file__).parent / "amqtt.db"
yield db_file
if db_file.exists():
db_file.unlink()
@ -297,7 +299,7 @@ async def test_client_retained_message(db_file, broker_context, db_session_facto
@pytest.mark.asyncio
async def test_topic_retained_message(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
broker_context.config = SessionDBPlugin.Config(file=db_file, clear_on_shutdown=False)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
@ -356,11 +358,32 @@ async def test_topic_clear_retained_message(db_file, broker_context, db_session_
assert(len(await cursor.fetchall()) == 0)
@pytest.mark.asyncio
async def test_restoring_retained_message(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
stmts = ("INSERT INTO stored_messages VALUES(1,'my/retained/topic1',X'72657461696e6564206d657373616765',2)",
"INSERT INTO stored_messages VALUES(2,'my/retained/topic2',X'72657461696e6564206d65737361676532',2)",
"INSERT INTO stored_messages VALUES(3,'my/retained/topic3',X'72657461696e6564206d65737361676533',2)")
async with aiosqlite.connect(str(db_file)) as db:
for stmt in stmts:
await db.execute(stmt)
await db.commit()
await session_db_plugin.on_broker_post_start()
assert len(broker_context.retained_messages) == 3
assert 'my/retained/topic1' in broker_context.retained_messages
assert 'my/retained/topic2' in broker_context.retained_messages
assert 'my/retained/topic3' in broker_context.retained_messages
# @pytest.mark.asyncio
# async def test_create_stored_session() -> None:
# async def test_full_broker_and_client() -> None:
#
# cfg = {
# 'listeners': {
@ -370,8 +393,10 @@ async def test_topic_clear_retained_message(db_file, broker_context, db_session_
# }
# },
# 'plugins': {
# 'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow-anonymous': True},
# 'amqtt.plugins.persistence.SessionDBPlugin': {}
# 'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False},
# 'amqtt.plugins.persistence.SessionDBPlugin': {
# 'clean_on_shutdown': False,
# }
# }
# }
#
@ -379,64 +404,8 @@ async def test_topic_clear_retained_message(db_file, broker_context, db_session_
# await b.start()
# await asyncio.sleep(1)
#
# c = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
# await c.connect(cleansession=False)
# await c.subscribe(
# [
# ('my/topic', QOS_0)
# ]
# )
# c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
# await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False)
#
# await c.disconnect()
# await asyncio.sleep(2)
# await b.shutdown()
# await asyncio.sleep(1)
"""
def test_create_tables(self) -> None:
dbfile = Path(__file__).resolve().parent / "test.db"
context = BaseContext()
context.logger = logging.getLogger(__name__)
context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config
SQLitePlugin(context)
try:
conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection
cursor = conn.cursor()
rows = cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table'")
tables = [row[0] for row in rows] # List comprehension for brevity
assert "session" in tables
finally:
conn.close()
def test_save_session(self) -> None:
dbfile = Path(__file__).resolve().parent / "test.db"
context = BaseContext()
context.logger = logging.getLogger(__name__)
context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config
sql_plugin = SQLitePlugin(context)
s = Session()
s.client_id = "test_save_session"
self.loop.run_until_complete(sql_plugin.save_session(session=s))
try:
conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection
cursor = conn.cursor()
row = cursor.execute("SELECT client_id FROM session WHERE client_id = 'test_save_session'").fetchone()
assert row is not None
assert row[0] == s.client_id
finally:
conn.close()
"""
# await c1.publish("my/topic", b'my retained message', retain=True)
# await c1.disconnect()