From cdf6f0dfff39e5620f7e931140b47dda7ac4465f Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Sat, 26 Jul 2025 17:36:20 -0400 Subject: [PATCH] post_init was not running on client config. cafile, keyfile, certfile are part of the connection section --- amqtt/client.py | 13 ++++++------- amqtt/contexts.py | 13 +++++++++++-- tests/contrib/test_cert.py | 19 ++++++++++++------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/amqtt/client.py b/amqtt/client.py index a579dad..3330dc2 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -462,15 +462,14 @@ class MQTTClient: sc = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, cafile=self.session.cafile - ) - if "certfile" in self.config and "keyfile" in self.config: - sc.load_cert_chain(certfile=self.config["certfile"], keyfile=self.config["keyfile"]) - if "cafile" in self.config: - sc.load_verify_locations(cafile=self.config["cafile"]) - if "check_hostname" in self.config and isinstance(self.config["check_hostname"], bool): - sc.check_hostname = self.config["check_hostname"] + if self.config.connection.certfile and self.config.connection.keyfile: + sc.load_cert_chain(certfile=self.config.connection.certfile, keyfile=self.config.connection.keyfile) + if self.config.connection.cafile: + sc.load_verify_locations(cafile=self.config.connection.cafile) + if self.config.check_hostname is not None: + sc.check_hostname = self.config.check_hostname sc.verify_mode = ssl.CERT_REQUIRED kwargs["ssl"] = sc diff --git a/amqtt/contexts.py b/amqtt/contexts.py index 18812dc..a51149c 100644 --- a/amqtt/contexts.py +++ b/amqtt/contexts.py @@ -325,8 +325,10 @@ class ClientConfig(Dictable): topics: dict[str, TopicConfig] | None = field(default_factory=dict) """Specify the topics and what flags should be set for messages published to them.""" broker: ConnectionConfig | None = field(default_factory=ConnectionConfig) + """*Deprecated* Configuration for connecting to the broker. Use `connection` field instead.""" + connection: ConnectionConfig | None = field(default_factory=ConnectionConfig) """Configuration for connecting to the broker. See - [ConnectionConfig](./#amqtt.contexts.ConnectionConfig) for more information.""" + [ConnectionConfig](./#amqtt.contexts.ConnectionConfig) for more information.""" plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins) """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is a dictionary of configuration options for that plugin. See [Plugins](http://localhost:8000/custom_plugins/) @@ -337,12 +339,19 @@ class ClientConfig(Dictable): """Message, topic and flags that should be sent to if the client disconnects. See [WillConfig](./#amqtt.contexts.WillConfig)""" - def __post__init__(self) -> None: + def __post_init__(self) -> None: """Check config for errors and transform fields for easier use.""" if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2): msg = "Client config: default QoS must be 0, 1 or 2." raise ValueError(msg) + if self.broker is not None: + self.connection = self.broker + + if bool(not self.connection.keyfile) ^ bool(not self.connection.certfile): + raise ValueError("Connection key and certificate files are _both_ required.") + + @classmethod def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig": """Create a client config from a dictionary.""" diff --git a/tests/contrib/test_cert.py b/tests/contrib/test_cert.py index cc3513f..27f51ee 100644 --- a/tests/contrib/test_cert.py +++ b/tests/contrib/test_cert.py @@ -138,9 +138,11 @@ async def test_client_broker_cert_authentication(ca_creds, server_creds, device_ client_config = { 'auto_reconnect': False, - 'cafile': ca_crt, - 'certfile': device_crt, - 'keyfile': device_key, + 'broker': { + 'cafile': ca_crt, + 'certfile': device_crt, + 'keyfile': device_key + } } c = MQTTClient(config=client_config, client_id='mydeviceid') @@ -159,7 +161,8 @@ async def test_client_broker_cert_authentication(ca_creds, server_creds, device_ def ssl_error_logger(loop, context): logger.critical("Asyncio SSL error:", context.get("message")) - assert "exception" not in context, f"Exception: {repr(context["exception"])}" + exc = repr(context.get("exception")) + assert "exception" not in context, f"Exception: {exc}" @pytest.mark.asyncio @@ -198,9 +201,11 @@ async def test_client_broker_wrong_certs(ca_creds, server_creds, device_creds): wrong_ca_crt = temp_dir / 'ca.crt' client_config = { 'auto_reconnect': False, - 'cafile': wrong_ca_crt, - 'certfile': device_crt, - 'keyfile': device_key, + 'connection': { + 'cafile': wrong_ca_crt, + 'certfile': device_crt, + 'keyfile': device_key, + } } c = MQTTClient(config=client_config, client_id='mydeviceid')