diff --git a/amqtt/client.py b/amqtt/client.py index 8da9aba..1828f02 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -470,6 +470,7 @@ class MQTTClient: reader: StreamReaderAdapter | WebSocketsReader | None = None writer: StreamWriterAdapter | WebSocketsWriter | None = None self._connected_state.clear() + # Open connection if scheme in ("mqtt", "mqtts"): conn_reader, conn_writer = await asyncio.open_connection( @@ -489,11 +490,11 @@ class MQTTClient: ) reader = WebSocketsReader(websocket) writer = WebSocketsWriter(websocket) - - if reader is None or writer is None: - self.session.transitions.disconnect() - self.logger.warning("reader or writer not initialized") - msg = "reader or writer not initialized" + elif not self.session.broker_uri: + msg = "missing broker uri" + raise ClientError(msg) + else: + msg = f"incorrect scheme defined in uri: '{scheme!r}'" raise ClientError(msg) # Start MQTT protocol diff --git a/tests/test_client.py b/tests/test_client.py index 7d56274..8962d21 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,7 @@ import logging import pytest from amqtt.client import MQTTClient -from amqtt.errors import ConnectError +from amqtt.errors import ClientError, ConnectError from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" @@ -301,4 +301,14 @@ async def test_client_publish_will_with_retain(broker_fixture, client_config): async def test_connect_broken_uri(): config = {"auto_reconnect": False} client = MQTTClient(config=config) - await client.connect('"mqtt://someplace') + with pytest.raises(ClientError): + await client.connect('"mqtt://someplace') + + +@pytest.mark.asyncio +async def test_connect_incorrect_scheme(): + config = {"auto_reconnect": False} + client = MQTTClient(config=config) + with pytest.raises(ClientError): + await client.connect('"mq://someplace') +