Implement client connection/disconnection

HBMQTT-13
pull/8/head
Nicolas Jouanin 2015-07-08 21:54:10 +02:00
rodzic f53ae9e10a
commit 15c63dc2e4
3 zmienionych plików z 72 dodań i 3 usunięć

Wyświetl plik

@ -5,6 +5,8 @@ import logging
import asyncio
from transitions import Machine, MachineError
from hbmqtt.session import Session
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
_defaults = {
@ -32,7 +34,7 @@ class Broker:
self._loop = asyncio.get_event_loop()
self._server = None
self._handlers = []
self._init_states()
def _init_states(self):
@ -79,5 +81,20 @@ class Broker:
@asyncio.coroutine
def client_connected(self, reader, writer):
(remote_address, remote_port) = writer.get_extra_info('peername')
self.logger.info(repr(writer.get_extra_info('peername')))
extra_info = writer.get_extra_info('peername')
remote_address = extra_info[0]
remote_port = extra_info[1]
self.logger.debug("Connection from %s:%d" % (remote_address, remote_port))
new_session = Session()
new_session.remote_address = remote_address
new_session.remote_port = remote_port
new_session.reader = reader
new_session.writer = writer
handler = BrokerProtocolHandler(new_session, self._loop)
self._handlers.append(handler)
yield from handler.start()
self.logger.debug("Start messages handling")
yield from handler.wait_disconnect()
self.logger.debug("Wait for disconnect")
yield from handler.stop()

Wyświetl plik

@ -0,0 +1,43 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
import logging
import asyncio
from asyncio import futures
from hbmqtt.mqtt.protocol.handler import ProtocolHandler
from hbmqtt.mqtt.packet import MQTTFixedHeader
from hbmqtt.mqtt.packet import PacketType
from hbmqtt.mqtt.connect import ConnectVariableHeader, ConnectPacket, ConnectPayload
from hbmqtt.mqtt.connack import ConnackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket
from hbmqtt.mqtt.pingreq import PingReqPacket
from hbmqtt.mqtt.pingresp import PingRespPacket
from hbmqtt.mqtt.subscribe import SubscribePacket
from hbmqtt.mqtt.suback import SubackPacket
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.session import Session
class BrokerProtocolHandler(ProtocolHandler):
def __init__(self, session: Session, loop=None):
super().__init__(session, loop)
self._disconnect_waiter = None
@asyncio.coroutine
def start(self):
yield from super().start()
@asyncio.coroutine
def stop(self):
yield from super().stop()
@asyncio.coroutine
def wait_disconnect(self):
if self._disconnect_waiter is None:
self._disconnect_waiter = futures.Future(loop=self._loop)
yield from self._disconnect_waiter
@asyncio.coroutine
def handle_disconnect(self, disconnect: DisconnectPacket):
if self._disconnect_waiter is not None:
self._disconnect_waiter.set_result(disconnect)

Wyświetl plik

@ -17,6 +17,7 @@ from hbmqtt.mqtt.pubrec import PubrecPacket
from hbmqtt.mqtt.pubcomp import PubcompPacket
from hbmqtt.mqtt.suback import SubackPacket
from hbmqtt.mqtt.unsuback import UnsubackPacket
from hbmqtt.mqtt.disconnect import DisconnectPacket
from hbmqtt.session import Session
from transitions import Machine
@ -63,7 +64,9 @@ class ProtocolHandler:
self._running = False
self.session.local_address, self.session.local_port = self.session.writer.get_extra_info('sockname')
extra_info = self.session.writer.get_extra_info('sockname')
self.session.local_address = extra_info[0]
self.session.local_port = extra_info[1]
self.incoming_queues = dict()
self.application_messages = asyncio.Queue()
@ -164,6 +167,8 @@ class ProtocolHandler:
asyncio.Task(self.handle_pingresp(packet))
elif packet.fixed_header.packet_type == PacketType.PUBLISH:
asyncio.Task(self.handle_publish(packet))
elif packet.fixed_header.packet_type == PacketType.DISCONNECT:
asyncio.Task(self.handle_disconnect(packet))
else:
self.logger.warn("Unhandled packet type: %s" % packet.fixed_header.packet_type)
else:
@ -238,6 +243,10 @@ class ProtocolHandler:
def handle_pingresp(self, pingresp: PingRespPacket):
pass
@asyncio.coroutine
def handle_disconnect(self, disconnect: DisconnectPacket):
pass
@asyncio.coroutine
def handle_puback(self, puback: PubackPacket):
packet_id = puback.variable_header.packet_id