diff --git a/amqtt/client.py b/amqtt/client.py index 5154a98..a11c5af 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -456,6 +456,8 @@ class MQTTClient: # if not self._handler: self._handler = ClientProtocolHandler(self.plugins_manager) + connection_timeout = self.config.get('connection_timeout', None) + if secure: sc = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, @@ -476,21 +478,24 @@ class MQTTClient: # Open connection if scheme in ("mqtt", "mqtts"): - conn_reader, conn_writer = await asyncio.open_connection( + conn_reader, conn_writer = await asyncio.wait_for( + asyncio.open_connection( self.session.remote_address, self.session.remote_port, **kwargs, - ) + ), timeout=connection_timeout) reader = StreamReaderAdapter(conn_reader) writer = StreamWriterAdapter(conn_writer) elif scheme in ("ws", "wss") and self.session.broker_uri: - websocket: ClientConnection = await websockets.connect( - self.session.broker_uri, - subprotocols=[websockets.Subprotocol("mqtt")], - additional_headers=self.additional_headers, - **kwargs, - ) + websocket: ClientConnection = await asyncio.wait_for( + websockets.connect( + self.session.broker_uri, + subprotocols=[websockets.Subprotocol("mqtt")], + additional_headers=self.additional_headers, + **kwargs, + ), timeout=connection_timeout) + reader = WebSocketsReader(websocket) writer = WebSocketsWriter(websocket) elif not self.session.broker_uri: diff --git a/docs/references/client_config.md b/docs/references/client_config.md index 24985d1..e061ac5 100644 --- a/docs/references/client_config.md +++ b/docs/references/client_config.md @@ -25,6 +25,11 @@ Default retain value to messages published. Defaults to `false`. Enable or disable auto-reconnect if connection with the broker is interrupted. Defaults to `false`. + +### `connect_timeout` *(int)* + +If specified, the number of seconds before a connection times out + ### `reconnect_retries` *(int)* Maximum reconnection retries. Defaults to `2`. Negative value will cause client to reconnect infinitely. diff --git a/tests/test_client.py b/tests/test_client.py index 8c426d7..28f5c49 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -446,6 +446,15 @@ async def test_connect_incorrect_scheme(): await client.connect('"mq://someplace') +@pytest.mark.asyncio +@pytest.mark.timeout(3) +async def test_connect_timeout(): + config = {"auto_reconnect": False, "connection_timeout": 2} + client = MQTTClient(config=config) + with pytest.raises(ClientError): + await client.connect("mqtt://localhost:8888") + + async def test_client_no_auth(): class MockEntryPoints: