kopia lustrzana https://github.com/Yakifo/amqtt
pull/8/head
rodzic
3f7e3a2801
commit
cc0454d335
|
@ -199,7 +199,7 @@ class ConnectPacket(MQTTPacket):
|
|||
VARIABLE_HEADER = ConnectVariableHeader
|
||||
PAYLOAD = ConnectPayload
|
||||
|
||||
def __init__(self, fixed: MQTTFixedHeader, vh: ConnectVariableHeader, payload: ConnectPayload):
|
||||
def __init__(self, fixed: MQTTFixedHeader=None, vh: ConnectVariableHeader=None, payload: ConnectPayload=None):
|
||||
if fixed is None:
|
||||
header = MQTTFixedHeader(PacketType.CONNECT, 0x00)
|
||||
else:
|
||||
|
|
|
@ -12,17 +12,23 @@ class ProtocolHandler:
|
|||
"""
|
||||
Class implementing the MQTT communication protocol using asyncio features
|
||||
"""
|
||||
def __init__(self, session: Session, loop: asyncio.BaseEventLoop):
|
||||
def __init__(self, session: Session, loop):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.session = session
|
||||
self._loop = loop
|
||||
self._reader_task = None
|
||||
self._writer_task = None
|
||||
self._reader_ready = asyncio.Event(loop=self._loop)
|
||||
self._writer_ready = asyncio.Event(loop=self._loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
self._reader_task = asyncio.async(self._writer_coro(), loop=self._loop)
|
||||
self._writer_task = asyncio.async(self._reader_coro(), loop=self._loop)
|
||||
self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop)
|
||||
self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop)
|
||||
yield from asyncio.wait([self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
|
||||
self.logger.debug("Handler tasks started")
|
||||
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
self._reader_task.cancel()
|
||||
self._writer_task.cancel()
|
||||
|
@ -32,12 +38,13 @@ class ProtocolHandler:
|
|||
self.logger.debug("Starting reader coro")
|
||||
while True:
|
||||
try:
|
||||
self._reader_ready.set()
|
||||
fixed_header = yield from MQTTFixedHeader.from_stream(self.session.reader)
|
||||
cls = packet_class(fixed_header)
|
||||
packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header)
|
||||
yield from self.session.incoming_queues[packet.fixed_header.packet_type].put(packet)
|
||||
except asyncio.CancelledError:
|
||||
self.logger.warn("Reader coro stopping")
|
||||
self.logger.warn("Reader coro stopped")
|
||||
break
|
||||
except NoDataException:
|
||||
self.logger.debug("No more data to read")
|
||||
|
@ -52,10 +59,21 @@ class ProtocolHandler:
|
|||
out_queue = self.session.outgoing_queue
|
||||
while True:
|
||||
try:
|
||||
self._writer_ready.set()
|
||||
packet = yield from out_queue.get()
|
||||
self.logger.debug(packet)
|
||||
yield from packet.to_stream(self.session.writer)
|
||||
except asyncio.CancelledError:
|
||||
self.logger.warn("Writer coro stopping")
|
||||
# Flush queue
|
||||
while True:
|
||||
try:
|
||||
packet = out_queue.get_nowait()
|
||||
self.logger.debug(packet)
|
||||
yield from packet.to_stream(self.session.writer)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
self.logger.warn("Writer coro stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.warn("Exception in writer coro: %s" % e)
|
||||
|
|
|
@ -17,6 +17,7 @@ logging.basicConfig(level=logging.DEBUG)
|
|||
class ConnectPacketTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def test_read_loop(self):
|
||||
data = b'\x10\x3e\x00\x04MQTT\x04\xce\x00\x00\x00\x0a0123456789\x00\x09WillTopic\x00\x0bWillMessage\x00\x04user\x00\x08password'
|
||||
|
@ -36,11 +37,46 @@ class ConnectPacketTest(unittest.TestCase):
|
|||
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888,
|
||||
loop=loop)
|
||||
handler = ProtocolHandler(S, loop)
|
||||
handler.start()
|
||||
yield from handler.start()
|
||||
incoming_packet = yield from S.incoming_queues[PacketType.CONNECT].get()
|
||||
handler.stop()
|
||||
return incoming_packet
|
||||
|
||||
packet = loop.run_until_complete(client())
|
||||
server.close()
|
||||
self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT)
|
||||
self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT)
|
||||
|
||||
def test_write_loop(self):
|
||||
data_ref = b'\x10\x3e\x00\x04MQTT\x04\xce\x00\x00\x00\x0a0123456789\x00\x09WillTopic\x00\x0bWillMessage\x00\x04user\x00\x08password'
|
||||
packet = None
|
||||
@asyncio.coroutine
|
||||
def serve_test(reader, writer):
|
||||
global packet
|
||||
packet = yield from ConnectPacket.from_stream(reader)
|
||||
self.logger.info("data=" + repr(packet))
|
||||
writer.close()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
coro = asyncio.start_server(serve_test, '127.0.0.1', 8888, loop=loop)
|
||||
server = loop.run_until_complete(coro)
|
||||
|
||||
S = Session()
|
||||
@asyncio.coroutine
|
||||
def client():
|
||||
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888,
|
||||
loop=loop)
|
||||
handler = ProtocolHandler(S, loop)
|
||||
#yield from handler.start()
|
||||
packet = ConnectPacket(vh=ConnectVariableHeader(), payload=ConnectPayload('Id', 'WillTopic', 'WillMessage', 'user', 'password'))
|
||||
self.logger.debug(packet)
|
||||
S.outgoing_queue.put_nowait(packet)
|
||||
#S.outgoing_queue.put_nowait(packet)
|
||||
#yield from S.outgoing_queue.put(packet)
|
||||
self.logger.debug("Messages in queue: %d" % S.outgoing_queue.qsize())
|
||||
yield from handler.stop()
|
||||
|
||||
loop.run_until_complete(client())
|
||||
loop.run_forever()
|
||||
server.close()
|
||||
print(packet)
|
||||
#self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT)
|
Ładowanie…
Reference in New Issue