Manage reader,writer in handler

pull/8/head
Nicolas Jouanin 2015-08-01 21:51:20 +02:00
rodzic 77a96faeeb
commit dd19fbb511
4 zmienionych plików z 38 dodań i 39 usunięć

Wyświetl plik

@ -10,8 +10,9 @@ from transitions import MachineError
from hbmqtt.utils import not_in_dict_or_none from hbmqtt.utils import not_in_dict_or_none
from hbmqtt.session import Session from hbmqtt.session import Session
from hbmqtt.mqtt.connack import ReturnCode from hbmqtt.mqtt.connack import CONNECTION_ACCEPTED
from hbmqtt.mqtt.protocol.client_handler import ClientProtocolHandler from hbmqtt.mqtt.protocol.client_handler import ClientProtocolHandler
from hbmqtt.adapters import StreamReaderAdapter, StreamWriterAdapter
_defaults = { _defaults = {
'keep_alive': 10, 'keep_alive': 10,
@ -189,15 +190,17 @@ class MQTTClient:
@asyncio.coroutine @asyncio.coroutine
def _connect_coro(self): def _connect_coro(self):
try: try:
self.session.reader, self.session.writer = \ conn_reader, conn_writer = \
yield from asyncio.open_connection(self.session.remote_address, self.session.remote_port) yield from asyncio.open_connection(self.session.remote_address, self.session.remote_port)
self._handler = ClientProtocolHandler(loop=self._loop) reader = StreamReaderAdapter(conn_reader)
writer = StreamWriterAdapter(conn_writer)
self._handler = ClientProtocolHandler(reader, writer, loop=self._loop)
self._handler.attach_to_session(self.session) self._handler.attach_to_session(self.session)
yield from self._handler.start() yield from self._handler.start()
return_code = yield from self._handler.mqtt_connect() return_code = yield from self._handler.mqtt_connect()
if return_code is not ReturnCode.CONNECTION_ACCEPTED: if return_code is not CONNECTION_ACCEPTED:
yield from self._handler.stop() yield from self._handler.stop()
self.session.machine.disconnect() self.session.machine.disconnect()
self.logger.warn("Connection rejected with code '%s'" % return_code) self.logger.warn("Connection rejected with code '%s'" % return_code)

Wyświetl plik

@ -1,12 +1,9 @@
# Copyright (c) 2015 Nicolas JOUANIN # Copyright (c) 2015 Nicolas JOUANIN
# #
# See the file license.txt for copying permission. # See the file license.txt for copying permission.
import logging
import asyncio
from asyncio import futures from asyncio import futures
from hbmqtt.mqtt.protocol.handler import ProtocolHandler from hbmqtt.mqtt.protocol.handler import ProtocolHandler
from hbmqtt.mqtt.packet import MQTTFixedHeader from hbmqtt.mqtt.packet import *
from hbmqtt.mqtt.packet import PacketType
from hbmqtt.mqtt.connect import ConnectVariableHeader, ConnectPacket, ConnectPayload from hbmqtt.mqtt.connect import ConnectVariableHeader, ConnectPacket, ConnectPayload
from hbmqtt.mqtt.connack import ConnackPacket from hbmqtt.mqtt.connack import ConnackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.mqtt.disconnect import DisconnectPacket
@ -16,10 +13,12 @@ from hbmqtt.mqtt.subscribe import SubscribePacket
from hbmqtt.mqtt.suback import SubackPacket from hbmqtt.mqtt.suback import SubackPacket
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.adapters import ReaderAdapter, WriterAdapter
class ClientProtocolHandler(ProtocolHandler): class ClientProtocolHandler(ProtocolHandler):
def __init__(self, loop=None): def __init__(self, reader: ReaderAdapter, writer: WriterAdapter, loop=None):
super().__init__(loop) super().__init__(reader, writer, loop)
self._ping_task = None self._ping_task = None
self._connack_waiter = None self._connack_waiter = None
self._pingresp_queue = asyncio.Queue() self._pingresp_queue = asyncio.Queue()
@ -127,7 +126,7 @@ class ClientProtocolHandler(ProtocolHandler):
else: else:
vh.will_flag = False vh.will_flag = False
header = MQTTFixedHeader(PacketType.CONNECT, 0x00) header = MQTTFixedHeader(CONNECT, 0x00)
packet = ConnectPacket(header, vh, payload) packet = ConnectPacket(header, vh, payload)
return packet return packet

Wyświetl plik

@ -4,10 +4,9 @@
import logging import logging
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket
from hbmqtt.mqtt import packet_class from hbmqtt.mqtt import packet_class
from hbmqtt.errors import NoDataException, HBMQTTException from hbmqtt.errors import NoDataException, HBMQTTException
from hbmqtt.mqtt.packet import PacketType from hbmqtt.mqtt.packet import *
from hbmqtt.mqtt.connack import ConnackPacket from hbmqtt.mqtt.connack import ConnackPacket
from hbmqtt.mqtt.connect import ConnectPacket from hbmqtt.mqtt.connect import ConnectPacket
from hbmqtt.mqtt.pingresp import PingRespPacket from hbmqtt.mqtt.pingresp import PingRespPacket
@ -22,6 +21,7 @@ from hbmqtt.mqtt.subscribe import SubscribePacket
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.mqtt.disconnect import DisconnectPacket
from hbmqtt.adapters import ReaderAdapter, WriterAdapter
from hbmqtt.session import Session from hbmqtt.session import Session
from hbmqtt.specs import * from hbmqtt.specs import *
from hbmqtt.mqtt.protocol.inflight import * from hbmqtt.mqtt.protocol.inflight import *
@ -32,9 +32,11 @@ class ProtocolHandler:
Class implementing the MQTT communication protocol using asyncio features Class implementing the MQTT communication protocol using asyncio features
""" """
def __init__(self, loop=None): def __init__(self, reader: ReaderAdapter, writer: WriterAdapter, loop=None):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.session = None self.session = None
self.reader = reader
self.writer = writer
if loop is None: if loop is None:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
else: else:
@ -52,9 +54,6 @@ class ProtocolHandler:
def attach_to_session(self, session: Session): def attach_to_session(self, session: Session):
self.session = session self.session = session
self.session.handler = self self.session.handler = self
extra_info = self.session.writer.get_extra_info('sockname')
self.session.local_address = extra_info[0]
self.session.local_port = extra_info[1]
def detach_from_session(self): def detach_from_session(self):
self.session.handler = None self.session.handler = None
@ -124,7 +123,7 @@ class ProtocolHandler:
@asyncio.coroutine @asyncio.coroutine
def stop(self): def stop(self):
self._running = False self._running = False
self.session.reader.feed_eof() #self.session.reader.feed_eof()
yield from self.outgoing_queue.put("STOP") yield from self.outgoing_queue.put("STOP")
yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop) yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop)
# Stop incoming messages flow waiter # Stop incoming messages flow waiter
@ -142,40 +141,40 @@ class ProtocolHandler:
keepalive_timeout = self.session.keep_alive keepalive_timeout = self.session.keep_alive
if keepalive_timeout <= 0: if keepalive_timeout <= 0:
keepalive_timeout = None keepalive_timeout = None
fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), keepalive_timeout) fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader), keepalive_timeout)
if fixed_header: if fixed_header:
cls = packet_class(fixed_header) cls = packet_class(fixed_header)
packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header) packet = yield from cls.from_stream(self.reader, fixed_header=fixed_header)
self.logger.debug("%s <-in-- %s" % (self.session.client_id, repr(packet))) self.logger.debug("%s <-in-- %s" % (self.session.client_id, repr(packet)))
task = None task = None
if packet.fixed_header.packet_type == PacketType.CONNACK: if packet.fixed_header.packet_type == CONNACK:
task = asyncio.Task(self.handle_connack(packet)) task = asyncio.Task(self.handle_connack(packet))
elif packet.fixed_header.packet_type == PacketType.SUBSCRIBE: elif packet.fixed_header.packet_type == SUBSCRIBE:
task = asyncio.Task(self.handle_subscribe(packet)) task = asyncio.Task(self.handle_subscribe(packet))
elif packet.fixed_header.packet_type == PacketType.UNSUBSCRIBE: elif packet.fixed_header.packet_type == UNSUBSCRIBE:
task = asyncio.Task(self.handle_unsubscribe(packet)) task = asyncio.Task(self.handle_unsubscribe(packet))
elif packet.fixed_header.packet_type == PacketType.SUBACK: elif packet.fixed_header.packet_type == SUBACK:
task = asyncio.Task(self.handle_suback(packet)) task = asyncio.Task(self.handle_suback(packet))
elif packet.fixed_header.packet_type == PacketType.UNSUBACK: elif packet.fixed_header.packet_type == UNSUBACK:
task = asyncio.Task(self.handle_unsuback(packet)) task = asyncio.Task(self.handle_unsuback(packet))
elif packet.fixed_header.packet_type == PacketType.PUBACK: elif packet.fixed_header.packet_type == PUBACK:
task = asyncio.Task(self.handle_puback(packet)) task = asyncio.Task(self.handle_puback(packet))
elif packet.fixed_header.packet_type == PacketType.PUBREC: elif packet.fixed_header.packet_type == PUBREC:
task = asyncio.Task(self.handle_pubrec(packet)) task = asyncio.Task(self.handle_pubrec(packet))
elif packet.fixed_header.packet_type == PacketType.PUBREL: elif packet.fixed_header.packet_type == PUBREL:
task = asyncio.Task(self.handle_pubrel(packet)) task = asyncio.Task(self.handle_pubrel(packet))
elif packet.fixed_header.packet_type == PacketType.PUBCOMP: elif packet.fixed_header.packet_type == PUBCOMP:
task = asyncio.Task(self.handle_pubcomp(packet)) task = asyncio.Task(self.handle_pubcomp(packet))
elif packet.fixed_header.packet_type == PacketType.PINGREQ: elif packet.fixed_header.packet_type == PINGREQ:
task = asyncio.Task(self.handle_pingreq(packet)) task = asyncio.Task(self.handle_pingreq(packet))
elif packet.fixed_header.packet_type == PacketType.PINGRESP: elif packet.fixed_header.packet_type == PINGRESP:
task = asyncio.Task(self.handle_pingresp(packet)) task = asyncio.Task(self.handle_pingresp(packet))
elif packet.fixed_header.packet_type == PacketType.PUBLISH: elif packet.fixed_header.packet_type == PUBLISH:
task = asyncio.Task(self.handle_publish(packet)) task = asyncio.Task(self.handle_publish(packet))
elif packet.fixed_header.packet_type == PacketType.DISCONNECT: elif packet.fixed_header.packet_type == DISCONNECT:
task = asyncio.Task(self.handle_disconnect(packet)) task = asyncio.Task(self.handle_disconnect(packet))
elif packet.fixed_header.packet_type == PacketType.CONNECT: elif packet.fixed_header.packet_type == CONNECT:
task = asyncio.Task(self.handle_connect(packet)) task = asyncio.Task(self.handle_connect(packet))
else: else:
self.logger.warn("%s Unhandled packet type: %s" % self.logger.warn("%s Unhandled packet type: %s" %
@ -184,7 +183,7 @@ class ProtocolHandler:
# Wait for message handling ends # Wait for message handling ends
asyncio.wait([task]) asyncio.wait([task])
else: else:
self.logger.debug("%s No more data, stopping reader coro" % self.session.client_id) self.logger.debug("%s No more data (EOF received), stopping reader coro" % self.session.client_id)
yield from self.handle_connection_closed() yield from self.handle_connection_closed()
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -210,9 +209,9 @@ class ProtocolHandler:
if not isinstance(packet, MQTTPacket): if not isinstance(packet, MQTTPacket):
self.logger.debug("%s Writer interruption" % self.session.client_id) self.logger.debug("%s Writer interruption" % self.session.client_id)
break break
yield from packet.to_stream(self.session.writer) yield from packet.to_stream(self.writer)
self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet))) self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet)))
yield from self.session.writer.drain() yield from self.writer.drain()
except asyncio.TimeoutError as ce: except asyncio.TimeoutError as ce:
self.logger.debug("%s Output queue get timeout" % self.session.client_id) self.logger.debug("%s Output queue get timeout" % self.session.client_id)
if self._running: if self._running:

Wyświetl plik

@ -14,8 +14,6 @@ class Session:
self.writer = None self.writer = None
self.remote_address = None self.remote_address = None
self.remote_port = None self.remote_port = None
self.local_address = None
self.local_port = None
self.client_id = None self.client_id = None
self.clean_session = None self.clean_session = None
self.will_flag = False self.will_flag = False