diff --git a/hbmqtt/client/_client.py b/hbmqtt/client/_client.py index 8e7fe3a..6987627 100644 --- a/hbmqtt/client/_client.py +++ b/hbmqtt/client/_client.py @@ -18,6 +18,8 @@ from hbmqtt.mqtt.pubrel import PubrelPacket from hbmqtt.mqtt.pubcomp import PubcompPacket from hbmqtt.mqtt.pingreq import PingReqPacket from hbmqtt.mqtt.pingresp import PingRespPacket +from hbmqtt.mqtt.subscribe import SubscribePacket +from hbmqtt.mqtt.suback import SubackPacket from hbmqtt.errors import MQTTException _defaults = { @@ -220,6 +222,18 @@ class MQTTClient: raise MQTTException("[MQTT-4.3.2-2] Pubcomp packet packet_id doesn't match pubrel packet") self._keep_alive() + @asyncio.coroutine + def subscribe(self, topics): + subscribe = SubscribePacket.build(topics, self._session.next_packet_id) + yield from subscribe.to_stream(self._session.writer) + self.logger.debug(" -out-> " + repr(subscribe)) + + suback = yield from SubackPacket.from_stream(self._session.reader) + self.logger.debug(" <-in-- " + repr(suback)) + if suback.variable_header.packet_id != subscribe.variable_header.packet_id: + raise MQTTException("[MQTT-4.3.2-2] Suback packet packet_id doesn't match subscribe packet") + self._keep_alive() + @asyncio.coroutine def _connect_coro(self): try: diff --git a/hbmqtt/mqtt/packet.py b/hbmqtt/mqtt/packet.py index dc106e4..448d828 100644 --- a/hbmqtt/mqtt/packet.py +++ b/hbmqtt/mqtt/packet.py @@ -132,8 +132,12 @@ class MQTTVariableHeader(metaclass=abc.ABCMeta): yield from writer.drain() @abc.abstractmethod - def to_bytes(self): - return + def to_bytes(self) -> bytes: + pass + + @property + def bytes_length(self): + return len(self.to_bytes()) @classmethod @asyncio.coroutine diff --git a/hbmqtt/mqtt/publish.py b/hbmqtt/mqtt/publish.py index ba22f18..607ca1a 100644 --- a/hbmqtt/mqtt/publish.py +++ b/hbmqtt/mqtt/publish.py @@ -122,4 +122,4 @@ class PublishPacket(MQTTPacket): packet.fixed_header.dup_flag = dup_flag packet.fixed_header.retain_flag = retain packet.fixed_header.qos = qos - return packet \ No newline at end of file + return packet diff --git a/hbmqtt/mqtt/suback.py b/hbmqtt/mqtt/suback.py index 382a202..8482b97 100644 --- a/hbmqtt/mqtt/suback.py +++ b/hbmqtt/mqtt/suback.py @@ -15,6 +15,9 @@ class SubackPayload(MQTTPayload): super().__init__() self.return_codes = return_codes + def __repr__(self): + return type(self).__name__ + '(return_codes={0})'.format(repr(self.return_codes)) + def to_bytes(self, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader): out = b'' for return_code in self.return_codes: @@ -26,12 +29,13 @@ class SubackPayload(MQTTPayload): def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader): return_codes = [] - while True: + bytes_to_read = fixed_header.remaining_length - variable_header.bytes_length + for i in range(0, bytes_to_read): try: return_code_byte = yield from read_or_raise(reader, 1) return_code = bytes_to_int(return_code_byte) return_codes.append(return_code) - except NoDataException: + except NoDataException as e: break return cls(return_codes) diff --git a/hbmqtt/mqtt/subscribe.py b/hbmqtt/mqtt/subscribe.py index f56394e..83aea7b 100644 --- a/hbmqtt/mqtt/subscribe.py +++ b/hbmqtt/mqtt/subscribe.py @@ -40,7 +40,7 @@ class SubscribePacket(MQTTPacket): def __init__(self, fixed: MQTTFixedHeader=None, variable_header: PacketIdVariableHeader=None, payload=None): if fixed is None: - header = MQTTFixedHeader(PacketType.SUBSCRIBE, 0x00) + header = MQTTFixedHeader(PacketType.SUBSCRIBE, 0x02) # [MQTT-3.8.1-1] else: if fixed.packet_type is not PacketType.SUBSCRIBE: raise HBMQTTException("Invalid fixed packet type %s for SubscribePacket init" % fixed.packet_type) @@ -49,3 +49,9 @@ class SubscribePacket(MQTTPacket): super().__init__(header) self.variable_header = variable_header self.payload = payload + + @classmethod + def build(cls, topics, packet_id): + v_header = PacketIdVariableHeader(packet_id) + payload = SubscribePayload(topics) + return SubscribePacket(variable_header=v_header, payload=payload) \ No newline at end of file diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 80d3f71..b99d35c 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -6,10 +6,16 @@ C=MQTTClient() @asyncio.coroutine def test_coro(): - yield from C.connect(uri='mqtt://localhost:1883/', username='testuser', password="passwd") + yield from C.connect(uri='mqtt://iot.eclipse.org:1883/', username='testuser', password="passwd") yield from asyncio.sleep(1) yield from C.publish('a/b', b'0123456789') - yield from asyncio.sleep(10) + yield from C.publish('a/b', b'0123456789', qos=0x01) + yield from C.publish('a/b', b'0123456789', qos=0x02) + yield from C.subscribe([ + {'filter': 'a/b', 'qos': 0x01}, + {'filter': 'c/d', 'qos': 0x02} + ]) + #yield from asyncio.sleep(10) yield from C.disconnect() if __name__ == '__main__':