import asyncio import logging import shutil import subprocess import tempfile from pathlib import Path from unittest.mock import MagicMock from OpenSSL import crypto import pytest from amqtt.broker import BrokerContext, Broker from amqtt.client import MQTTClient from amqtt.contrib.cert import UserAuthCertPlugin from amqtt.errors import ConnectError from amqtt.scripts.server_creds import server_creds as get_server_creds from amqtt.scripts.device_creds import device_creds as get_device_creds from amqtt.scripts.ca_creds import ca_creds as get_ca_creds from amqtt.session import Session logger = logging.getLogger(__name__) @pytest.fixture def temp_directory(): temp_dir = Path(tempfile.mkdtemp(prefix="amqtt-test-")) yield temp_dir logger.critical(temp_dir) # shutil.rmtree(temp_dir) @pytest.fixture def ca_creds(temp_directory): get_ca_creds(country='US', state="NY", locality="NYC", org_name="aMQTT", cn="aMQTT", output_dir=str(temp_directory)) ca_key = temp_directory / "ca.key" ca_crt = temp_directory / "ca.crt" return ca_key, ca_crt @pytest.fixture def server_creds(ca_creds, temp_directory): ca_key = temp_directory / "ca.key" ca_crt = temp_directory / "ca.crt" get_server_creds(country='US', org_name='aMQTT', cn='aMQTT', output_dir=str(temp_directory), ca_key_fn=str(ca_key), ca_crt_fn=str(ca_crt)) server_key = temp_directory / "server.key" server_crt = temp_directory / "server.crt" yield server_key, server_crt @pytest.fixture def device_creds(ca_creds, temp_directory): ca_key, ca_crt = ca_creds get_device_creds(country='US', org_name='aMQTT', device_id="mydeviceid", uri='test.amqtt.io', output_dir=str(temp_directory), ca_key_fn=str(ca_key), ca_crt_fn=str(ca_crt)) yield temp_directory / "mydeviceid.key", temp_directory / "mydeviceid.crt" def test_device_cert(temp_directory, ca_creds, server_creds, device_creds): ca_key, ca_crt = ca_creds server_key, server_crt = server_creds device_key, device_crt = device_creds assert ca_key.exists() assert ca_crt.exists() assert server_key.exists() assert server_crt.exists() assert device_key.exists() assert device_crt.exists() r = subprocess.run(f"openssl x509 -in {str(device_crt)} -noout -text", shell=True, capture_output=True, text=True, check=True) assert "URI:spiffe://test.amqtt.io/device/mydeviceid, DNS:mydeviceid.local" in r.stdout @pytest.fixture def ssl_object_mock(device_creds): device_key, device_crt = device_creds with device_crt.open("rb") as f: cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) mock = MagicMock() mock.getpeercert.return_value = crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) yield mock @pytest.mark.parametrize("uri_domain,client_id,expected_result", [ ('test.amqtt.io', 'mydeviceid', True), ('test.amqtt.io', 'notmydeviceid', False), ('other.amqtt.io', 'mydeviceid', False), ]) @pytest.mark.asyncio async def test_cert_plugin(ssl_object_mock, uri_domain, client_id, expected_result): empty_cfg = { 'listeners': {'default': {'type':'tcp', 'bind':'127.0.0.1:1883'}}, 'plugins': {} } bc = BrokerContext(broker=Broker(config=empty_cfg)) bc.config = UserAuthCertPlugin.Config(uri_domain=uri_domain) cert_auth_plugin = UserAuthCertPlugin(bc) s = Session() s.client_id = client_id s.ssl_object = ssl_object_mock assert await cert_auth_plugin.authenticate(session=s) == expected_result @pytest.mark.asyncio async def test_client_broker_cert_authentication(ca_creds, server_creds, device_creds): ca_key, ca_crt = ca_creds server_key, server_crt = server_creds device_key, device_crt = device_creds broker_config = { 'listeners': { 'default': { 'type':'tcp', 'bind':'127.0.0.1:8883', 'ssl': True, 'keyfile': server_key, 'certfile': server_crt, 'cafile': ca_crt, } }, 'plugins': { 'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{}, 'amqtt.contrib.cert.UserAuthCertPlugin': {'uri_domain': 'test.amqtt.io'}, } } b = Broker(config=broker_config) await b.start() await asyncio.sleep(1) client_config = { 'auto_reconnect': False, 'broker': { 'cafile': ca_crt, 'certfile': device_crt, 'keyfile': device_key } } c = MQTTClient(config=client_config, client_id='mydeviceid') await c.connect('mqtts://127.0.0.1:8883') await asyncio.sleep(0.1) assert 'mydeviceid' in b._sessions s, _ = b._sessions['mydeviceid'] assert s.transitions.state == "connected" await asyncio.sleep(0.1) await c.disconnect() await asyncio.sleep(0.1) await b.shutdown() def ssl_error_logger(loop, context): logger.critical("Asyncio SSL error:", context.get("message")) exc = repr(context.get("exception")) assert "exception" not in context, f"Exception: {exc}" @pytest.mark.asyncio async def test_client_broker_wrong_certs(ca_creds, server_creds, device_creds): loop = asyncio.get_event_loop() loop.set_exception_handler(ssl_error_logger) loop.set_debug(True) ca_key, ca_crt = ca_creds server_key, server_crt = server_creds device_key, device_crt = device_creds broker_config = { 'listeners': { 'default': { 'type':'tcp', 'bind':'127.0.0.1:8883', 'ssl': True, 'keyfile': server_key, 'certfile': server_crt, 'cafile': ca_crt, } }, 'plugins': { 'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{}, 'amqtt.contrib.cert.UserAuthCertPlugin': {'uri_domain': 'test.amqtt.io'}, } } b = Broker(config=broker_config) await b.start() await asyncio.sleep(1) # generate a different ca certificate and make sure the connection fails temp_dir = Path(tempfile.mkdtemp(prefix="amqtt-test-")) get_ca_creds(country='US', state="NY", locality="NYC", org_name="aMQTT", cn="aMQTT", output_dir=str(temp_dir)) wrong_ca_crt = temp_dir / 'ca.crt' client_config = { 'auto_reconnect': False, 'connection': { 'cafile': wrong_ca_crt, 'certfile': device_crt, 'keyfile': device_key, } } c = MQTTClient(config=client_config, client_id='mydeviceid') with pytest.raises(ConnectError, match='.+?SSL: CERTIFICATE_VERIFY_FAILED.+?'): await c.connect('mqtts://127.0.0.1:8883') await b.shutdown()