import json import sqlite3 import tempfile from pathlib import Path from unittest.mock import patch, call, ANY import aiosqlite import pytest from jsonschema import validate from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from amqtt.broker import BrokerContext, Broker from amqtt.contrib.shadows import ShadowPlugin from amqtt.contrib.shadows.models import Shadow, ShadowUpdateError from amqtt.contrib.shadows.states import StateDocument, State, MetaTimestamp from amqtt.mqtt3.constants import QOS_0 from amqtt.session import IncomingApplicationMessage from tests.contrib.test_shadows_schema import * @pytest.fixture def db_file(): with tempfile.TemporaryDirectory() as temp_dir: with tempfile.NamedTemporaryFile(mode='wb', delete=True) as tmp: yield Path(temp_dir) / f"{tmp.name}.db" @pytest.fixture def db_connection(db_file): test_db_connect = f"sqlite+aiosqlite:///{db_file}" yield test_db_connect @pytest.fixture @pytest.mark.asyncio async def db_session_maker(db_connection): engine = create_async_engine(f"{db_connection}") db_session_maker = async_sessionmaker(engine, expire_on_commit=False) yield db_session_maker @pytest.fixture @pytest.mark.asyncio async def shadow_plugin(db_connection): cfg = ShadowPlugin.Config(connection=db_connection) ctx = BrokerContext(broker=Broker()) ctx.config = cfg shadow_plugin = ShadowPlugin(ctx) await shadow_plugin.on_broker_pre_start() yield shadow_plugin @pytest.mark.asyncio async def test_shadow_find_latest_empty(db_session_maker, shadow_plugin): async with db_session_maker() as db_session, db_session.begin(): shadow = await Shadow.latest_version(session=db_session, device_id='device123', name="myShadowName") assert shadow is None @pytest.mark.asyncio async def test_shadow_create_new(db_file, db_connection, db_session_maker, shadow_plugin): async with db_session_maker() as db_session, db_session.begin(): shadow = Shadow(device_id='device123', name="myShadowName") db_session.add(shadow) await db_session.commit() async with aiosqlite.connect(db_file) as db_conn: db_conn.row_factory = sqlite3.Row # Set the row_factory has_shadow = False async with await db_conn.execute("SELECT * FROM shadows_shadow") as cursor: for row in await cursor.fetchall(): assert row['name'] == 'myShadowName' assert row['device_id'] == 'device123' assert row['state'] == '{}' has_shadow = True assert has_shadow, "Shadow was not created." @pytest.mark.asyncio async def test_shadow_create_find_empty_state(db_connection, db_session_maker, shadow_plugin): async with db_session_maker() as db_session, db_session.begin(): shadow = Shadow(device_id='device123', name="myShadowName") db_session.add(shadow) await db_session.commit() await db_session.flush() async with db_session_maker() as db_session, db_session.begin(): shadow = await Shadow.latest_version(session=db_session, device_id='device123', name="myShadowName") assert shadow is not None assert shadow.version == 1 assert shadow.state == StateDocument() @pytest.mark.asyncio async def test_shadow_create_find_state_doc(db_connection, db_session_maker, shadow_plugin): state_doc = StateDocument( state=State( desired={'item1': 'value1', 'item2': 'value2'}, reported={'item3': 'value3', 'item4': 'value4'}, ) ) async with db_session_maker() as db_session, db_session.begin(): shadow = Shadow(device_id='device123', name="myShadowName") shadow.state = state_doc db_session.add(shadow) await db_session.commit() await db_session.flush() def new_equal(a, b): diff = abs(a.timestamp - b.timestamp) return diff <= 2 async with db_session_maker() as db_session, db_session.begin(): shadow = await Shadow.latest_version(session=db_session, device_id='device123', name="myShadowName") assert shadow is not None assert shadow.version == 1 with patch.object(MetaTimestamp, "__eq__", new=new_equal) as mocked_mqtt_publish: assert shadow.state == state_doc @pytest.mark.asyncio async def test_shadow_update_state(db_connection, db_session_maker, shadow_plugin): state_doc = StateDocument( state=State( desired={'item1': 'value1', 'item2': 'value2'}, reported={'item3': 'value3', 'item4': 'value4'}, ) ) async with db_session_maker() as db_session, db_session.begin(): shadow = Shadow(device_id='device123', name="myShadowName") shadow.state = state_doc db_session.add(shadow) await db_session.commit() await db_session.flush() async with db_session_maker() as db_session, db_session.begin(): shadow = await Shadow.latest_version(session=db_session, device_id='device123', name="myShadowName") assert shadow is not None shadow.state = StateDocument( state=State( desired={'item5': 'value5', 'item6': 'value6'}, reported={'item7': 'value7', 'item8': 'value8'}, ) ) with pytest.raises(ShadowUpdateError): await db_session.commit() @pytest.mark.asyncio async def test_shadow_update_state(db_connection, db_session_maker, shadow_plugin): state_doc = StateDocument( state=State( desired={'item1': 'value1', 'item2': 'value2'}, reported={'item3': 'value3', 'item4': 'value4'}, ) ) async with db_session_maker() as db_session, db_session.begin(): shadow = Shadow(device_id='device123', name="myShadowName") shadow.state = state_doc db_session.add(shadow) await db_session.commit() async with db_session_maker() as db_session, db_session.begin(): shadow = await Shadow.latest_version(session=db_session, device_id='device123', name="myShadowName") assert shadow is not None shadow.state += StateDocument( state=State( desired={'item1': 'value1a', 'item6': 'value6'} ) ) await db_session.commit() async with db_session_maker() as db_session, db_session.begin(): shadow_list = await Shadow.all(db_session, "device123", "myShadowName") assert len(shadow_list) == 2 async with db_session_maker() as db_session, db_session.begin(): shadow = await Shadow.latest_version(session=db_session, device_id='device123', name="myShadowName") assert shadow is not None assert shadow.version == 2 assert shadow.state.state.desired == {'item1': 'value1a', 'item2': 'value2', 'item6': 'value6'} assert shadow.state.state.reported == {'item3': 'value3', 'item4': 'value4'} @pytest.mark.asyncio async def test_shadow_plugin_get_rejected(shadow_plugin): """test """ with patch.object(BrokerContext, 'broadcast_message', return_value=None) as mock_method: msg = IncomingApplicationMessage(packet_id=1, topic='$shadow/myClient123/myShadow/get', qos=QOS_0, data=json.dumps({}).encode('utf-8'), retain=False) await shadow_plugin.on_broker_message_received(client_id="myClient123", message=msg) mock_method.assert_called() topic, message = mock_method.call_args[0] assert topic == '$shadow/myClient123/myShadow/get/rejected' validate(instance=json.loads(message.decode('utf-8')), schema=get_rejected_schema) @pytest.mark.asyncio async def test_shadow_plugin_update_accepted(shadow_plugin): with patch.object(BrokerContext, 'broadcast_message', return_value=None) as mock_method: update_msg = { 'state': { 'desired': { 'item1': 'value1', 'item2': 'value2' } } } validate(instance=update_msg, schema=update_schema) msg = IncomingApplicationMessage(packet_id=1, topic='$shadow/myClient123/myShadow/update', qos=QOS_0, data=json.dumps(update_msg).encode('utf-8'), retain=False) await shadow_plugin.on_broker_message_received(client_id="myClient123", message=msg) accepted_call = call('$shadow/myClient123/myShadow/update/accepted', ANY) document_call = call('$shadow/myClient123/myShadow/update/documents', ANY) delta_call = call('$shadow/myClient123/myShadow/update/delta', ANY) iota_call = call('$shadow/myClient123/myShadow/update/iota', ANY) mock_method.assert_has_calls( [ accepted_call, document_call, delta_call, iota_call, ], any_order=True, ) for actual in mock_method.call_args_list: if actual == accepted_call: validate(instance=json.loads(actual.args[1].decode('utf-8')), schema=update_accepted_schema) elif actual == document_call: validate(instance=json.loads(actual.args[1].decode('utf-8')), schema=update_documents_schema) elif actual == delta_call: validate(instance=json.loads(actual.args[1].decode('utf-8')), schema=delta_schema) elif actual == iota_call: validate(instance=json.loads(actual.args[1].decode('utf-8')), schema=delta_schema) else: assert False, "unknown call made to broadcast" @pytest.mark.asyncio async def test_shadow_plugin_get_accepted(shadow_plugin): with patch.object(BrokerContext, 'broadcast_message', return_value=None) as mock_method: update_msg = { 'state': { 'desired': { 'item1': 'value1', 'item2': 'value2' } } } update_msg = IncomingApplicationMessage(packet_id=1, topic='$shadow/myClient123/myShadow/update', qos=QOS_0, data=json.dumps(update_msg).encode('utf-8'), retain=False) await shadow_plugin.on_broker_message_received(client_id="myClient123", message=update_msg) mock_method.reset_mock() get_msg = IncomingApplicationMessage(packet_id=1, topic='$shadow/myClient123/myShadow/get', qos=QOS_0, data=json.dumps({}).encode('utf-8'), retain=False) await shadow_plugin.on_broker_message_received(client_id="myClient123", message=get_msg) get_accepted = call('$shadow/myClient123/myShadow/get/accepted', ANY) mock_method.assert_has_calls( [get_accepted] ) has_msg = False for actual in mock_method.call_args_list: if actual == get_accepted: validate(instance=json.loads(actual.args[1].decode('utf-8')), schema=get_accepted_schema) has_msg = True assert has_msg, "could not find the broadcast call for get accepted"