Manage reader,writer in handler

pull/8/head
Nicolas Jouanin 2015-08-01 21:51:20 +02:00
rodzic 77a96faeeb
commit dd19fbb511
4 zmienionych plików z 38 dodań i 39 usunięć

Wyświetl plik

@ -10,8 +10,9 @@ from transitions import MachineError
from hbmqtt.utils import not_in_dict_or_none
from hbmqtt.session import Session
from hbmqtt.mqtt.connack import ReturnCode
from hbmqtt.mqtt.connack import CONNECTION_ACCEPTED
from hbmqtt.mqtt.protocol.client_handler import ClientProtocolHandler
from hbmqtt.adapters import StreamReaderAdapter, StreamWriterAdapter
_defaults = {
'keep_alive': 10,
@ -189,15 +190,17 @@ class MQTTClient:
@asyncio.coroutine
def _connect_coro(self):
try:
self.session.reader, self.session.writer = \
conn_reader, conn_writer = \
yield from asyncio.open_connection(self.session.remote_address, self.session.remote_port)
self._handler = ClientProtocolHandler(loop=self._loop)
reader = StreamReaderAdapter(conn_reader)
writer = StreamWriterAdapter(conn_writer)
self._handler = ClientProtocolHandler(reader, writer, loop=self._loop)
self._handler.attach_to_session(self.session)
yield from self._handler.start()
return_code = yield from self._handler.mqtt_connect()
if return_code is not ReturnCode.CONNECTION_ACCEPTED:
if return_code is not CONNECTION_ACCEPTED:
yield from self._handler.stop()
self.session.machine.disconnect()
self.logger.warn("Connection rejected with code '%s'" % return_code)

Wyświetl plik

@ -1,12 +1,9 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import logging
import asyncio
from asyncio import futures
from hbmqtt.mqtt.protocol.handler import ProtocolHandler
from hbmqtt.mqtt.packet import MQTTFixedHeader
from hbmqtt.mqtt.packet import PacketType
from hbmqtt.mqtt.packet import *
from hbmqtt.mqtt.connect import ConnectVariableHeader, ConnectPacket, ConnectPayload
from hbmqtt.mqtt.connack import ConnackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket
@ -16,10 +13,12 @@ from hbmqtt.mqtt.subscribe import SubscribePacket
from hbmqtt.mqtt.suback import SubackPacket
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.adapters import ReaderAdapter, WriterAdapter
class ClientProtocolHandler(ProtocolHandler):
def __init__(self, loop=None):
super().__init__(loop)
def __init__(self, reader: ReaderAdapter, writer: WriterAdapter, loop=None):
super().__init__(reader, writer, loop)
self._ping_task = None
self._connack_waiter = None
self._pingresp_queue = asyncio.Queue()
@ -127,7 +126,7 @@ class ClientProtocolHandler(ProtocolHandler):
else:
vh.will_flag = False
header = MQTTFixedHeader(PacketType.CONNECT, 0x00)
header = MQTTFixedHeader(CONNECT, 0x00)
packet = ConnectPacket(header, vh, payload)
return packet

Wyświetl plik

@ -4,10 +4,9 @@
import logging
import asyncio
from datetime import datetime
from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket
from hbmqtt.mqtt import packet_class
from hbmqtt.errors import NoDataException, HBMQTTException
from hbmqtt.mqtt.packet import PacketType
from hbmqtt.mqtt.packet import *
from hbmqtt.mqtt.connack import ConnackPacket
from hbmqtt.mqtt.connect import ConnectPacket
from hbmqtt.mqtt.pingresp import PingRespPacket
@ -22,6 +21,7 @@ from hbmqtt.mqtt.subscribe import SubscribePacket
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket
from hbmqtt.adapters import ReaderAdapter, WriterAdapter
from hbmqtt.session import Session
from hbmqtt.specs import *
from hbmqtt.mqtt.protocol.inflight import *
@ -32,9 +32,11 @@ class ProtocolHandler:
Class implementing the MQTT communication protocol using asyncio features
"""
def __init__(self, loop=None):
def __init__(self, reader: ReaderAdapter, writer: WriterAdapter, loop=None):
self.logger = logging.getLogger(__name__)
self.session = None
self.reader = reader
self.writer = writer
if loop is None:
self._loop = asyncio.get_event_loop()
else:
@ -52,9 +54,6 @@ class ProtocolHandler:
def attach_to_session(self, session: Session):
self.session = session
self.session.handler = self
extra_info = self.session.writer.get_extra_info('sockname')
self.session.local_address = extra_info[0]
self.session.local_port = extra_info[1]
def detach_from_session(self):
self.session.handler = None
@ -124,7 +123,7 @@ class ProtocolHandler:
@asyncio.coroutine
def stop(self):
self._running = False
self.session.reader.feed_eof()
#self.session.reader.feed_eof()
yield from self.outgoing_queue.put("STOP")
yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop)
# Stop incoming messages flow waiter
@ -142,40 +141,40 @@ class ProtocolHandler:
keepalive_timeout = self.session.keep_alive
if keepalive_timeout <= 0:
keepalive_timeout = None
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), keepalive_timeout)
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader), keepalive_timeout)
if fixed_header:
cls = packet_class(fixed_header)
packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header)
packet = yield from cls.from_stream(self.reader, fixed_header=fixed_header)
self.logger.debug("%s <-in-- %s" % (self.session.client_id, repr(packet)))
task = None
if packet.fixed_header.packet_type == PacketType.CONNACK:
if packet.fixed_header.packet_type == CONNACK:
task = asyncio.Task(self.handle_connack(packet))
elif packet.fixed_header.packet_type == PacketType.SUBSCRIBE:
elif packet.fixed_header.packet_type == SUBSCRIBE:
task = asyncio.Task(self.handle_subscribe(packet))
elif packet.fixed_header.packet_type == PacketType.UNSUBSCRIBE:
elif packet.fixed_header.packet_type == UNSUBSCRIBE:
task = asyncio.Task(self.handle_unsubscribe(packet))
elif packet.fixed_header.packet_type == PacketType.SUBACK:
elif packet.fixed_header.packet_type == SUBACK:
task = asyncio.Task(self.handle_suback(packet))
elif packet.fixed_header.packet_type == PacketType.UNSUBACK:
elif packet.fixed_header.packet_type == UNSUBACK:
task = asyncio.Task(self.handle_unsuback(packet))
elif packet.fixed_header.packet_type == PacketType.PUBACK:
elif packet.fixed_header.packet_type == PUBACK:
task = asyncio.Task(self.handle_puback(packet))
elif packet.fixed_header.packet_type == PacketType.PUBREC:
elif packet.fixed_header.packet_type == PUBREC:
task = asyncio.Task(self.handle_pubrec(packet))
elif packet.fixed_header.packet_type == PacketType.PUBREL:
elif packet.fixed_header.packet_type == PUBREL:
task = asyncio.Task(self.handle_pubrel(packet))
elif packet.fixed_header.packet_type == PacketType.PUBCOMP:
elif packet.fixed_header.packet_type == PUBCOMP:
task = asyncio.Task(self.handle_pubcomp(packet))
elif packet.fixed_header.packet_type == PacketType.PINGREQ:
elif packet.fixed_header.packet_type == PINGREQ:
task = asyncio.Task(self.handle_pingreq(packet))
elif packet.fixed_header.packet_type == PacketType.PINGRESP:
elif packet.fixed_header.packet_type == PINGRESP:
task = asyncio.Task(self.handle_pingresp(packet))
elif packet.fixed_header.packet_type == PacketType.PUBLISH:
elif packet.fixed_header.packet_type == PUBLISH:
task = asyncio.Task(self.handle_publish(packet))
elif packet.fixed_header.packet_type == PacketType.DISCONNECT:
elif packet.fixed_header.packet_type == DISCONNECT:
task = asyncio.Task(self.handle_disconnect(packet))
elif packet.fixed_header.packet_type == PacketType.CONNECT:
elif packet.fixed_header.packet_type == CONNECT:
task = asyncio.Task(self.handle_connect(packet))
else:
self.logger.warn("%s Unhandled packet type: %s" %
@ -184,7 +183,7 @@ class ProtocolHandler:
# Wait for message handling ends
asyncio.wait([task])
else:
self.logger.debug("%s No more data, stopping reader coro" % self.session.client_id)
self.logger.debug("%s No more data (EOF received), stopping reader coro" % self.session.client_id)
yield from self.handle_connection_closed()
break
except asyncio.TimeoutError:
@ -210,9 +209,9 @@ class ProtocolHandler:
if not isinstance(packet, MQTTPacket):
self.logger.debug("%s Writer interruption" % self.session.client_id)
break
yield from packet.to_stream(self.session.writer)
yield from packet.to_stream(self.writer)
self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet)))
yield from self.session.writer.drain()
yield from self.writer.drain()
except asyncio.TimeoutError as ce:
self.logger.debug("%s Output queue get timeout" % self.session.client_id)
if self._running:

Wyświetl plik

@ -14,8 +14,6 @@ class Session:
self.writer = None
self.remote_address = None
self.remote_port = None
self.local_address = None
self.local_port = None
self.client_id = None
self.clean_session = None
self.will_flag = False