Merge session and protocol classes

Move client test
pull/8/head
Nicolas Jouanin 2015-06-26 22:00:26 +02:00
rodzic 52fd303438
commit 753f347f5c
4 zmienionych plików z 86 dodań i 102 usunięć

Wyświetl plik

@ -1,90 +0,0 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import logging
import asyncio
from hbmqtt.session import Session
from hbmqtt.mqtt.packet import MQTTFixedHeader
from hbmqtt.mqtt import packet_class
from hbmqtt.errors import NoDataException
class ProtocolHandler:
"""
Class implementing the MQTT communication protocol using asyncio features
"""
def __init__(self, session: Session, loop):
self.logger = logging.getLogger(__name__)
self.session = session
self._loop = loop
self._reader_task = None
self._writer_task = None
self._reader_ready = asyncio.Event(loop=self._loop)
self._writer_ready = asyncio.Event(loop=self._loop)
self._running = False
@asyncio.coroutine
def start(self):
self._running = True
self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop)
self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop)
yield from asyncio.wait([self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
self.logger.debug("Handler tasks started")
@asyncio.coroutine
def stop(self):
self._running = False
yield from asyncio.wait([self._writer_task], loop=self._loop)
@asyncio.coroutine
def _reader_coro(self):
self.logger.debug("Starting reader coro")
while self._running:
try:
self._reader_ready.set()
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), 60)
if fixed_header:
cls = packet_class(fixed_header)
packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header)
yield from self.session.incoming_queues[packet.fixed_header.packet_type].put(packet)
else:
self.logger.debug("No data")
except asyncio.TimeoutError:
self.logger.warn("Input stream read timeout")
except NoDataException as nde:
self.logger.debug("No data available")
#break
except BaseException as e:
self.logger.warn("Exception in reader coro: %s" % e)
break
self.logger.debug("Reader coro stopped")
@asyncio.coroutine
def _writer_coro(self):
self.logger.debug("Starting writer coro")
out_queue = self.session.outgoing_queue
packet = None
while self._running:
try:
self._writer_ready.set()
packet = yield from asyncio.wait_for(out_queue.get(), 60)
self.logger.debug(packet)
yield from packet.to_stream(self.session.writer)
yield from self.session.writer.drain()
except asyncio.TimeoutError as ce:
self.logger.warn("Output queue get timeout")
except Exception as e:
self.logger.warn("Exception in writer coro: %s" % e)
break
self.logger.debug("Writer coro stopping")
# Flush queue before stopping
if not out_queue.empty():
while True:
try:
packet = out_queue.get_nowait()
self.logger.debug(packet)
yield from packet.to_stream(self.session.writer)
except asyncio.QueueEmpty:
break
self.logger.debug("Writer coro stopped")

Wyświetl plik

@ -2,8 +2,12 @@
#
# See the file license.txt for copying permission.
import asyncio
import logging
from enum import Enum
from hbmqtt.mqtt.packet import PacketType
from hbmqtt.mqtt.packet import MQTTFixedHeader
from hbmqtt.mqtt import packet_class
from hbmqtt.errors import NoDataException
class SessionState(Enum):
NEW = 0
@ -11,7 +15,14 @@ class SessionState(Enum):
DISCONNECTED = 2
class Session:
def __init__(self):
def __init__(self, loop):
self.logger = logging.getLogger(__name__)
self._loop = loop
self._reader_task = None
self._writer_task = None
self._reader_ready = asyncio.Event(loop=self._loop)
self._writer_ready = asyncio.Event(loop=self._loop)
self.state = SessionState.NEW
self.reader = None
self.writer = None
@ -41,3 +52,69 @@ class Session:
def next_packet_id(self):
self._packet_id += 1
return self._packet_id
@asyncio.coroutine
def start(self):
self._running = True
self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop)
self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop)
yield from asyncio.wait([self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
self.logger.debug("Handler tasks started")
@asyncio.coroutine
def stop(self):
self._running = False
yield from asyncio.wait([self._writer_task], loop=self._loop)
@asyncio.coroutine
def _reader_coro(self):
self.logger.debug("Starting reader coro")
while self._running:
try:
self._reader_ready.set()
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader), 5)
if fixed_header:
cls = packet_class(fixed_header)
packet = yield from cls.from_stream(self.reader, fixed_header=fixed_header)
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
else:
self.logger.debug("No data")
except asyncio.TimeoutError:
self.logger.warn("Input stream read timeout")
except NoDataException as nde:
self.logger.debug("No data available")
except BaseException as e:
self.logger.warn("Exception in reader coro: %s" % e)
break
self.logger.debug("Reader coro stopped")
@asyncio.coroutine
def _writer_coro(self):
self.logger.debug("Starting writer coro")
out_queue = self.outgoing_queue
packet = None
while self._running:
try:
self._writer_ready.set()
packet = yield from asyncio.wait_for(out_queue.get(), 5)
self.logger.debug(packet)
yield from packet.to_stream(self.writer)
yield from self.writer.drain()
except asyncio.TimeoutError as ce:
self.logger.warn("Output queue get timeout")
except Exception as e:
self.logger.warn("Exception in writer coro: %s" % e)
break
self.logger.debug("Writer coro stopping")
# Flush queue before stopping
if not out_queue.empty():
while True:
try:
packet = out_queue.get_nowait()
self.logger.debug(packet)
yield from packet.to_stream(self.writer)
except asyncio.QueueEmpty:
break
self.logger.debug("Writer coro stopped")

Wyświetl plik

@ -5,10 +5,7 @@ import unittest
import asyncio
from hbmqtt.mqtt.connect import ConnectPacket, ConnectVariableHeader, ConnectPayload
from hbmqtt.mqtt.packet import MQTTFixedHeader, PacketType
from hbmqtt.errors import MQTTException
from hbmqtt.session import Session
from hbmqtt.protocol import ProtocolHandler
from hbmqtt.mqtt.packet import PacketType
import logging
@ -35,17 +32,18 @@ class ConnectPacketTest(unittest.TestCase):
@asyncio.coroutine
def client():
S = Session()
S = Session(loop)
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888,
loop=loop)
handler = ProtocolHandler(S, loop)
yield from handler.start()
yield from S.start()
incoming_packet = yield from S.incoming_queues[PacketType.CONNECT].get()
handler.stop()
S.writer.close()
yield from S.stop()
return incoming_packet
packet = loop.run_until_complete(client())
server.close()
loop.stop()
self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT)
def test_write_loop(self):
@ -62,12 +60,11 @@ class ConnectPacketTest(unittest.TestCase):
@asyncio.coroutine
def client():
S = Session()
S = Session(loop)
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=loop)
handler = ProtocolHandler(S, loop)
yield from handler.start()
yield from S.start()
yield from S.outgoing_queue.put(test_packet)
yield from handler.stop()
yield from S.stop()
S.writer.close()
# Start server