diff --git a/amqtt/codecs_amqtt.py b/amqtt/codecs_amqtt.py index f385155..3e425a3 100644 --- a/amqtt/codecs_amqtt.py +++ b/amqtt/codecs_amqtt.py @@ -2,7 +2,7 @@ import asyncio from struct import pack, unpack from amqtt.adapters import ReaderAdapter -from amqtt.errors import NoDataError +from amqtt.errors import NoDataError, ZeroLengthReadError def bytes_to_hex_str(data: bytes | bytearray) -> str: @@ -59,7 +59,7 @@ async def read_or_raise(reader: ReaderAdapter | asyncio.StreamReader, n: int = - data = await reader.read(n) except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError): data = None - if not data: + if data is None: msg = "No more data" raise NoDataError(msg) return data @@ -72,6 +72,8 @@ async def decode_string(reader: ReaderAdapter | asyncio.StreamReader) -> str: :return: string read from stream. """ length_bytes = await read_or_raise(reader, 2) + if len(length_bytes) < 1: + raise ZeroLengthReadError str_length = unpack("!H", length_bytes)[0] if str_length: byte_str = await read_or_raise(reader, str_length) @@ -90,6 +92,8 @@ async def decode_data_with_length(reader: ReaderAdapter | asyncio.StreamReader) :return: bytes read from stream (without length). """ length_bytes = await read_or_raise(reader, 2) + if len(length_bytes) < 1: + raise ZeroLengthReadError bytes_length = unpack("!H", length_bytes)[0] return await read_or_raise(reader, bytes_length) diff --git a/amqtt/errors.py b/amqtt/errors.py index 71d65c7..ca6cafc 100644 --- a/amqtt/errors.py +++ b/amqtt/errors.py @@ -13,6 +13,9 @@ class CodecError(Exception): class NoDataError(Exception): """Exceptions thrown by packet encode/decode functions.""" +class ZeroLengthReadError(NoDataError): + def __init__(self) -> None: + super().__init__("Decoding a string of length zero.") class BrokerError(Exception): """Exceptions thrown by broker."""