diff --git a/hbmqtt/codecs/header.py b/hbmqtt/codecs/header.py index ed4725e..5436832 100644 --- a/hbmqtt/codecs/header.py +++ b/hbmqtt/codecs/header.py @@ -66,29 +66,31 @@ class MQTTHeaderCodec: @staticmethod @asyncio.coroutine - def encode(header: MQTTHeader, writer): - def encode_remaining_length(length:int): + def encode(header: MQTTHeader) -> bytes: + def encode_remaining_length(length: int): encoded = b'' while True: length_byte = length % 0x80 - length /= 0x80 + length //= 0x80 if length > 0: length_byte |= 0x80 - encoded += int_to_bytes(length_byte, 1) + encoded += int_to_bytes(length_byte) if length <= 0: break return encoded + out = b'' packet_type = 0 try: - packet_type = (header.message_type.value << 4) & header.flags - encoded_type = int_to_bytes(packet_type, 1) - writer.write(encoded_type) + packet_type = (header.message_type.value << 4) | header.flags + out += int_to_bytes(packet_type) except OverflowError: raise CodecException('packet_size encoding exceed 1 byte length: value=%d', packet_type) try: encoded_length = encode_remaining_length(header.remaining_length) - writer.write(encoded_length) + out += encoded_length except OverflowError: raise CodecException('message length encoding exceed 1 byte length: value=%d', header.remaining_length) + + return out diff --git a/hbmqtt/codecs/utils.py b/hbmqtt/codecs/utils.py index 9283139..2c7256b 100644 --- a/hbmqtt/codecs/utils.py +++ b/hbmqtt/codecs/utils.py @@ -21,14 +21,13 @@ def bytes_to_int(data): """ return int.from_bytes(data, byteorder='big') -def int_to_bytes(int_value:int, length) -> bytes: +def int_to_bytes(int_value:int) -> bytes: """ convert an integer to a sequence of bytes using big endian byte ordering :param int_value: integer value to convert - :param length: byte sequence length :return: byte sequence """ - int_value.to_bytes(length, byteorder='big') + int_value.to_bytes(int_value.bit_length(), byteorder='big') @asyncio.coroutine diff --git a/tests/codecs/test_header.py b/tests/codecs/test_header.py index f139ac1..1a48c31 100644 --- a/tests/codecs/test_header.py +++ b/tests/codecs/test_header.py @@ -4,7 +4,9 @@ import unittest import asyncio from hbmqtt.codecs.header import MQTTHeaderCodec, MQTTHeaderException -from hbmqtt.message import MessageType +from hbmqtt.message import MessageType, MQTTHeader +from hbmqtt.codecs.utils import bytes_to_hex_str + class TestMQTTHeaderCodec(unittest.TestCase): def setUp(self): @@ -41,3 +43,13 @@ class TestMQTTHeaderCodec(unittest.TestCase): stream.feed_data(b'\x10\xff\xff\xff\xff\x7f') with self.assertRaises(MQTTHeaderException): self.loop.run_until_complete(MQTTHeaderCodec.decode(stream)) + + def test_encode(self): + header = MQTTHeader(MessageType.CONNECT, 0x00, 0) + data = self.loop.run_until_complete(MQTTHeaderCodec.encode(header)) + self.assertEqual(data, b'\x10\x00') + + def test_encode_2(self): + header = MQTTHeader(MessageType.CONNECT, 0x00, 268435455) + data = self.loop.run_until_complete(MQTTHeaderCodec.encode(header)) + self.assertEqual(data, b'\x10\xff\xff\xff\x7f')