Merge pull request #213 from shipmints/misc-fixes

Misc fixes
pull/5/head
Nicolas 2020-04-10 20:47:06 +04:00 zatwierdzone przez GitHub
commit 07c4c70f06
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
6 zmienionych plików z 46 dodań i 25 usunięć

Wyświetl plik

@ -158,12 +158,15 @@ class StreamWriterAdapter(WriterAdapter):
def __init__(self, writer: StreamWriter): def __init__(self, writer: StreamWriter):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self._writer = writer self._writer = writer
self.is_closed = False # StreamWriter has no test for closed...we use our own
def write(self, data): def write(self, data):
if not self.is_closed:
self._writer.write(data) self._writer.write(data)
@asyncio.coroutine @asyncio.coroutine
def drain(self): def drain(self):
if not self.is_closed:
yield from self._writer.drain() yield from self._writer.drain()
def get_peer_info(self): def get_peer_info(self):
@ -172,10 +175,14 @@ class StreamWriterAdapter(WriterAdapter):
@asyncio.coroutine @asyncio.coroutine
def close(self): def close(self):
if not self.is_closed:
self.is_closed = True # we first mark this closed so yields below don't cause races with waiting writes
yield from self._writer.drain() yield from self._writer.drain()
if self._writer.can_write_eof(): if self._writer.can_write_eof():
self._writer.write_eof() self._writer.write_eof()
self._writer.close() self._writer.close()
try: yield from self._writer.wait_closed() # py37+
except AttributeError: pass
class BufferReader(ReaderAdapter): class BufferReader(ReaderAdapter):

Wyświetl plik

@ -692,7 +692,9 @@ class Broker:
try: try:
while True: while True:
while running_tasks and running_tasks[0].done(): while running_tasks and running_tasks[0].done():
running_tasks.popleft() task = running_tasks.popleft()
try: task.result() # make asyncio happy and collect results
except Exception: pass
broadcast = yield from self._broadcast_queue.get() broadcast = yield from self._broadcast_queue.get()
if self.logger.isEnabledFor(logging.DEBUG): if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("broadcasting %r" % broadcast) self.logger.debug("broadcasting %r" % broadcast)
@ -705,6 +707,7 @@ class Broker:
if 'qos' in broadcast: if 'qos' in broadcast:
qos = broadcast['qos'] qos = broadcast['qos']
if target_session.transitions.state == 'connected': if target_session.transitions.state == 'connected':
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("broadcasting application message from %s on topic '%s' to %s" % self.logger.debug("broadcasting application message from %s on topic '%s' to %s" %
(format_client_message(session=broadcast['session']), (format_client_message(session=broadcast['session']),
broadcast['topic'], format_client_message(session=target_session))) broadcast['topic'], format_client_message(session=target_session)))
@ -713,17 +716,21 @@ class Broker:
handler.mqtt_publish(broadcast['topic'], broadcast['data'], qos, retain=False), handler.mqtt_publish(broadcast['topic'], broadcast['data'], qos, retain=False),
loop=self._loop) loop=self._loop)
running_tasks.append(task) running_tasks.append(task)
else: elif qos is not None and qos > 0:
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("retaining application message from %s on topic '%s' to client '%s'" % self.logger.debug("retaining application message from %s on topic '%s' to client '%s'" %
(format_client_message(session=broadcast['session']), (format_client_message(session=broadcast['session']),
broadcast['topic'], format_client_message(session=target_session))) broadcast['topic'], format_client_message(session=target_session)))
retained_message = RetainedApplicationMessage( retained_message = RetainedApplicationMessage(
broadcast['session'], broadcast['topic'], broadcast['data'], qos) broadcast['session'], broadcast['topic'], broadcast['data'], qos)
yield from target_session.retained_messages.put(retained_message) yield from target_session.retained_messages.put(retained_message)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f'target_session.retained_messages={target_session.retained_messages.qsize()}')
except CancelledError: except CancelledError:
# Wait until current broadcasting tasks end # Wait until current broadcasting tasks end
if running_tasks: if running_tasks:
yield from asyncio.wait(running_tasks, loop=self._loop) yield from asyncio.wait(running_tasks, loop=self._loop)
raise # reraise per CancelledError semantics
@asyncio.coroutine @asyncio.coroutine
def _broadcast_message(self, session, topic, data, force_qos=None): def _broadcast_message(self, session, topic, data, force_qos=None):

Wyświetl plik

