diff --git a/hbmqtt/mqtt/packet.py b/hbmqtt/mqtt/packet.py index 7d3f0b0..3e30f50 100644 --- a/hbmqtt/mqtt/packet.py +++ b/hbmqtt/mqtt/packet.py @@ -134,6 +134,24 @@ class MQTTVariableHeader(metaclass=abc.ABCMeta): def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader): return +class PacketIdVariableHeader(MQTTVariableHeader): + def __init__(self, packet_id): + super().__init__() + self.packet_id = packet_id + + def to_bytes(self): + out = b'' + out += int_to_bytes(self.packet_id, 2) + return out + + @classmethod + def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader): + packet_id = yield from decode_packet_id(reader) + return cls(packet_id) + + def __repr__(self): + return type(self).__name__ + '(packet_id={0})'.format(self.packet_id) + class MQTTPayload(metaclass=abc.ABCMeta): def __init__(self): @@ -212,25 +230,29 @@ class MQTTPacket: def bytes_length(self): return len(self.to_bytes()) + def __getattr__(self, name): + """ + This method is implemented in order to facilitate access to packet data structure + attribute is first searched in packet then in fixed_header, variable_header and payload + example : packet.packet_id is equivalent to packet.variable_header.packet_id + :param name: name of attribute the packet to get + :return: the value of the attribute found. Raise AttributeError otherwise. + """ + try: + return getattr(self.fixed_header, name) + except AttributeError: + pass + try: + return getattr(self.variable_header, name) + except AttributeError: + pass + try: + return getattr(self.payload, name) + except AttributeError: + pass + raise AttributeError("Attribute '%s' not found in packet data structure" % name) + + def __repr__(self): return type(self).__name__ + '(fixed={0!r}, variable={1!r}, payload={2!r})'.\ format(self.fixed_header, self.variable_header, self.payload) - - -class PacketIdVariableHeader(MQTTVariableHeader): - def __init__(self, packet_id): - super().__init__() - self.packet_id = packet_id - - def to_bytes(self): - out = b'' - out += int_to_bytes(self.packet_id, 2) - return out - - @classmethod - def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader): - packet_id = yield from decode_packet_id(reader) - return cls(packet_id) - - def __repr__(self): - return type(self).__name__ + '(packet_id={0})'.format(self.packet_id) diff --git a/tests/mqtt/test_connect.py b/tests/mqtt/test_connect.py index e2f4878..cf0d0ef 100644 --- a/tests/mqtt/test_connect.py +++ b/tests/mqtt/test_connect.py @@ -89,3 +89,36 @@ class ConnectPacketTest(unittest.TestCase): message = ConnectPacket(header, variable_header, payload) encoded = message.to_bytes() self.assertEqual(encoded, b'\x10\x3e\x00\x04MQTT\x04\xce\x00\x00\x00\x0a0123456789\x00\x09WillTopic\x00\x0bWillMessage\x00\x04user\x00\x08password') + + def test_getattr_ok(self): + data = b'\x10\x3e\x00\x04MQTT\x04\xce\x00\x00\x00\x0a0123456789\x00\x09WillTopic\x00\x0bWillMessage\x00\x04user\x00\x08password' + stream = BufferReader(data) + message = self.loop.run_until_complete(ConnectPacket.from_stream(stream)) + self.assertEqual(message.variable_header.proto_name, "MQTT") + self.assertEqual(message.proto_name, "MQTT") + self.assertEqual(message.variable_header.proto_level, 4) + self.assertEqual(message.proto_level, 4) + self.assertTrue(message.variable_header.username_flag) + self.assertTrue(message.username_flag) + self.assertTrue(message.variable_header.password_flag) + self.assertTrue(message.password_flag) + self.assertFalse(message.variable_header.will_retain_flag) + self.assertFalse(message.will_retain_flag) + self.assertEqual(message.variable_header.will_qos, 1) + self.assertEqual(message.will_qos, 1) + self.assertTrue(message.variable_header.will_flag) + self.assertTrue(message.will_flag) + self.assertTrue(message.variable_header.clean_session_flag) + self.assertTrue(message.clean_session_flag) + self.assertFalse(message.variable_header.reserved_flag) + self.assertFalse(message.reserved_flag) + self.assertEqual(message.payload.client_id, '0123456789') + self.assertEqual(message.client_id, '0123456789') + self.assertEqual(message.payload.will_topic, 'WillTopic') + self.assertEqual(message.will_topic, 'WillTopic') + self.assertEqual(message.payload.will_message, b'WillMessage') + self.assertEqual(message.will_message, b'WillMessage') + self.assertEqual(message.payload.username, 'user') + self.assertEqual(message.username, 'user') + self.assertEqual(message.payload.password, 'password') + self.assertEqual(message.password, 'password')