import asyncio import logging from pathlib import Path import sqlite3 import pytest import aiosqlite from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy import select from amqtt.broker import Broker, BrokerContext, RetainedApplicationMessage from amqtt.client import MQTTClient from amqtt.mqtt.constants import QOS_1 from amqtt.contrib.persistence import SessionDBPlugin, Subscription, StoredSession, RetainedMessage from amqtt.session import Session formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" logging.basicConfig(level=logging.DEBUG, format=formatter) logger = logging.getLogger(__name__) @pytest.fixture async def db_file(): db_file = Path(__file__).parent / "amqtt.db" if db_file.exists(): raise NotImplementedError("existing db file found, should it be cleaned up?") yield db_file if db_file.exists(): db_file.unlink() @pytest.fixture async def broker_context(): cfg = { 'listeners': { 'default': {'type': 'tcp', 'bind': 'localhost:1883' }}, 'plugins': {} } context = BrokerContext(broker=Broker(config=cfg)) yield context @pytest.fixture async def db_session_factory(db_file): engine = create_async_engine(f"sqlite+aiosqlite:///{str(db_file)}") factory = async_sessionmaker(engine, expire_on_commit=False) yield factory @pytest.mark.asyncio async def test_initialize_tables(db_file, broker_context): broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() assert db_file.exists() conn = sqlite3.connect(str(db_file)) cursor = conn.cursor() table_name = 'stored_sessions' cursor.execute(f"PRAGMA table_info({table_name});") rows = cursor.fetchall() column_names = [row[1] for row in rows] assert len(column_names) > 1 @pytest.mark.asyncio async def test_create_stored_session(db_file, broker_context, db_session_factory): broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() async with db_session_factory() as db_session: async with db_session.begin(): stored_session = await session_db_plugin._get_or_create_session(db_session, 'test_client_1') assert stored_session.client_id == 'test_client_1' async with aiosqlite.connect(str(db_file)) as db: async with await db.execute("SELECT * FROM stored_sessions") as cursor: async for row in cursor: assert row[1] == 'test_client_1' @pytest.mark.asyncio async def test_get_stored_session(db_file, broker_context, db_session_factory): broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() async with aiosqlite.connect(str(db_file)) as db: sql = """INSERT INTO stored_sessions ( client_id, clean_session, will_flag, will_qos, keep_alive, retained, subscriptions ) VALUES ( 'test_client_1', 1, 0, 1, 60, '[]', '[{"topic":"sensors/#","qos":1}]' )""" await db.execute(sql) await db.commit() async with db_session_factory() as db_session: async with db_session.begin(): stored_session = await session_db_plugin._get_or_create_session(db_session, 'test_client_1') assert stored_session.subscriptions == [Subscription(topic='sensors/#', qos=1)] @pytest.mark.asyncio async def test_update_stored_session(db_file, broker_context, db_session_factory): broker_context.config = SessionDBPlugin.Config(file=db_file) # create session for client id (without subscription) await broker_context.add_subscription('test_client_1', None, None) session = broker_context.get_session('test_client_1') assert session is not None session.clean_session = False session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() # initialize with stored client session async with aiosqlite.connect(str(db_file)) as db: sql = """INSERT INTO stored_sessions ( client_id, clean_session, will_flag, will_qos, keep_alive, retained, subscriptions ) VALUES ( 'test_client_1', 1, 0, 1, 60, '[]', '[{"topic":"sensors/#","qos":1}]' )""" await db.execute(sql) await db.commit() await session_db_plugin.on_broker_client_subscribed(client_id='test_client_1', topic='my/topic', qos=2) # verify that the stored session has been updated with the new subscription has_stored_session = False async with aiosqlite.connect(str(db_file)) as db: async with await db.execute("SELECT * FROM stored_sessions") as cursor: async for row in cursor: assert row[1] == 'test_client_1' assert row[-1] == '[{"topic": "sensors/#", "qos": 1}, {"topic": "my/topic", "qos": 2}]' has_stored_session = True assert has_stored_session, "stored session wasn't updated" @pytest.mark.asyncio async def test_client_connected_with_clean_session(db_file, broker_context, db_session_factory) -> None: broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() session = Session() session.client_id = 'test_client_connected' session.is_anonymous = False session.clean_session = True await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session) async with aiosqlite.connect(str(db_file)) as db_conn: db_conn.row_factory = sqlite3.Row async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor: assert len(await cursor.fetchall()) == 0 @pytest.mark.asyncio async def test_client_connected_anonymous_session(db_file, broker_context, db_session_factory) -> None: broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() session = Session() session.is_anonymous = True session.client_id = 'test_client_connected' await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session) async with aiosqlite.connect(str(db_file)) as db_conn: db_conn.row_factory = sqlite3.Row # Set the row_factory async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor: assert len(await cursor.fetchall()) == 0 @pytest.mark.asyncio async def test_client_connected_and_stored_session(db_file, broker_context, db_session_factory) -> None: broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() session = Session() session.client_id = 'test_client_connected' session.is_anonymous = False session.clean_session = False session.will_flag = True session.will_qos = 1 session.will_topic = 'my/will/topic' session.will_retain = False session.will_message = b'test connected client has a last will (and testament) message' session.keep_alive = 42 await session_db_plugin.on_broker_client_connected(client_id='test_client_connected', client_session=session) has_stored_session = False async with aiosqlite.connect(str(db_file)) as db_conn: db_conn.row_factory = sqlite3.Row # Set the row_factory async with await db_conn.execute("SELECT * FROM stored_sessions") as cursor: for row in await cursor.fetchall(): assert row['client_id'] == 'test_client_connected' assert row['clean_session'] == False assert row['will_flag'] == True assert row['will_qos'] == 1 assert row['will_topic'] == 'my/will/topic' assert row['will_retain'] == False assert row['will_message'] == b'test connected client has a last will (and testament) message' assert row['keep_alive'] == 42 has_stored_session = True assert has_stored_session, "client session wasn't stored" @pytest.mark.asyncio async def test_repopulate_stored_sessions(db_file, broker_context, db_session_factory) -> None: broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() async with aiosqlite.connect(str(db_file)) as db: sql = """INSERT INTO stored_sessions ( client_id, clean_session, will_flag, will_qos, keep_alive, retained, subscriptions ) VALUES ( 'test_client_1', 1, 0, 1, 60, '[{"topic":"sensors/#","data":"this message is retained when client reconnects","qos":1}]', '[{"topic":"sensors/#","qos":1}]' )""" await db.execute(sql) await db.commit() await session_db_plugin.on_broker_post_start() session = broker_context.get_session('test_client_1') assert session is not None assert session.retained_messages.qsize() == 1 assert 'sensors/#' in broker_context._broker_instance._subscriptions # ugly: b/c _subscriptions is a list of dictionaries of tuples assert broker_context._broker_instance._subscriptions['sensors/#'][0][1] == 1 @pytest.mark.asyncio async def test_client_retained_message(db_file, broker_context, db_session_factory) -> None: broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() # add a session to the broker await broker_context.add_subscription('test_retained_client', None, None) # update the session so that it's retained session = broker_context.get_session('test_retained_client') assert session is not None session.is_anonymous = False session.clean_session = False session.transitions.disconnect() retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'retain message for disconnected client', qos=2) await session_db_plugin.on_broker_retained_message(client_id='test_retained_client', retained_message=retained_message) async with db_session_factory() as db_session: async with db_session.begin(): stmt = select(StoredSession).filter(StoredSession.client_id == 'test_retained_client') stored_session = await db_session.scalar(stmt) assert stored_session is not None assert len(stored_session.retained) > 0 assert RetainedMessage(topic='my/retained/topic', data='retained message', qos=2) @pytest.mark.asyncio async def test_topic_retained_message(db_file, broker_context, db_session_factory) -> None: broker_context.config = SessionDBPlugin.Config(file=db_file, clear_on_shutdown=False) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() # add a session to the broker await broker_context.add_subscription('test_retained_client', None, None) # update the session so that it's retained session = broker_context.get_session('test_retained_client') assert session is not None session.is_anonymous = False session.clean_session = False session.transitions.disconnect() retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'retained message', qos=2) await session_db_plugin.on_broker_retained_message(client_id=None, retained_message=retained_message) has_stored_message = False async with aiosqlite.connect(str(db_file)) as db_conn: db_conn.row_factory = sqlite3.Row # Set the row_factory async with await db_conn.execute("SELECT * FROM stored_messages") as cursor: # assert(len(await cursor.fetchall()) > 0) for row in await cursor.fetchall(): assert row['topic'] == 'my/retained/topic' assert row['data'] == b'retained message' assert row['qos'] == 2 has_stored_message = True assert has_stored_message, "retained topic message wasn't stored" @pytest.mark.asyncio async def test_topic_clear_retained_message(db_file, broker_context, db_session_factory) -> None: broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() # add a session to the broker await broker_context.add_subscription('test_retained_client', None, None) # update the session so that it's retained session = broker_context.get_session('test_retained_client') assert session is not None session.is_anonymous = False session.clean_session = False session.transitions.disconnect() retained_message = RetainedApplicationMessage(source_session=session, topic='my/retained/topic', data=b'', qos=0) await session_db_plugin.on_broker_retained_message(client_id=None, retained_message=retained_message) async with aiosqlite.connect(str(db_file)) as db_conn: db_conn.row_factory = sqlite3.Row async with await db_conn.execute("SELECT * FROM stored_messages") as cursor: assert(len(await cursor.fetchall()) == 0) @pytest.mark.asyncio async def test_restoring_retained_message(db_file, broker_context, db_session_factory) -> None: broker_context.config = SessionDBPlugin.Config(file=db_file) session_db_plugin = SessionDBPlugin(broker_context) await session_db_plugin.on_broker_pre_start() stmts = ("INSERT INTO stored_messages VALUES(1,'my/retained/topic1',X'72657461696e6564206d657373616765',2)", "INSERT INTO stored_messages VALUES(2,'my/retained/topic2',X'72657461696e6564206d65737361676532',2)", "INSERT INTO stored_messages VALUES(3,'my/retained/topic3',X'72657461696e6564206d65737361676533',2)") async with aiosqlite.connect(str(db_file)) as db: for stmt in stmts: await db.execute(stmt) await db.commit() await session_db_plugin.on_broker_post_start() assert len(broker_context.retained_messages) == 3 assert 'my/retained/topic1' in broker_context.retained_messages assert 'my/retained/topic2' in broker_context.retained_messages assert 'my/retained/topic3' in broker_context.retained_messages @pytest.fixture def db_config(db_file): return { 'listeners': { 'default': { 'type': 'tcp', 'bind': '127.0.0.1:1883' } }, 'plugins': { 'amqtt.plugins.authentication.AnonymousAuthPlugin': {'allow_anonymous': False}, 'amqtt.contrib.persistence.SessionDBPlugin': { 'file': db_file, 'clear_on_shutdown': False } } } @pytest.mark.asyncio async def test_broker_client_no_cleanup(db_file, db_config) -> None: b1 = Broker(config=db_config) await b1.start() await asyncio.sleep(0.1) c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False}) await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False) # test that this message is retained for the topic upon restore await c1.publish("my/retained/topic", b'retained message for topic my/retained/topic', retain=True) await asyncio.sleep(0.2) await c1.disconnect() await b1.shutdown() # new broker should load the previous broker's db file since clean_on_shutdown is false in config b2 = Broker(config=db_config) await b2.start() await asyncio.sleep(0.1) # upon subscribing to topic with retained message, it should be received c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False}) await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False) await c2.subscribe([ ('my/retained/topic', QOS_1) ]) msg = await c2.deliver_message(timeout_duration=1) assert msg is not None assert msg.topic == "my/retained/topic" assert msg.data == b'retained message for topic my/retained/topic' await c2.disconnect() await b2.shutdown() @pytest.mark.asyncio async def test_broker_client_retain_subscription(db_file, db_config) -> None: b1 = Broker(config=db_config) await b1.start() await asyncio.sleep(0.1) c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False}) await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False) # test to make sure the subscription is re-established upon reconnection after broker restart (clear_on_shutdown = False) ret = await c1.subscribe([ ('my/offline/topic', QOS_1) ]) assert ret == [QOS_1,] await asyncio.sleep(0.2) await c1.disconnect() await asyncio.sleep(0.1) await b1.shutdown() # new broker should load the previous broker's db file b2 = Broker(config=db_config) await b2.start() await asyncio.sleep(0.1) # client1's subscription should have been restored, so when it connects, it should receive this message c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False}) await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False) await c2.publish('my/offline/topic', b'standard message to be retained for offline clients') await asyncio.sleep(0.1) await c2.disconnect() await c1.reconnect(cleansession=False) await asyncio.sleep(0.1) msg = await c1.deliver_message(timeout_duration=2) assert msg is not None assert msg.topic == "my/offline/topic" assert msg.data == b'standard message to be retained for offline clients' await c1.disconnect() await b2.shutdown() @pytest.mark.asyncio async def test_broker_client_retain_message(db_file, db_config) -> None: """test to make sure that the retained message because client1 is offline, gets sent when back online after broker restart.""" b1 = Broker(config=db_config) await b1.start() await asyncio.sleep(0.1) c1 = MQTTClient(client_id='test_client1', config={'auto_reconnect':False}) await c1.connect("mqtt://myUsername@127.0.0.1:1883", cleansession=False) # subscribe to a topic with QOS_1 so that we receive messages, even if we're disconnected when sent ret = await c1.subscribe([ ('my/offline/topic', QOS_1) ]) assert ret == [QOS_1,] await asyncio.sleep(0.2) # go offline await c1.disconnect() await asyncio.sleep(0.1) # another client sends a message to previously subscribed to topic c2 = MQTTClient(client_id='test_client2', config={'auto_reconnect':False}) await c2.connect("mqtt://myOtherUsername@localhost:1883", cleansession=False) # this message should be delivered after broker stops and restarts (and client connects) await c2.publish('my/offline/topic', b'standard message to be retained for offline clients') await asyncio.sleep(0.1) await c2.disconnect() await asyncio.sleep(0.1) await b1.shutdown() # new broker should load the previous broker's db file since we declared clear_on_shutdown = False in config b2 = Broker(config=db_config) await b2.start() await asyncio.sleep(0.1) # when first client reconnects, it should receive the message that had been previously retained for it await c1.reconnect(cleansession=False) await asyncio.sleep(0.1) msg = await c1.deliver_message(timeout_duration=2) assert msg is not None assert msg.topic == "my/offline/topic" assert msg.data == b'standard message to be retained for offline clients' # client should also receive a message if send on this topic await c1.publish("my/offline/topic", b"online message should also be received") await asyncio.sleep(0.1) msg = await c1.deliver_message(timeout_duration=2) assert msg is not None assert msg.topic == "my/offline/topic" assert msg.data == b'online message should also be received' await c1.disconnect() await b2.shutdown()