diff --git a/hbmqtt/protocol.py b/hbmqtt/protocol.py index effef0a..42b1919 100644 --- a/hbmqtt/protocol.py +++ b/hbmqtt/protocol.py @@ -3,27 +3,10 @@ # See the file license.txt for copying permission. import logging import asyncio -import threading from hbmqtt.session import Session from hbmqtt.mqtt.packet import MQTTFixedHeader from hbmqtt.mqtt import packet_class - -# class ProtoThread(threading.Thread): -# def __init__(self, session: Session, loop: asyncio.BaseEventLoop): -# super().__init__(name="MQTT Protocol communication thread") -# self.logger = logging.getLogger(__name__) -# self._loop = loop -# self._session = session -# -# def run(self): -# asyncio.set_event_loop(self._loop) -# self._loop.call_soon(asyncio.async, self._read_protocol()) -# if not self._loop.is_running(): -# self._loop.run_forever() -# -# @asyncio.coroutine -# def _read_protocol(self): -# while true: +from hbmqtt.errors import NoDataException class ProtocolHandler: """ @@ -52,10 +35,13 @@ class ProtocolHandler: fixed_header = yield from MQTTFixedHeader.from_stream(self.session.reader) cls = packet_class(fixed_header) packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header) - self.logger.debug(packet) + yield from self.session.incoming_queues[packet.fixed_header.packet_type].put(packet) except asyncio.CancelledError: self.logger.warn("Reader coro stopping") break + except NoDataException: + self.logger.debug("No more data to read") + break except Exception as e: self.logger.warn("Exception in reader coro: %s" % e) break @@ -63,13 +49,13 @@ class ProtocolHandler: @asyncio.coroutine def _writer_coro(self): self.logger.debug("Starting writer coro") - out_queue = self.session._out_queue + out_queue = self.session.outgoing_queue while True: try: packet = yield from out_queue.get() yield from packet.to_stream(self.session.writer) except asyncio.CancelledError: - self.logger.warn("Reader coro stopping") + self.logger.warn("Writer coro stopping") break except Exception as e: self.logger.warn("Exception in writer coro: %s" % e) diff --git a/hbmqtt/session.py b/hbmqtt/session.py index 1b9612b..74840a6 100644 --- a/hbmqtt/session.py +++ b/hbmqtt/session.py @@ -3,6 +3,7 @@ # See the file license.txt for copying permission. import asyncio from enum import Enum +from hbmqtt.mqtt.packet import PacketType class SessionState(Enum): NEW = 0 @@ -31,7 +32,10 @@ class Session: self.scheme = None self._packet_id = 0 - self._out_queue = asyncio.Queue() + self.incoming_queues = dict() + for p in PacketType: + self.incoming_queues[p] = asyncio.Queue() + self.outgoing_queue = asyncio.Queue() @property def next_packet_id(self): diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 0000000..9c0b5a5 --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,46 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. +import unittest +import asyncio + +from hbmqtt.mqtt.connect import ConnectPacket, ConnectVariableHeader, ConnectPayload +from hbmqtt.mqtt.packet import MQTTFixedHeader, PacketType +from hbmqtt.errors import MQTTException +from hbmqtt.session import Session +from hbmqtt.protocol import ProtocolHandler +from hbmqtt.mqtt.packet import PacketType +import logging + +logging.basicConfig(level=logging.DEBUG) + +class ConnectPacketTest(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + + def test_read_loop(self): + data = b'\x10\x3e\x00\x04MQTT\x04\xce\x00\x00\x00\x0a0123456789\x00\x09WillTopic\x00\x0bWillMessage\x00\x04user\x00\x08password' + @asyncio.coroutine + def serve_test(reader, writer): + writer.write(data) + yield from writer.drain() + writer.close() + + loop = asyncio.get_event_loop() + coro = asyncio.start_server(serve_test, '127.0.0.1', 8888, loop=loop) + server = loop.run_until_complete(coro) + + @asyncio.coroutine + def client(): + S = Session() + S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888, + loop=loop) + handler = ProtocolHandler(S, loop) + handler.start() + incoming_packet = yield from S.incoming_queues[PacketType.CONNECT].get() + handler.stop() + return incoming_packet + + packet = loop.run_until_complete(client()) + server.close() + self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT) \ No newline at end of file