From f3291eda2f6d6bec2044399e4d232f0313a499dc Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Sat, 7 Jun 2025 18:21:08 -0400 Subject: [PATCH] fixes Yakifo/amqtt#154 : will message is allowed to have zero length. and the StreamReader can return 'None' or zero and they have different implications. if reading value in order to read from a stream and the length is zero, return a NoDataError. Creating exception class to satisfy EM101 and TRY003 --- amqtt/codecs_amqtt.py | 8 ++++++-- amqtt/errors.py | 3 +++ tests/test_client.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/amqtt/codecs_amqtt.py b/amqtt/codecs_amqtt.py index f385155..3e425a3 100644 --- a/amqtt/codecs_amqtt.py +++ b/amqtt/codecs_amqtt.py @@ -2,7 +2,7 @@ import asyncio from struct import pack, unpack from amqtt.adapters import ReaderAdapter -from amqtt.errors import NoDataError +from amqtt.errors import NoDataError, ZeroLengthReadError def bytes_to_hex_str(data: bytes | bytearray) -> str: @@ -59,7 +59,7 @@ async def read_or_raise(reader: ReaderAdapter | asyncio.StreamReader, n: int = - data = await reader.read(n) except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError): data = None - if not data: + if data is None: msg = "No more data" raise NoDataError(msg) return data @@ -72,6 +72,8 @@ async def decode_string(reader: ReaderAdapter | asyncio.StreamReader) -> str: :return: string read from stream. """ length_bytes = await read_or_raise(reader, 2) + if len(length_bytes) < 1: + raise ZeroLengthReadError str_length = unpack("!H", length_bytes)[0] if str_length: byte_str = await read_or_raise(reader, str_length) @@ -90,6 +92,8 @@ async def decode_data_with_length(reader: ReaderAdapter | asyncio.StreamReader) :return: bytes read from stream (without length). """ length_bytes = await read_or_raise(reader, 2) + if len(length_bytes) < 1: + raise ZeroLengthReadError bytes_length = unpack("!H", length_bytes)[0] return await read_or_raise(reader, bytes_length) diff --git a/amqtt/errors.py b/amqtt/errors.py index 71d65c7..ca6cafc 100644 --- a/amqtt/errors.py +++ b/amqtt/errors.py @@ -13,6 +13,9 @@ class CodecError(Exception): class NoDataError(Exception): """Exceptions thrown by packet encode/decode functions.""" +class ZeroLengthReadError(NoDataError): + def __init__(self) -> None: + super().__init__("Decoding a string of length zero.") class BrokerError(Exception): """Exceptions thrown by broker.""" diff --git a/tests/test_client.py b/tests/test_client.py index 59547ce..4048584 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -295,3 +295,35 @@ async def test_client_publish_will_with_retain(broker_fixture, client_config): assert message3.topic == 'test/will/topic' assert message3.data == b'client ABC has disconnected' await client3.disconnect() + + +@pytest.mark.asyncio +async def test_client_with_will_empty_message(broker_fixture): + client_config = { + "broker": { + "uri": "mqtt://localhost:1883" + }, + "reconnect_max_interval": 5, + "will": { + "topic": "test/will/topic", + "retain": True, + "message": "", + "qos": 0 + }, + } + client1 = MQTTClient(client_id="client1", config=client_config) + await client1.connect() + + client2 = MQTTClient(client_id="client2") + await client2.connect('mqtt://localhost:1883') + await client2.subscribe([ + ("test/will/topic", QOS_0) + ]) + + await client1.disconnect() + + message = await client2.deliver_message(timeout_duration=1) + assert message.topic == 'test/will/topic' + assert message.data == b'' + + await client2.disconnect()