From 15c63dc2e4b41de1e03a3e3583157f2628b8d508 Mon Sep 17 00:00:00 2001 From: Nicolas Jouanin Date: Wed, 8 Jul 2015 21:54:10 +0200 Subject: [PATCH] Implement client connection/disconnection HBMQTT-13 --- hbmqtt/broker.py | 21 +++++++++++-- hbmqtt/mqtt/protocol/broker_handler.py | 43 ++++++++++++++++++++++++++ hbmqtt/mqtt/protocol/handler.py | 11 ++++++- 3 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 hbmqtt/mqtt/protocol/broker_handler.py diff --git a/hbmqtt/broker.py b/hbmqtt/broker.py index 66a0516..ca0e3e2 100644 --- a/hbmqtt/broker.py +++ b/hbmqtt/broker.py @@ -5,6 +5,8 @@ import logging import asyncio from transitions import Machine, MachineError +from hbmqtt.session import Session +from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler _defaults = { @@ -32,7 +34,7 @@ class Broker: self._loop = asyncio.get_event_loop() self._server = None - + self._handlers = [] self._init_states() def _init_states(self): @@ -79,5 +81,20 @@ class Broker: @asyncio.coroutine def client_connected(self, reader, writer): - (remote_address, remote_port) = writer.get_extra_info('peername') + self.logger.info(repr(writer.get_extra_info('peername'))) + extra_info = writer.get_extra_info('peername') + remote_address = extra_info[0] + remote_port = extra_info[1] self.logger.debug("Connection from %s:%d" % (remote_address, remote_port)) + new_session = Session() + new_session.remote_address = remote_address + new_session.remote_port = remote_port + new_session.reader = reader + new_session.writer = writer + handler = BrokerProtocolHandler(new_session, self._loop) + self._handlers.append(handler) + yield from handler.start() + self.logger.debug("Start messages handling") + yield from handler.wait_disconnect() + self.logger.debug("Wait for disconnect") + yield from handler.stop() \ No newline at end of file diff --git a/hbmqtt/mqtt/protocol/broker_handler.py b/hbmqtt/mqtt/protocol/broker_handler.py new file mode 100644 index 0000000..5db2c2c --- /dev/null +++ b/hbmqtt/mqtt/protocol/broker_handler.py @@ -0,0 +1,43 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. +import logging +import asyncio +from asyncio import futures +from hbmqtt.mqtt.protocol.handler import ProtocolHandler +from hbmqtt.mqtt.packet import MQTTFixedHeader +from hbmqtt.mqtt.packet import PacketType +from hbmqtt.mqtt.connect import ConnectVariableHeader, ConnectPacket, ConnectPayload +from hbmqtt.mqtt.connack import ConnackPacket +from hbmqtt.mqtt.disconnect import DisconnectPacket +from hbmqtt.mqtt.pingreq import PingReqPacket +from hbmqtt.mqtt.pingresp import PingRespPacket +from hbmqtt.mqtt.subscribe import SubscribePacket +from hbmqtt.mqtt.suback import SubackPacket +from hbmqtt.mqtt.unsubscribe import UnsubscribePacket +from hbmqtt.mqtt.unsuback import UnsubackPacket +from hbmqtt.session import Session + +class BrokerProtocolHandler(ProtocolHandler): + def __init__(self, session: Session, loop=None): + super().__init__(session, loop) + self._disconnect_waiter = None + + @asyncio.coroutine + def start(self): + yield from super().start() + + @asyncio.coroutine + def stop(self): + yield from super().stop() + + @asyncio.coroutine + def wait_disconnect(self): + if self._disconnect_waiter is None: + self._disconnect_waiter = futures.Future(loop=self._loop) + yield from self._disconnect_waiter + + @asyncio.coroutine + def handle_disconnect(self, disconnect: DisconnectPacket): + if self._disconnect_waiter is not None: + self._disconnect_waiter.set_result(disconnect) diff --git a/hbmqtt/mqtt/protocol/handler.py b/hbmqtt/mqtt/protocol/handler.py index 0b17ee6..6551bc9 100644 --- a/hbmqtt/mqtt/protocol/handler.py +++ b/hbmqtt/mqtt/protocol/handler.py @@ -17,6 +17,7 @@ from hbmqtt.mqtt.pubrec import PubrecPacket from hbmqtt.mqtt.pubcomp import PubcompPacket from hbmqtt.mqtt.suback import SubackPacket from hbmqtt.mqtt.unsuback import UnsubackPacket +from hbmqtt.mqtt.disconnect import DisconnectPacket from hbmqtt.session import Session from transitions import Machine @@ -63,7 +64,9 @@ class ProtocolHandler: self._running = False - self.session.local_address, self.session.local_port = self.session.writer.get_extra_info('sockname') + extra_info = self.session.writer.get_extra_info('sockname') + self.session.local_address = extra_info[0] + self.session.local_port = extra_info[1] self.incoming_queues = dict() self.application_messages = asyncio.Queue() @@ -164,6 +167,8 @@ class ProtocolHandler: asyncio.Task(self.handle_pingresp(packet)) elif packet.fixed_header.packet_type == PacketType.PUBLISH: asyncio.Task(self.handle_publish(packet)) + elif packet.fixed_header.packet_type == PacketType.DISCONNECT: + asyncio.Task(self.handle_disconnect(packet)) else: self.logger.warn("Unhandled packet type: %s" % packet.fixed_header.packet_type) else: @@ -238,6 +243,10 @@ class ProtocolHandler: def handle_pingresp(self, pingresp: PingRespPacket): pass + @asyncio.coroutine + def handle_disconnect(self, disconnect: DisconnectPacket): + pass + @asyncio.coroutine def handle_puback(self, puback: PubackPacket): packet_id = puback.variable_header.packet_id