diff --git a/hbmqtt/client.py b/hbmqtt/client.py index 04d2918..91e23f4 100644 --- a/hbmqtt/client.py +++ b/hbmqtt/client.py @@ -10,8 +10,9 @@ from transitions import MachineError from hbmqtt.utils import not_in_dict_or_none from hbmqtt.session import Session -from hbmqtt.mqtt.connack import ReturnCode +from hbmqtt.mqtt.connack import CONNECTION_ACCEPTED from hbmqtt.mqtt.protocol.client_handler import ClientProtocolHandler +from hbmqtt.adapters import StreamReaderAdapter, StreamWriterAdapter _defaults = { 'keep_alive': 10, @@ -189,15 +190,17 @@ class MQTTClient: @asyncio.coroutine def _connect_coro(self): try: - self.session.reader, self.session.writer = \ + conn_reader, conn_writer = \ yield from asyncio.open_connection(self.session.remote_address, self.session.remote_port) - self._handler = ClientProtocolHandler(loop=self._loop) + reader = StreamReaderAdapter(conn_reader) + writer = StreamWriterAdapter(conn_writer) + self._handler = ClientProtocolHandler(reader, writer, loop=self._loop) self._handler.attach_to_session(self.session) yield from self._handler.start() return_code = yield from self._handler.mqtt_connect() - if return_code is not ReturnCode.CONNECTION_ACCEPTED: + if return_code is not CONNECTION_ACCEPTED: yield from self._handler.stop() self.session.machine.disconnect() self.logger.warn("Connection rejected with code '%s'" % return_code) diff --git a/hbmqtt/mqtt/protocol/client_handler.py b/hbmqtt/mqtt/protocol/client_handler.py index e13e49f..fff50f7 100644 --- a/hbmqtt/mqtt/protocol/client_handler.py +++ b/hbmqtt/mqtt/protocol/client_handler.py @@ -1,12 +1,9 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. -import logging -import asyncio from asyncio import futures from hbmqtt.mqtt.protocol.handler import ProtocolHandler -from hbmqtt.mqtt.packet import MQTTFixedHeader -from hbmqtt.mqtt.packet import PacketType +from hbmqtt.mqtt.packet import * from hbmqtt.mqtt.connect import ConnectVariableHeader, ConnectPacket, ConnectPayload from hbmqtt.mqtt.connack import ConnackPacket from hbmqtt.mqtt.disconnect import DisconnectPacket @@ -16,10 +13,12 @@ from hbmqtt.mqtt.subscribe import SubscribePacket from hbmqtt.mqtt.suback import SubackPacket from hbmqtt.mqtt.unsubscribe import UnsubscribePacket from hbmqtt.mqtt.unsuback import UnsubackPacket +from hbmqtt.adapters import ReaderAdapter, WriterAdapter + class ClientProtocolHandler(ProtocolHandler): - def __init__(self, loop=None): - super().__init__(loop) + def __init__(self, reader: ReaderAdapter, writer: WriterAdapter, loop=None): + super().__init__(reader, writer, loop) self._ping_task = None self._connack_waiter = None self._pingresp_queue = asyncio.Queue() @@ -127,7 +126,7 @@ class ClientProtocolHandler(ProtocolHandler): else: vh.will_flag = False - header = MQTTFixedHeader(PacketType.CONNECT, 0x00) + header = MQTTFixedHeader(CONNECT, 0x00) packet = ConnectPacket(header, vh, payload) return packet diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index d1d00e6..f5f69fe 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -4,10 +4,9 @@ import logging import asyncio from datetime import datetime -from hbmqtt.mqtt.packet import MQTTFixedHeader, MQTTPacket from hbmqtt.mqtt import packet_class from hbmqtt.errors import NoDataException, HBMQTTException -from hbmqtt.mqtt.packet import PacketType +from hbmqtt.mqtt.packet import * from hbmqtt.mqtt.connack import ConnackPacket from hbmqtt.mqtt.connect import ConnectPacket from hbmqtt.mqtt.pingresp import PingRespPacket @@ -22,6 +21,7 @@ from hbmqtt.mqtt.subscribe import SubscribePacket from hbmqtt.mqtt.unsubscribe import UnsubscribePacket from hbmqtt.mqtt.unsuback import UnsubackPacket from hbmqtt.mqtt.disconnect import DisconnectPacket +from hbmqtt.adapters import ReaderAdapter, WriterAdapter from hbmqtt.session import Session from hbmqtt.specs import * from hbmqtt.mqtt.protocol.inflight import * @@ -32,9 +32,11 @@ class ProtocolHandler: Class implementing the MQTT communication protocol using asyncio features """ - def __init__(self, loop=None): + def __init__(self, reader: ReaderAdapter, writer: WriterAdapter, loop=None): self.logger = logging.getLogger(__name__) self.session = None + self.reader = reader + self.writer = writer if loop is None: self._loop = asyncio.get_event_loop() else: @@ -52,9 +54,6 @@ class ProtocolHandler: def attach_to_session(self, session: Session): self.session = session self.session.handler = self - extra_info = self.session.writer.get_extra_info('sockname') - self.session.local_address = extra_info[0] - self.session.local_port = extra_info[1] def detach_from_session(self): self.session.handler = None @@ -124,7 +123,7 @@ class ProtocolHandler: @asyncio.coroutine def stop(self): self._running = False - self.session.reader.feed_eof() + #self.session.reader.feed_eof() yield from self.outgoing_queue.put("STOP") yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop) # Stop incoming messages flow waiter @@ -142,40 +141,40 @@ class ProtocolHandler: keepalive_timeout = self.session.keep_alive if keepalive_timeout <= 0: keepalive_timeout = None - fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.session.reader), keepalive_timeout) + fixed_header = yield from asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader), keepalive_timeout) if fixed_header: cls = packet_class(fixed_header) - packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header) + packet = yield from cls.from_stream(self.reader, fixed_header=fixed_header) self.logger.debug("%s <-in-- %s" % (self.session.client_id, repr(packet))) task = None - if packet.fixed_header.packet_type == PacketType.CONNACK: + if packet.fixed_header.packet_type == CONNACK: task = asyncio.Task(self.handle_connack(packet)) - elif packet.fixed_header.packet_type == PacketType.SUBSCRIBE: + elif packet.fixed_header.packet_type == SUBSCRIBE: task = asyncio.Task(self.handle_subscribe(packet)) - elif packet.fixed_header.packet_type == PacketType.UNSUBSCRIBE: + elif packet.fixed_header.packet_type == UNSUBSCRIBE: task = asyncio.Task(self.handle_unsubscribe(packet)) - elif packet.fixed_header.packet_type == PacketType.SUBACK: + elif packet.fixed_header.packet_type == SUBACK: task = asyncio.Task(self.handle_suback(packet)) - elif packet.fixed_header.packet_type == PacketType.UNSUBACK: + elif packet.fixed_header.packet_type == UNSUBACK: task = asyncio.Task(self.handle_unsuback(packet)) - elif packet.fixed_header.packet_type == PacketType.PUBACK: + elif packet.fixed_header.packet_type == PUBACK: task = asyncio.Task(self.handle_puback(packet)) - elif packet.fixed_header.packet_type == PacketType.PUBREC: + elif packet.fixed_header.packet_type == PUBREC: task = asyncio.Task(self.handle_pubrec(packet)) - elif packet.fixed_header.packet_type == PacketType.PUBREL: + elif packet.fixed_header.packet_type == PUBREL: task = asyncio.Task(self.handle_pubrel(packet)) - elif packet.fixed_header.packet_type == PacketType.PUBCOMP: + elif packet.fixed_header.packet_type == PUBCOMP: task = asyncio.Task(self.handle_pubcomp(packet)) - elif packet.fixed_header.packet_type == PacketType.PINGREQ: + elif packet.fixed_header.packet_type == PINGREQ: task = asyncio.Task(self.handle_pingreq(packet)) - elif packet.fixed_header.packet_type == PacketType.PINGRESP: + elif packet.fixed_header.packet_type == PINGRESP: task = asyncio.Task(self.handle_pingresp(packet)) - elif packet.fixed_header.packet_type == PacketType.PUBLISH: + elif packet.fixed_header.packet_type == PUBLISH: task = asyncio.Task(self.handle_publish(packet)) - elif packet.fixed_header.packet_type == PacketType.DISCONNECT: + elif packet.fixed_header.packet_type == DISCONNECT: task = asyncio.Task(self.handle_disconnect(packet)) - elif packet.fixed_header.packet_type == PacketType.CONNECT: + elif packet.fixed_header.packet_type == CONNECT: task = asyncio.Task(self.handle_connect(packet)) else: self.logger.warn("%s Unhandled packet type: %s" % @@ -184,7 +183,7 @@ class ProtocolHandler: # Wait for message handling ends asyncio.wait([task]) else: - self.logger.debug("%s No more data, stopping reader coro" % self.session.client_id) + self.logger.debug("%s No more data (EOF received), stopping reader coro" % self.session.client_id) yield from self.handle_connection_closed() break except asyncio.TimeoutError: @@ -210,9 +209,9 @@ class ProtocolHandler: if not isinstance(packet, MQTTPacket): self.logger.debug("%s Writer interruption" % self.session.client_id) break - yield from packet.to_stream(self.session.writer) + yield from packet.to_stream(self.writer) self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet))) - yield from self.session.writer.drain() + yield from self.writer.drain() except asyncio.TimeoutError as ce: self.logger.debug("%s Output queue get timeout" % self.session.client_id) if self._running: diff --git a/hbmqtt/session.py b/hbmqtt/session.py index be3a7d3..63eac68 100644 --- a/hbmqtt/session.py +++ b/hbmqtt/session.py @@ -14,8 +14,6 @@ class Session: self.writer = None self.remote_address = None self.remote_port = None - self.local_address = None - self.local_port = None self.client_id = None self.clean_session = None self.will_flag = False