kopia lustrzana https://github.com/Yakifo/amqtt
Disconnection and message handling refactoring
rodzic
9cc8f01c02
commit
cbdb97aefc
|
@ -213,39 +213,43 @@ class Broker:
|
|||
client_session.machine.connect()
|
||||
handler = BrokerProtocolHandler(self._loop)
|
||||
handler.attach_to_session(client_session)
|
||||
self.logger.debug("Start messages handling")
|
||||
self.logger.debug("%s Start messages handling" % client_session.client_id)
|
||||
yield from handler.start()
|
||||
yield from self.publish_session_retained_messages(client_session)
|
||||
self.logger.debug("Wait for disconnect")
|
||||
self.logger.debug("%s Wait for disconnect" % client_session.client_id)
|
||||
|
||||
connected = True
|
||||
wait_disconnect = asyncio.Task(handler.wait_disconnect())
|
||||
wait_subscription = asyncio.Task(handler.get_next_pending_subscription())
|
||||
wait_unsubscription = asyncio.Task(handler.get_next_pending_unsubscription())
|
||||
wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message())
|
||||
disconnect_event = False
|
||||
while connected:
|
||||
done, pending = yield from asyncio.wait(
|
||||
[wait_disconnect, wait_subscription, wait_unsubscription, wait_deliver],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
if wait_disconnect in done:
|
||||
result = wait_disconnect.result()
|
||||
self.logger.debug("Result from wait_diconnect: %s" % result)
|
||||
if result is None:
|
||||
self.logger.debug("Will flag: %s" % client_session.will_flag)
|
||||
#Connection closed anormally, send will message
|
||||
if client_session.will_flag:
|
||||
self.logger.debug("Client %s disconnected abnormally, sending will message" %
|
||||
format_client_message(client_session))
|
||||
yield from self.broadcast_application_message(
|
||||
client_session, client_session.will_topic,
|
||||
client_session.will_message,
|
||||
client_session.will_qos)
|
||||
if client_session.will_retain:
|
||||
self.retain_message(client_session,
|
||||
client_session.will_topic,
|
||||
client_session.will_message,
|
||||
client_session.will_qos)
|
||||
connected = False
|
||||
if not disconnect_event:
|
||||
result = wait_disconnect.result()
|
||||
self.logger.debug("%s Result from wait_diconnect: %s" % (client_session.client_id, result))
|
||||
if result is None:
|
||||
self.logger.debug("Will flag: %s" % client_session.will_flag)
|
||||
#Connection closed anormally, send will message
|
||||
if client_session.will_flag:
|
||||
self.logger.debug("Client %s disconnected abnormally, sending will message" %
|
||||
format_client_message(client_session))
|
||||
yield from self.broadcast_application_message(
|
||||
client_session, client_session.will_topic,
|
||||
client_session.will_message,
|
||||
client_session.will_qos)
|
||||
if client_session.will_retain:
|
||||
self.retain_message(client_session,
|
||||
client_session.will_topic,
|
||||
client_session.will_message,
|
||||
client_session.will_qos)
|
||||
disconnect_event = True
|
||||
if not (wait_unsubscription.done() or wait_subscription.done() or wait_deliver.done):
|
||||
connected = False
|
||||
if wait_unsubscription in done:
|
||||
unsubscription = wait_unsubscription.result()
|
||||
for topic in unsubscription['topics']:
|
||||
|
@ -275,7 +279,7 @@ class Broker:
|
|||
wait_unsubscription.cancel()
|
||||
wait_deliver.cancel()
|
||||
|
||||
self.logger.debug("Client disconnecting")
|
||||
self.logger.debug("%s Client disconnecting" % client_session.client_id)
|
||||
try:
|
||||
yield from handler.stop()
|
||||
except Exception as e:
|
||||
|
@ -285,7 +289,7 @@ class Broker:
|
|||
handler = None
|
||||
client_session.machine.disconnect()
|
||||
writer.close()
|
||||
self.logger.debug("Session disconnected")
|
||||
self.logger.debug("%s Session disconnected" % client_session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def check_connect(self, connect: ConnectPacket):
|
||||
|
@ -306,12 +310,12 @@ class Broker:
|
|||
def retain_message(self, source_session, topic_name, data, qos=None):
|
||||
if data is not None and data != b'':
|
||||
# If retained flag set, store the message for further subscriptions
|
||||
self.logger.debug("Retaining message on topic %s" % topic_name)
|
||||
self.logger.debug("%s Retaining message on topic %s" % (source_session.client_id, topic_name))
|
||||
retained_message = RetainedApplicationMessage(source_session, topic_name, data, qos)
|
||||
self._global_retained_messages[topic_name] = retained_message
|
||||
else:
|
||||
# [MQTT-3.3.1-10]
|
||||
self.logger.debug("Clear retained messages for topic '%s'" % topic_name)
|
||||
self.logger.debug("%s Clear retained messages for topic '%s'" % (source_session.client_id, topic_name))
|
||||
del self._global_retained_messages[topic_name]
|
||||
|
||||
def add_subscription(self, subscription, session):
|
||||
|
@ -399,6 +403,9 @@ class Broker:
|
|||
asyncio.wait(publish_tasks)
|
||||
except Exception as e:
|
||||
self.logger.warn("Message broadcasting failed: %s", e)
|
||||
self.logger.debug("End Broadcasting message from %s on topic %s" %
|
||||
(format_client_message(session=source_session), topic)
|
||||
)
|
||||
|
||||
@asyncio.coroutine
|
||||
def publish_session_retained_messages(self, session):
|
||||
|
|
|
@ -48,7 +48,7 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
|
||||
@asyncio.coroutine
|
||||
def handle_disconnect(self, disconnect):
|
||||
if self._disconnect_waiter is not None:
|
||||
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
|
||||
self._disconnect_waiter.set_result(disconnect)
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -59,8 +59,8 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
def handle_connect(self, connect: ConnectPacket):
|
||||
# Broker handler shouldn't received CONNECT message during messages handling
|
||||
# as CONNECT messages are managed by the broker on client connection
|
||||
self.logger.error('[MQTT-3.1.0-2] %s : CONNECT message received during messages handling' %
|
||||
(format_client_message(self.session)))
|
||||
self.logger.error('%s [MQTT-3.1.0-2] %s : CONNECT message received during messages handling' %
|
||||
(self.session.client_id, format_client_message(self.session)))
|
||||
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
|
||||
self._disconnect_waiter.set_result(None)
|
||||
|
||||
|
|
|
@ -97,12 +97,12 @@ class ProtocolHandler:
|
|||
self._writer_task = asyncio.Task(self._writer_coro(), loop=self._loop)
|
||||
yield from asyncio.wait(
|
||||
[self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
|
||||
self.logger.debug("Handler tasks started")
|
||||
self.logger.debug("%s Handler tasks started" % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_publish(self, topic, message, packet_id, dup, qos, retain):
|
||||
if packet_id in self.session.inflight_out:
|
||||
self.logger.warn("A message with the same packet ID is already in flight")
|
||||
self.logger.warn("%s A message with the same packet ID is already in flight" % self.session.client_id)
|
||||
packet = PublishPacket.build(topic, message, packet_id, dup, qos, retain)
|
||||
yield from self.outgoing_queue.put(packet)
|
||||
inflight_message = InFlightMessage(packet, qos)
|
||||
|
@ -151,7 +151,7 @@ class ProtocolHandler:
|
|||
|
||||
@asyncio.coroutine
|
||||
def _reader_coro(self):
|
||||
self.logger.debug("Starting reader coro")
|
||||
self.logger.debug("%s Starting reader coro" % self.session.client_id)
|
||||
while self._running:
|
||||
try:
|
||||
self._reader_ready.set()
|
||||
|
@ -162,55 +162,56 @@ class ProtocolHandler:
|
|||
if fixed_header:
|
||||
cls = packet_class(fixed_header)
|
||||
packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header)
|
||||
self.logger.debug(" <-in-- " + repr(packet))
|
||||
self.logger.debug("%s <-in-- %s" % (self.session.client_id, repr(packet)))
|
||||
|
||||
if packet.fixed_header.packet_type == PacketType.CONNACK:
|
||||
asyncio.Task(self.handle_connack(packet))
|
||||
yield from self.handle_connack(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.SUBSCRIBE:
|
||||
asyncio.Task(self.handle_subscribe(packet))
|
||||
yield from self.handle_subscribe(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.UNSUBSCRIBE:
|
||||
asyncio.Task(self.handle_unsubscribe(packet))
|
||||
yield from self.handle_unsubscribe(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.SUBACK:
|
||||
asyncio.Task(self.handle_suback(packet))
|
||||
yield from self.handle_suback(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.UNSUBACK:
|
||||
asyncio.Task(self.handle_unsuback(packet))
|
||||
yield from self.handle_unsuback(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBACK:
|
||||
asyncio.Task(self.handle_puback(packet))
|
||||
yield from self.handle_puback(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBREC:
|
||||
asyncio.Task(self.handle_pubrec(packet))
|
||||
yield from self.handle_pubrec(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBREL:
|
||||
asyncio.Task(self.handle_pubrel(packet))
|
||||
yield from self.handle_pubrel(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBCOMP:
|
||||
asyncio.Task(self.handle_pubcomp(packet))
|
||||
yield from self.handle_pubcomp(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.PINGREQ:
|
||||
asyncio.Task(self.handle_pingreq(packet))
|
||||
yield from self.handle_pingreq(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.PINGRESP:
|
||||
asyncio.Task(self.handle_pingresp(packet))
|
||||
yield from self.handle_pingresp(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBLISH:
|
||||
asyncio.Task(self.handle_publish(packet))
|
||||
yield from self.handle_publish(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.DISCONNECT:
|
||||
asyncio.Task(self.handle_disconnect(packet))
|
||||
yield from self.handle_disconnect(packet)
|
||||
elif packet.fixed_header.packet_type == PacketType.CONNECT:
|
||||
asyncio.Task(self.handle_connect(packet))
|
||||
yield from self.handle_connect(packet)
|
||||
else:
|
||||
self.logger.warn("Unhandled packet type: %s" % packet.fixed_header.packet_type)
|
||||
self.logger.warn("%s Unhandled packet type: %s" %
|
||||
(self.session.client_id, packet.fixed_header.packet_type))
|
||||
else:
|
||||
self.logger.debug("No more data, stopping reader coro")
|
||||
self.logger.debug("%s No more data, stopping reader coro" % self.session.client_id)
|
||||
yield from self.handle_connection_closed()
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.debug("Input stream read timeout")
|
||||
self.logger.debug("%s Input stream read timeout" % self.session.client_id)
|
||||
self.handle_read_timeout()
|
||||
except NoDataException as nde:
|
||||
self.logger.debug("No data available")
|
||||
self.logger.debug("%s No data available" % self.session.client_id)
|
||||
except Exception as e:
|
||||
self.logger.warn("Unhandled exception in reader coro: %s" % e)
|
||||
self.logger.warn("%s Unhandled exception in reader coro: %s" % (self.session.client_id, e))
|
||||
break
|
||||
self.logger.debug("Reader coro stopped")
|
||||
self.logger.debug("%s Reader coro stopped" % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _writer_coro(self):
|
||||
self.logger.debug("Starting writer coro")
|
||||
self.logger.debug("%s Starting writer coro" % self.session.client_id)
|
||||
while self._running:
|
||||
try:
|
||||
self._writer_ready.set()
|
||||
|
@ -219,23 +220,22 @@ class ProtocolHandler:
|
|||
keepalive_timeout = None
|
||||
packet = yield from asyncio.wait_for(self.outgoing_queue.get(), keepalive_timeout)
|
||||
if not isinstance(packet, MQTTPacket):
|
||||
self.logger.debug("Writer interruption")
|
||||
self.logger.debug("%s Writer interruption" % self.session.client_id)
|
||||
break
|
||||
yield from packet.to_stream(self.session.writer)
|
||||
self.logger.debug(" -out-> " + repr(packet))
|
||||
self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet)))
|
||||
yield from self.session.writer.drain()
|
||||
#self.outgoing_queue.task_done() # to be used with Python 3.5
|
||||
except asyncio.TimeoutError as ce:
|
||||
self.logger.debug("Output queue get timeout")
|
||||
self.logger.debug("%s Output queue get timeout" % self.session.client_id)
|
||||
if self._running:
|
||||
self.handle_write_timeout()
|
||||
except ConnectionResetError as cre:
|
||||
yield from self.handle_connection_closed()
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.warn("Unhandled exception in writer coro: %s" % e)
|
||||
self.logger.warn("%sUnhandled exception in writer coro: %s" % (self.session.client_id, e))
|
||||
break
|
||||
self.logger.debug("Writer coro stopping")
|
||||
self.logger.debug("%s Writer coro stopping" % self.session.client_id)
|
||||
# Flush queue before stopping
|
||||
if not self.outgoing_queue.empty():
|
||||
while True:
|
||||
|
@ -244,12 +244,12 @@ class ProtocolHandler:
|
|||
if not isinstance(packet, MQTTPacket):
|
||||
break
|
||||
yield from packet.to_stream(self.session.writer)
|
||||
self.logger.debug(" -out-> " + repr(packet))
|
||||
self.logger.debug("%s -out-> %s" % (self.session.client_id, repr(packet)))
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.warn("Unhandled exception in writer coro: %s" % e)
|
||||
self.logger.debug("Writer coro stopped")
|
||||
self.logger.warn("%s Unhandled exception in writer coro: %s" % (self.session.client_id, e))
|
||||
self.logger.debug("%s Writer coro stopped" % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_deliver_next_message(self):
|
||||
|
@ -257,50 +257,50 @@ class ProtocolHandler:
|
|||
return inflight_message
|
||||
|
||||
def handle_write_timeout(self):
|
||||
self.logger.warn('write timeout unhandled')
|
||||
self.logger.warn('%s write timeout unhandled' % self.session.client_id)
|
||||
|
||||
def handle_read_timeout(self):
|
||||
self.logger.warn('read timeout unhandled')
|
||||
self.logger.warn('%s read timeout unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_connack(self, connack: ConnackPacket):
|
||||
self.logger.warn('CONNACK unhandled')
|
||||
self.logger.warn('%s CONNACK unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_connect(self, connect: ConnectPacket):
|
||||
self.logger.warn('CONNECT unhandled')
|
||||
self.logger.warn('%s CONNECT unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_subscribe(self, subscribe: SubscribePacket):
|
||||
self.logger.warn('SUBSCRIBE unhandled')
|
||||
self.logger.warn('%s SUBSCRIBE unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_unsubscribe(self, subscribe: UnsubscribePacket):
|
||||
self.logger.warn('UNSUBSCRIBE unhandled')
|
||||
self.logger.warn('%s UNSUBSCRIBE unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_suback(self, suback: SubackPacket):
|
||||
self.logger.warn('SUBACK unhandled')
|
||||
self.logger.warn('%s SUBACK unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
self.logger.warn('UNSUBACK unhandled')
|
||||
self.logger.warn('%s UNSUBACK unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_pingresp(self, pingresp: PingRespPacket):
|
||||
self.logger.warn('PINGRESP unhandled')
|
||||
self.logger.warn('%s PINGRESP unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_pingreq(self, pingreq: PingReqPacket):
|
||||
self.logger.warn('PINGREQ unhandled')
|
||||
self.logger.warn('%s PINGREQ unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_disconnect(self, disconnect: DisconnectPacket):
|
||||
self.logger.warn('DISCONNECT unhandled')
|
||||
self.logger.warn('%s DISCONNECT unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_connection_closed(self):
|
||||
self.logger.warn('Connection closed unhandled')
|
||||
self.logger.warn('%s Connection closed unhandled' % self.session.client_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_puback(self, puback: PubackPacket):
|
||||
|
@ -309,7 +309,8 @@ class ProtocolHandler:
|
|||
waiter = self._puback_waiters[packet_id]
|
||||
waiter.set_result(puback)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
self.logger.warn("%s Received PUBACK for unknown pending subscription with Id: %s" %
|
||||
(self.session.client_id, packet_id))
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_pubrec(self, pubrec: PubrecPacket):
|
||||
|
|
Ładowanie…
Reference in New Issue