kopia lustrzana https://github.com/Yakifo/amqtt
rodzic
52fd303438
commit
753f347f5c
|
@ -1,90 +0,0 @@
|
||||||
# Copyright (c) 2015 Nicolas JOUANIN
|
|
||||||
#
|
|
||||||
# See the file license.txt for copying permission.
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
from hbmqtt.session import Session
|
|
||||||
from hbmqtt.mqtt.packet import MQTTFixedHeader
|
|
||||||
from hbmqtt.mqtt import packet_class
|
|
||||||
from hbmqtt.errors import NoDataException
|
|
||||||
|
|
||||||
class ProtocolHandler:
|
|
||||||
"""
|
|
||||||
Class implementing the MQTT communication protocol using asyncio features
|
|
||||||
"""
|
|
||||||
def __init__(self, session: Session, loop):
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
self.session = session
|
|
||||||
self._loop = loop
|
|
||||||
self._reader_task = None
|
|
||||||
self._writer_task = None
|
|
||||||
self._reader_ready = asyncio.Event(loop=self._loop)
|
|
||||||
self._writer_ready = asyncio.Event(loop=self._loop)
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def start(self):
|
|
||||||
self._running = True
|
|
||||||
self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop)
|
|
||||||
self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop)
|
|
||||||
yield from asyncio.wait([self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
|
|
||||||
self.logger.debug("Handler tasks started")
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def stop(self):
|
|
||||||
self._running = False
|
|
||||||
yield from asyncio.wait([self._writer_task], loop=self._loop)
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def _reader_coro(self):
|
|
||||||
self.logger.debug("Starting reader coro")
|
|
||||||
while self._running:
|
|
||||||
try:
|
|
||||||
self._reader_ready.set()
|
|
||||||
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), 60)
|
|
||||||
if fixed_header:
|
|
||||||
cls = packet_class(fixed_header)
|
|
||||||
packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header)
|
|
||||||
yield from self.session.incoming_queues[packet.fixed_header.packet_type].put(packet)
|
|
||||||
else:
|
|
||||||
self.logger.debug("No data")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.logger.warn("Input stream read timeout")
|
|
||||||
except NoDataException as nde:
|
|
||||||
self.logger.debug("No data available")
|
|
||||||
#break
|
|
||||||
except BaseException as e:
|
|
||||||
self.logger.warn("Exception in reader coro: %s" % e)
|
|
||||||
break
|
|
||||||
self.logger.debug("Reader coro stopped")
|
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def _writer_coro(self):
|
|
||||||
self.logger.debug("Starting writer coro")
|
|
||||||
out_queue = self.session.outgoing_queue
|
|
||||||
packet = None
|
|
||||||
while self._running:
|
|
||||||
try:
|
|
||||||
self._writer_ready.set()
|
|
||||||
packet = yield from asyncio.wait_for(out_queue.get(), 60)
|
|
||||||
self.logger.debug(packet)
|
|
||||||
yield from packet.to_stream(self.session.writer)
|
|
||||||
yield from self.session.writer.drain()
|
|
||||||
except asyncio.TimeoutError as ce:
|
|
||||||
self.logger.warn("Output queue get timeout")
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warn("Exception in writer coro: %s" % e)
|
|
||||||
break
|
|
||||||
self.logger.debug("Writer coro stopping")
|
|
||||||
# Flush queue before stopping
|
|
||||||
if not out_queue.empty():
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
packet = out_queue.get_nowait()
|
|
||||||
self.logger.debug(packet)
|
|
||||||
yield from packet.to_stream(self.session.writer)
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
break
|
|
||||||
self.logger.debug("Writer coro stopped")
|
|
|
@ -2,8 +2,12 @@
|
||||||
#
|
#
|
||||||
# See the file license.txt for copying permission.
|
# See the file license.txt for copying permission.
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from hbmqtt.mqtt.packet import PacketType
|
from hbmqtt.mqtt.packet import PacketType
|
||||||
|
from hbmqtt.mqtt.packet import MQTTFixedHeader
|
||||||
|
from hbmqtt.mqtt import packet_class
|
||||||
|
from hbmqtt.errors import NoDataException
|
||||||
|
|
||||||
class SessionState(Enum):
|
class SessionState(Enum):
|
||||||
NEW = 0
|
NEW = 0
|
||||||
|
@ -11,7 +15,14 @@ class SessionState(Enum):
|
||||||
DISCONNECTED = 2
|
DISCONNECTED = 2
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
def __init__(self):
|
def __init__(self, loop):
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self._loop = loop
|
||||||
|
self._reader_task = None
|
||||||
|
self._writer_task = None
|
||||||
|
self._reader_ready = asyncio.Event(loop=self._loop)
|
||||||
|
self._writer_ready = asyncio.Event(loop=self._loop)
|
||||||
|
|
||||||
self.state = SessionState.NEW
|
self.state = SessionState.NEW
|
||||||
self.reader = None
|
self.reader = None
|
||||||
self.writer = None
|
self.writer = None
|
||||||
|
@ -41,3 +52,69 @@ class Session:
|
||||||
def next_packet_id(self):
|
def next_packet_id(self):
|
||||||
self._packet_id += 1
|
self._packet_id += 1
|
||||||
return self._packet_id
|
return self._packet_id
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def start(self):
|
||||||
|
self._running = True
|
||||||
|
self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop)
|
||||||
|
self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop)
|
||||||
|
yield from asyncio.wait([self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
|
||||||
|
self.logger.debug("Handler tasks started")
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def stop(self):
|
||||||
|
self._running = False
|
||||||
|
yield from asyncio.wait([self._writer_task], loop=self._loop)
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def _reader_coro(self):
|
||||||
|
self.logger.debug("Starting reader coro")
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
self._reader_ready.set()
|
||||||
|
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader), 5)
|
||||||
|
if fixed_header:
|
||||||
|
cls = packet_class(fixed_header)
|
||||||
|
packet = yield from cls.from_stream(self.reader, fixed_header=fixed_header)
|
||||||
|
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
|
||||||
|
else:
|
||||||
|
self.logger.debug("No data")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.logger.warn("Input stream read timeout")
|
||||||
|
except NoDataException as nde:
|
||||||
|
self.logger.debug("No data available")
|
||||||
|
except BaseException as e:
|
||||||
|
self.logger.warn("Exception in reader coro: %s" % e)
|
||||||
|
break
|
||||||
|
self.logger.debug("Reader coro stopped")
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def _writer_coro(self):
|
||||||
|
self.logger.debug("Starting writer coro")
|
||||||
|
out_queue = self.outgoing_queue
|
||||||
|
packet = None
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
self._writer_ready.set()
|
||||||
|
packet = yield from asyncio.wait_for(out_queue.get(), 5)
|
||||||
|
self.logger.debug(packet)
|
||||||
|
yield from packet.to_stream(self.writer)
|
||||||
|
yield from self.writer.drain()
|
||||||
|
except asyncio.TimeoutError as ce:
|
||||||
|
self.logger.warn("Output queue get timeout")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warn("Exception in writer coro: %s" % e)
|
||||||
|
break
|
||||||
|
self.logger.debug("Writer coro stopping")
|
||||||
|
# Flush queue before stopping
|
||||||
|
if not out_queue.empty():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
packet = out_queue.get_nowait()
|
||||||
|
self.logger.debug(packet)
|
||||||
|
yield from packet.to_stream(self.writer)
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
self.logger.debug("Writer coro stopped")
|
||||||
|
|
|
@ -5,10 +5,7 @@ import unittest
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from hbmqtt.mqtt.connect import ConnectPacket, ConnectVariableHeader, ConnectPayload
|
from hbmqtt.mqtt.connect import ConnectPacket, ConnectVariableHeader, ConnectPayload
|
||||||
from hbmqtt.mqtt.packet import MQTTFixedHeader, PacketType
|
|
||||||
from hbmqtt.errors import MQTTException
|
|
||||||
from hbmqtt.session import Session
|
from hbmqtt.session import Session
|
||||||
from hbmqtt.protocol import ProtocolHandler
|
|
||||||
from hbmqtt.mqtt.packet import PacketType
|
from hbmqtt.mqtt.packet import PacketType
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -35,17 +32,18 @@ class ConnectPacketTest(unittest.TestCase):
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def client():
|
def client():
|
||||||
S = Session()
|
S = Session(loop)
|
||||||
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888,
|
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888,
|
||||||
loop=loop)
|
loop=loop)
|
||||||
handler = ProtocolHandler(S, loop)
|
yield from S.start()
|
||||||
yield from handler.start()
|
|
||||||
incoming_packet = yield from S.incoming_queues[PacketType.CONNECT].get()
|
incoming_packet = yield from S.incoming_queues[PacketType.CONNECT].get()
|
||||||
handler.stop()
|
S.writer.close()
|
||||||
|
yield from S.stop()
|
||||||
return incoming_packet
|
return incoming_packet
|
||||||
|
|
||||||
packet = loop.run_until_complete(client())
|
packet = loop.run_until_complete(client())
|
||||||
server.close()
|
server.close()
|
||||||
|
loop.stop()
|
||||||
self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT)
|
self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT)
|
||||||
|
|
||||||
def test_write_loop(self):
|
def test_write_loop(self):
|
||||||
|
@ -62,12 +60,11 @@ class ConnectPacketTest(unittest.TestCase):
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def client():
|
def client():
|
||||||
S = Session()
|
S = Session(loop)
|
||||||
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=loop)
|
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=loop)
|
||||||
handler = ProtocolHandler(S, loop)
|
yield from S.start()
|
||||||
yield from handler.start()
|
|
||||||
yield from S.outgoing_queue.put(test_packet)
|
yield from S.outgoing_queue.put(test_packet)
|
||||||
yield from handler.stop()
|
yield from S.stop()
|
||||||
S.writer.close()
|
S.writer.close()
|
||||||
|
|
||||||
# Start server
|
# Start server
|
Ładowanie…
Reference in New Issue