From 753f347f5c700a64608b1ce4556f6e2fc094e70d Mon Sep 17 00:00:00 2001 From: Nicolas Jouanin Date: Fri, 26 Jun 2015 22:00:26 +0200 Subject: [PATCH] Merge session and protocol classes Move client test --- hbmqtt/protocol.py | 90 --------------------- hbmqtt/session.py | 79 +++++++++++++++++- {tests/client => hbmqtt}/test_client.py | 0 tests/{test_protocol.py => test_session.py} | 19 ++--- 4 files changed, 86 insertions(+), 102 deletions(-) delete mode 100644 hbmqtt/protocol.py rename {tests/client => hbmqtt}/test_client.py (100%) rename tests/{test_protocol.py => test_session.py} (85%) diff --git a/hbmqtt/protocol.py b/hbmqtt/protocol.py deleted file mode 100644 index f27d5b9..0000000 --- a/hbmqtt/protocol.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. -import logging -import asyncio -from hbmqtt.session import Session -from hbmqtt.mqtt.packet import MQTTFixedHeader -from hbmqtt.mqtt import packet_class -from hbmqtt.errors import NoDataException - -class ProtocolHandler: - """ - Class implementing the MQTT communication protocol using asyncio features - """ - def __init__(self, session: Session, loop): - self.logger = logging.getLogger(__name__) - self.session = session - self._loop = loop - self._reader_task = None - self._writer_task = None - self._reader_ready = asyncio.Event(loop=self._loop) - self._writer_ready = asyncio.Event(loop=self._loop) - self._running = False - - @asyncio.coroutine - def start(self): - self._running = True - self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop) - self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop) - yield from asyncio.wait([self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop) - self.logger.debug("Handler tasks started") - - @asyncio.coroutine - def stop(self): - self._running = False - yield from asyncio.wait([self._writer_task], loop=self._loop) - - - @asyncio.coroutine - def _reader_coro(self): - self.logger.debug("Starting reader coro") - while self._running: - try: - self._reader_ready.set() - fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), 60) - if fixed_header: - cls = packet_class(fixed_header) - packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header) - yield from self.session.incoming_queues[packet.fixed_header.packet_type].put(packet) - else: - self.logger.debug("No data") - except asyncio.TimeoutError: - self.logger.warn("Input stream read timeout") - except NoDataException as nde: - self.logger.debug("No data available") - #break - except BaseException as e: - self.logger.warn("Exception in reader coro: %s" % e) - break - self.logger.debug("Reader coro stopped") - - - @asyncio.coroutine - def _writer_coro(self): - self.logger.debug("Starting writer coro") - out_queue = self.session.outgoing_queue - packet = None - while self._running: - try: - self._writer_ready.set() - packet = yield from asyncio.wait_for(out_queue.get(), 60) - self.logger.debug(packet) - yield from packet.to_stream(self.session.writer) - yield from self.session.writer.drain() - except asyncio.TimeoutError as ce: - self.logger.warn("Output queue get timeout") - except Exception as e: - self.logger.warn("Exception in writer coro: %s" % e) - break - self.logger.debug("Writer coro stopping") - # Flush queue before stopping - if not out_queue.empty(): - while True: - try: - packet = out_queue.get_nowait() - self.logger.debug(packet) - yield from packet.to_stream(self.session.writer) - except asyncio.QueueEmpty: - break - self.logger.debug("Writer coro stopped") diff --git a/hbmqtt/session.py b/hbmqtt/session.py index 74840a6..91fd813 100644 --- a/hbmqtt/session.py +++ b/hbmqtt/session.py @@ -2,8 +2,12 @@ # # See the file license.txt for copying permission. import asyncio +import logging from enum import Enum from hbmqtt.mqtt.packet import PacketType +from hbmqtt.mqtt.packet import MQTTFixedHeader +from hbmqtt.mqtt import packet_class +from hbmqtt.errors import NoDataException class SessionState(Enum): NEW = 0 @@ -11,7 +15,14 @@ class SessionState(Enum): DISCONNECTED = 2 class Session: - def __init__(self): + def __init__(self, loop): + self.logger = logging.getLogger(__name__) + self._loop = loop + self._reader_task = None + self._writer_task = None + self._reader_ready = asyncio.Event(loop=self._loop) + self._writer_ready = asyncio.Event(loop=self._loop) + self.state = SessionState.NEW self.reader = None self.writer = None @@ -41,3 +52,69 @@ class Session: def next_packet_id(self): self._packet_id += 1 return self._packet_id + + @asyncio.coroutine + def start(self): + self._running = True + self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop) + self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop) + yield from asyncio.wait([self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop) + self.logger.debug("Handler tasks started") + + @asyncio.coroutine + def stop(self): + self._running = False + yield from asyncio.wait([self._writer_task], loop=self._loop) + + + @asyncio.coroutine + def _reader_coro(self): + self.logger.debug("Starting reader coro") + while self._running: + try: + self._reader_ready.set() + fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader), 5) + if fixed_header: + cls = packet_class(fixed_header) + packet = yield from cls.from_stream(self.reader, fixed_header=fixed_header) + yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet) + else: + self.logger.debug("No data") + except asyncio.TimeoutError: + self.logger.warn("Input stream read timeout") + except NoDataException as nde: + self.logger.debug("No data available") + except BaseException as e: + self.logger.warn("Exception in reader coro: %s" % e) + break + self.logger.debug("Reader coro stopped") + + + @asyncio.coroutine + def _writer_coro(self): + self.logger.debug("Starting writer coro") + out_queue = self.outgoing_queue + packet = None + while self._running: + try: + self._writer_ready.set() + packet = yield from asyncio.wait_for(out_queue.get(), 5) + self.logger.debug(packet) + yield from packet.to_stream(self.writer) + yield from self.writer.drain() + except asyncio.TimeoutError as ce: + self.logger.warn("Output queue get timeout") + except Exception as e: + self.logger.warn("Exception in writer coro: %s" % e) + break + self.logger.debug("Writer coro stopping") + # Flush queue before stopping + if not out_queue.empty(): + while True: + try: + packet = out_queue.get_nowait() + self.logger.debug(packet) + yield from packet.to_stream(self.writer) + except asyncio.QueueEmpty: + break + self.logger.debug("Writer coro stopped") diff --git a/tests/client/test_client.py b/hbmqtt/test_client.py similarity index 100% rename from tests/client/test_client.py rename to hbmqtt/test_client.py diff --git a/tests/test_protocol.py b/tests/test_session.py similarity index 85% rename from tests/test_protocol.py rename to tests/test_session.py index cd59d6b..1dbd5f5 100644 --- a/tests/test_protocol.py +++ b/tests/test_session.py @@ -5,10 +5,7 @@ import unittest import asyncio from hbmqtt.mqtt.connect import ConnectPacket, ConnectVariableHeader, ConnectPayload -from hbmqtt.mqtt.packet import MQTTFixedHeader, PacketType -from hbmqtt.errors import MQTTException from hbmqtt.session import Session -from hbmqtt.protocol import ProtocolHandler from hbmqtt.mqtt.packet import PacketType import logging @@ -35,17 +32,18 @@ class ConnectPacketTest(unittest.TestCase): @asyncio.coroutine def client(): - S = Session() + S = Session(loop) S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=loop) - handler = ProtocolHandler(S, loop) - yield from handler.start() + yield from S.start() incoming_packet = yield from S.incoming_queues[PacketType.CONNECT].get() - handler.stop() + S.writer.close() + yield from S.stop() return incoming_packet packet = loop.run_until_complete(client()) server.close() + loop.stop() self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT) def test_write_loop(self): @@ -62,12 +60,11 @@ class ConnectPacketTest(unittest.TestCase): @asyncio.coroutine def client(): - S = Session() + S = Session(loop) S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=loop) - handler = ProtocolHandler(S, loop) - yield from handler.start() + yield from S.start() yield from S.outgoing_queue.put(test_packet) - yield from handler.stop() + yield from S.stop() S.writer.close() # Start server