diff --git a/hbmqtt/codecs/utils.py b/hbmqtt/codecs/utils.py index 40eeac4..eb0836f 100644 --- a/hbmqtt/codecs/utils.py +++ b/hbmqtt/codecs/utils.py @@ -22,16 +22,18 @@ def bytes_to_int(data): """ return int.from_bytes(data, byteorder='big') -def int_to_bytes(int_value: int) -> bytes: +def int_to_bytes(int_value: int, length=-1) -> bytes: """ convert an integer to a sequence of bytes using big endian byte ordering :param int_value: integer value to convert + :param length: (optional) byte length :return: byte sequence """ - byte_length = ceil(int_value.bit_length()//8) - if byte_length == 0: - byte_length = 1 - return int_value.to_bytes(byte_length, byteorder='big') + if length == -1: + length = ceil(int_value.bit_length()//8) + if length == 0: + length = 1 + return int_value.to_bytes(length, byteorder='big') @asyncio.coroutine @@ -51,7 +53,7 @@ def read_or_raise(reader, n=-1): return data @asyncio.coroutine -def decode_string(reader): +def decode_string(reader) -> bytes: """ Read a string from a reader and decode it according to MQTT string specification :param reader: Stream reader @@ -60,4 +62,9 @@ def decode_string(reader): length_bytes = yield from read_or_raise(reader, 2) str_length = bytes_to_int(length_bytes) byte_str = yield from read_or_raise(reader, str_length) - return byte_str.decode(encoding='utf-8') \ No newline at end of file + return byte_str.decode(encoding='utf-8') + +def encode_string(string: str) -> bytes: + data = string.encode(encoding='utf-8') + data_length = len(data) + return int_to_bytes(data_length, 2) + data diff --git a/tests/codecs/test_utils.py b/tests/codecs/test_utils.py index f0fdc04..f7d480f 100644 --- a/tests/codecs/test_utils.py +++ b/tests/codecs/test_utils.py @@ -8,6 +8,7 @@ from hbmqtt.codecs.utils import ( bytes_to_hex_str, bytes_to_int, decode_string, + encode_string, ) @@ -25,8 +26,12 @@ class TestUtils(unittest.TestCase): ret = bytes_to_int(b'\xff\xff') self.assertEqual(ret, 65535) - def test_read_string(self): + def test_decode_string(self): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'\x00\x02AA') ret = self.loop.run_until_complete(decode_string(stream)) - self.assertEqual(ret, 'AA') \ No newline at end of file + self.assertEqual(ret, 'AA') + + def test_encode_string(self): + encoded = encode_string('AA') + self.assertEqual(b'\x00\x02AA', encoded) \ No newline at end of file