@ -182,7 +182,7 @@ class MQTTClient:
:return: :return:
""" """
try: try:
while True: while self.client_tasks:
task = self.client_tasks.pop() task = self.client_tasks.pop()
task.cancel() task.cancel()
except IndexError as err: except IndexError as err:
@ -349,16 +349,16 @@ class MQTTClient:
self.client_tasks.append(deliver_task) self.client_tasks.append(deliver_task)
self.logger.debug("Waiting message delivery") self.logger.debug("Waiting message delivery")
done, pending = yield from asyncio.wait([deliver_task], loop=self._loop, return_when=asyncio.FIRST_EXCEPTION, timeout=timeout) done, pending = yield from asyncio.wait([deliver_task], loop=self._loop, return_when=asyncio.FIRST_EXCEPTION, timeout=timeout)
if self.client_tasks:
self.client_tasks.pop()
if deliver_task in done: if deliver_task in done:
if deliver_task.exception() is not None: if deliver_task.exception() is not None:
# deliver_task raised an exception, pass it on to our caller # deliver_task raised an exception, pass it on to our caller
raise deliver_task.exception() raise deliver_task.exception()
self.client_tasks.pop()
return deliver_task.result() return deliver_task.result()
else: else:
#timeout occured before message received #timeout occured before message received
deliver_task.cancel() deliver_task.cancel()
self.client_tasks.pop()
raise asyncio.TimeoutError raise asyncio.TimeoutError
@asyncio.coroutine @asyncio.coroutine
@ -456,7 +456,7 @@ class MQTTClient:
while self.client_tasks: while self.client_tasks:
task = self.client_tasks.popleft() task = self.client_tasks.popleft()
if not task.done(): if not task.done():
task.set_exception(ClientException("Connection lost")) task.cancel()
self.logger.debug("Watch broker disconnection") self.logger.debug("Watch broker disconnection")
# Wait for disconnection from broker (like connection lost) # Wait for disconnection from broker (like connection lost)

Wyświetl plik

@ -49,7 +49,10 @@ def read_or_raise(reader, n=-1):
:param n: number of bytes to read :param n: number of bytes to read
:return: bytes read :return: bytes read
""" """
try:
data = yield from reader.read(n) data = yield from reader.read(n)
except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError):
data = None
if not data: if not data:
raise NoDataException("No more data") raise NoDataException("No more data")
return data return data

Wyświetl plik

@ -417,7 +417,7 @@ class ProtocolHandler:
if task: if task:
running_tasks.append(task) running_tasks.append(task)
else: else:
self.logger.debug("%s No more data (EOF received), stopping reader coro" % self.session.client_id) self.logger.debug("No more data (EOF received), stopping reader coro")
break break
except MQTTException: except MQTTException:
self.logger.debug("Message discarded") self.logger.debug("Message discarded")
@ -425,10 +425,10 @@ class ProtocolHandler:
self.logger.debug("Task cancelled, reader loop ending") self.logger.debug("Task cancelled, reader loop ending")
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.logger.debug("%s Input stream read timeout" % self.session.client_id) self.logger.debug("Input stream read timeout")
self.handle_read_timeout() self.handle_read_timeout()
except NoDataException: except NoDataException:
self.logger.debug("%s No data available" % self.session.client_id) self.logger.debug("No data available")
except BaseException as e: except BaseException as e:
self.logger.warning("%s Unhandled exception in reader coro: %r" % (type(self).__name__, e)) self.logger.warning("%s Unhandled exception in reader coro: %r" % (type(self).__name__, e))
break break
@ -436,7 +436,7 @@ class ProtocolHandler:
running_tasks.popleft().cancel() running_tasks.popleft().cancel()
yield from self.handle_connection_closed() yield from self.handle_connection_closed()
self._reader_stopped.set() self._reader_stopped.set()
self.logger.debug("%s Reader coro stopped" % self.session.client_id) self.logger.debug("Reader coro stopped")
yield from self.stop() yield from self.stop()
@asyncio.coroutine @asyncio.coroutine
@ -449,8 +449,9 @@ class ProtocolHandler:
self._keepalive_task = self._loop.call_later(self.keepalive_timeout, self.handle_write_timeout) self._keepalive_task = self._loop.call_later(self.keepalive_timeout, self.handle_write_timeout)
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session) yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session)
except ConnectionResetError as cre: except (ConnectionResetError, BrokenPipeError):
yield from self.handle_connection_closed() yield from self.handle_connection_closed()
except asyncio.CancelledError:
raise raise
except BaseException as e: except BaseException as e:
self.logger.warning("Unhandled exception: %s" % e) self.logger.warning("Unhandled exception: %s" % e)
@ -458,6 +459,8 @@ class ProtocolHandler:
@asyncio.coroutine @asyncio.coroutine
def mqtt_deliver_next_message(self): def mqtt_deliver_next_message(self):
if not self._is_attached():
return None
if self.logger.isEnabledFor(logging.DEBUG): if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("%d message(s) available for delivery" % self.session.delivered_message_queue.qsize()) self.logger.debug("%d message(s) available for delivery" % self.session.delivered_message_queue.qsize())
try: try:

Wyświetl plik

@ -149,6 +149,7 @@ class PluginManager:
if wait: if wait:
if tasks: if tasks:
yield from asyncio.wait(tasks, loop=self._loop) yield from asyncio.wait(tasks, loop=self._loop)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug("Plugins len(_fired_events)=%d" % (len(self._fired_events))) self.logger.debug("Plugins len(_fired_events)=%d" % (len(self._fired_events)))
@asyncio.coroutine @asyncio.coroutine