amqtt/tests/contrib/test_cert.py

215 wiersze
6.7 KiB
Python

import asyncio
import logging
import shutil
import subprocess
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
from OpenSSL import crypto
import pytest
from amqtt.broker import BrokerContext, Broker
from amqtt.client import MQTTClient
from amqtt.contrib.cert import UserAuthCertPlugin
from amqtt.errors import ConnectError
from amqtt.scripts.server_creds import server_creds as get_server_creds
from amqtt.scripts.device_creds import device_creds as get_device_creds
from amqtt.scripts.ca_creds import ca_creds as get_ca_creds
from amqtt.session import Session
logger = logging.getLogger(__name__)
@pytest.fixture
def temp_directory():
temp_dir = Path(tempfile.mkdtemp(prefix="amqtt-test-"))
yield temp_dir
logger.critical(temp_dir)
# shutil.rmtree(temp_dir)
@pytest.fixture
def ca_creds(temp_directory):
get_ca_creds(country='US', state="NY", locality="NYC", org_name="aMQTT", cn="aMQTT", output_dir=str(temp_directory))
ca_key = temp_directory / "ca.key"
ca_crt = temp_directory / "ca.crt"
return ca_key, ca_crt
@pytest.fixture
def server_creds(ca_creds, temp_directory):
ca_key = temp_directory / "ca.key"
ca_crt = temp_directory / "ca.crt"
get_server_creds(country='US', org_name='aMQTT', cn='aMQTT',
output_dir=str(temp_directory),
ca_key_fn=str(ca_key), ca_crt_fn=str(ca_crt))
server_key = temp_directory / "server.key"
server_crt = temp_directory / "server.crt"
yield server_key, server_crt
@pytest.fixture
def device_creds(ca_creds, temp_directory):
ca_key, ca_crt = ca_creds
get_device_creds(country='US', org_name='aMQTT',
device_id="mydeviceid", uri='test.amqtt.io',
output_dir=str(temp_directory),
ca_key_fn=str(ca_key), ca_crt_fn=str(ca_crt))
yield temp_directory / "mydeviceid.key", temp_directory / "mydeviceid.crt"
def test_device_cert(temp_directory, ca_creds, server_creds, device_creds):
ca_key, ca_crt = ca_creds
server_key, server_crt = server_creds
device_key, device_crt = device_creds
assert ca_key.exists()
assert ca_crt.exists()
assert server_key.exists()
assert server_crt.exists()
assert device_key.exists()
assert device_crt.exists()
r = subprocess.run(f"openssl x509 -in {str(device_crt)} -noout -text", shell=True, capture_output=True, text=True, check=True)
assert "URI:spiffe://test.amqtt.io/device/mydeviceid, DNS:mydeviceid.local" in r.stdout
@pytest.fixture
def ssl_object_mock(device_creds):
device_key, device_crt = device_creds
with device_crt.open("rb") as f:
cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
mock = MagicMock()
mock.getpeercert.return_value = crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)
yield mock
@pytest.mark.parametrize("uri_domain,client_id,expected_result", [
('test.amqtt.io', 'mydeviceid', True),
('test.amqtt.io', 'notmydeviceid', False),
('other.amqtt.io', 'mydeviceid', False),
])
@pytest.mark.asyncio
async def test_cert_plugin(ssl_object_mock, uri_domain, client_id, expected_result):
empty_cfg = {
'listeners': {'default': {'type':'tcp', 'bind':'127.0.0.1:1883'}},
'plugins': {}
}
bc = BrokerContext(broker=Broker(config=empty_cfg))
bc.config = UserAuthCertPlugin.Config(uri_domain=uri_domain)
cert_auth_plugin = UserAuthCertPlugin(bc)
s = Session()
s.client_id = client_id
s.ssl_object = ssl_object_mock
assert await cert_auth_plugin.authenticate(session=s) == expected_result
@pytest.mark.asyncio
async def test_client_broker_cert_authentication(ca_creds, server_creds, device_creds):
ca_key, ca_crt = ca_creds
server_key, server_crt = server_creds
device_key, device_crt = device_creds
broker_config = {
'listeners': {
'default': {
'type':'tcp',
'bind':'127.0.0.1:8883',
'ssl': True,
'keyfile': server_key,
'certfile': server_crt,
'cafile': ca_crt,
}
},
'plugins': {
'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{},
'amqtt.contrib.cert.UserAuthCertPlugin': {'uri_domain': 'test.amqtt.io'},
}
}
b = Broker(config=broker_config)
await b.start()
await asyncio.sleep(1)
client_config = {
'auto_reconnect': False,
'broker': {
'cafile': ca_crt,
'certfile': device_crt,
'keyfile': device_key
}
}
c = MQTTClient(config=client_config, client_id='mydeviceid')
await c.connect('mqtts://127.0.0.1:8883')
await asyncio.sleep(0.1)
assert 'mydeviceid' in b._sessions
s, _ = b._sessions['mydeviceid']
assert s.transitions.state == "connected"
await asyncio.sleep(0.1)
await c.disconnect()
await asyncio.sleep(0.1)
await b.shutdown()
def ssl_error_logger(loop, context):
logger.critical("Asyncio SSL error:", context.get("message"))
exc = repr(context.get("exception"))
assert "exception" not in context, f"Exception: {exc}"
@pytest.mark.asyncio
async def test_client_broker_wrong_certs(ca_creds, server_creds, device_creds):
loop = asyncio.get_event_loop()
loop.set_exception_handler(ssl_error_logger)
loop.set_debug(True)
ca_key, ca_crt = ca_creds
server_key, server_crt = server_creds
device_key, device_crt = device_creds
broker_config = {
'listeners': {
'default': {
'type':'tcp',
'bind':'127.0.0.1:8883',
'ssl': True,
'keyfile': server_key,
'certfile': server_crt,
'cafile': ca_crt,
}
},
'plugins': {
'amqtt.plugins.logging_amqtt.PacketLoggerPlugin':{},
'amqtt.contrib.cert.UserAuthCertPlugin': {'uri_domain': 'test.amqtt.io'},
}
}
b = Broker(config=broker_config)
await b.start()
await asyncio.sleep(1)
# generate a different ca certificate and make sure the connection fails
temp_dir = Path(tempfile.mkdtemp(prefix="amqtt-test-"))
get_ca_creds(country='US', state="NY", locality="NYC", org_name="aMQTT", cn="aMQTT", output_dir=str(temp_dir))
wrong_ca_crt = temp_dir / 'ca.crt'
client_config = {
'auto_reconnect': False,
'connection': {
'cafile': wrong_ca_crt,
'certfile': device_crt,
'keyfile': device_key,
}
}
c = MQTTClient(config=client_config, client_id='mydeviceid')
with pytest.raises(ConnectError, match='.+?SSL: CERTIFICATE_VERIFY_FAILED.+?'):
await c.connect('mqtts://127.0.0.1:8883')
await b.shutdown()