kopia lustrzana https://github.com/Yakifo/amqtt
additional test cases for retained topic messages
rodzic
dc9816a54a
commit
b649ce406d
|
@ -108,9 +108,6 @@ class BrokerContext(BaseContext):
|
|||
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
|
||||
await self._broker_instance.internal_message_broadcast(topic, data, qos)
|
||||
|
||||
async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
|
||||
await self._broker_instance.retain_message(None, topic_name, data, qos)
|
||||
|
||||
@property
|
||||
def sessions(self) -> Generator[Session]:
|
||||
for session in self._broker_instance.sessions.values():
|
||||
|
|
|
@ -265,6 +265,22 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
await session.retained_messages.put(retained_message)
|
||||
restored_sessions += 1
|
||||
|
||||
stmt = select(StoredMessage)
|
||||
stored_messages = await db_session.execute(stmt)
|
||||
|
||||
restored_messages = 0
|
||||
retained_messages = self.context.retained_messages
|
||||
for stored_message in stored_messages.scalars():
|
||||
retained_messages[stored_message.topic] = (RetainedApplicationMessage(
|
||||
source_session=None,
|
||||
topic=stored_message.topic,
|
||||
data=stored_message.data,
|
||||
qos=stored_message.qos
|
||||
))
|
||||
restored_messages += 1
|
||||
logger.info(f"Retained messages restored: {restored_messages}")
|
||||
|
||||
|
||||
logger.info(f"Restored {restored_sessions} sessions.")
|
||||
|
||||
async def on_broker_pre_shutdown(self) -> None:
|
||||
|
@ -280,7 +296,7 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
class Config:
|
||||
"""Configuration variables."""
|
||||
|
||||
file: str | Path = "amqtt.sqlite3"
|
||||
file: str | Path = "amqtt.db"
|
||||
retain_interval: int = 5
|
||||
clear_on_shutdown: bool = True
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|||
from sqlalchemy import select
|
||||
|
||||
from amqtt.broker import Broker, BrokerContext, RetainedApplicationMessage
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1
|
||||
from amqtt.plugins.persistence import SessionDBPlugin, Subscription, StoredSession, RetainedMessage, \
|
||||
StoredMessage
|
||||
from amqtt.session import Session
|
||||
|
@ -20,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@pytest.fixture
|
||||
async def db_file():
|
||||
db_file = Path(__file__).parent / "amqtt.lite"
|
||||
db_file = Path(__file__).parent / "amqtt.db"
|
||||
yield db_file
|
||||
if db_file.exists():
|
||||
db_file.unlink()
|
||||
|
@ -297,7 +299,7 @@ async def test_client_retained_message(db_file, broker_context, db_session_facto
|
|||
@pytest.mark.asyncio
|
||||
async def test_topic_retained_message(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file, clear_on_shutdown=False)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
|
@ -356,11 +358,32 @@ async def test_topic_clear_retained_message(db_file, broker_context, db_session_
|
|||
assert(len(await cursor.fetchall()) == 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restoring_retained_message(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
stmts = ("INSERT INTO stored_messages VALUES(1,'my/retained/topic1',X'72657461696e6564206d657373616765',2)",
|
||||
"INSERT INTO stored_messages VALUES(2,'my/retained/topic2',X'72657461696e6564206d65737361676532',2)",
|
||||
"INSERT INTO stored_messages VALUES(3,'my/retained/topic3',X'72657461696e6564206d65737361676533',2)")
|
||||
|
||||
async with aiosqlite.connect(str(db_file)) as db:
|
||||
for stmt in stmts:
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
await session_db_plugin.on_broker_post_start()
|
||||
|
||||
assert len(broker_context.retained_messages) == 3
|
||||
assert 'my/retained/topic1' in broker_context.retained_messages
|
||||
assert 'my/retained/topic2' in broker_context.retained_messages
|
||||
assert 'my/retained/topic3' in broker_context.retained_messages
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_create_stored_session() -> None:
|
||||
# async def test_full_broker_and_client() -> None:
|
||||
#
|
||||
# cfg = {
|
||||
# 'listeners': {
|
||||
|
@ -370,8 +393,10 @@ async def test_topic_clear_retained_message(db_file, broker_context, db_session_
|
|||
# }
|
||||
# },
|
||||
# 'plugins': {
|
||||
# 'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow-anonymous': True},
|
||||
# 'amqtt.plugins.persistence.SessionDBPlugin': {}
|
||||
# 'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False},
|
||||
# 'amqtt.plugins.persistence.SessionDBPlugin': {
|
||||
# 'clean_on_shutdown': False,
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
#
|
||||
|
@ -379,64 +404,8 @@ async def test_topic_clear_retained_message(db_file, broker_context, db_session_
|
|||
# await b.start()
|
||||
# await asyncio.sleep(1)
|
||||
#
|
||||
# c = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
|
||||
# await c.connect(cleansession=False)
|
||||
# await c.subscribe(
|
||||
# [
|
||||
# ('my/topic', QOS_0)
|
||||
# ]
|
||||
# )
|
||||
# c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False})
|
||||
# await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False)
|
||||
#
|
||||
# await c.disconnect()
|
||||
# await asyncio.sleep(2)
|
||||
# await b.shutdown()
|
||||
# await asyncio.sleep(1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def test_create_tables(self) -> None:
|
||||
dbfile = Path(__file__).resolve().parent / "test.db"
|
||||
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger(__name__)
|
||||
context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config
|
||||
SQLitePlugin(context)
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection
|
||||
cursor = conn.cursor()
|
||||
rows = cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table'")
|
||||
tables = [row[0] for row in rows] # List comprehension for brevity
|
||||
assert "session" in tables
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_save_session(self) -> None:
|
||||
dbfile = Path(__file__).resolve().parent / "test.db"
|
||||
|
||||
context = BaseContext()
|
||||
context.logger = logging.getLogger(__name__)
|
||||
context.config = {"persistence": {"file": str(dbfile)}} # Ensure string path for config
|
||||
sql_plugin = SQLitePlugin(context)
|
||||
|
||||
s = Session()
|
||||
s.client_id = "test_save_session"
|
||||
|
||||
self.loop.run_until_complete(sql_plugin.save_session(session=s))
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(dbfile)) # Convert Path to string for sqlite connection
|
||||
cursor = conn.cursor()
|
||||
row = cursor.execute("SELECT client_id FROM session WHERE client_id = 'test_save_session'").fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == s.client_id
|
||||
finally:
|
||||
conn.close()
|
||||
"""
|
||||
# await c1.publish("my/topic", b'my retained message', retain=True)
|
||||
# await c1.disconnect()
|
||||
|
|
Ładowanie…
Reference in New Issue