pull/8/head
Nico 2015-10-15 21:57:21 +02:00
rodzic 7ede67ed7e
commit 6a876be559
5 zmienionych plików z 27 dodań i 19 usunięć

Wyświetl plik

@ -3,7 +3,7 @@
# See the file license.txt for copying permission. # See the file license.txt for copying permission.
import asyncio import asyncio
from math import ceil from math import ceil
from struct import unpack from struct import pack, unpack
from hbmqtt.errors import NoDataException from hbmqtt.errors import NoDataException
@ -29,18 +29,18 @@ def bytes_to_int(data):
return data return data
def int_to_bytes(int_value: int, length=-1) -> bytes: def int_to_bytes(int_value: int, length: int) -> bytes:
""" """
convert an integer to a sequence of bytes using big endian byte ordering convert an integer to a sequence of bytes using big endian byte ordering
:param int_value: integer value to convert :param int_value: integer value to convert
:param length: (optional) byte length :param length: (optional) byte length
:return: byte sequence :return: byte sequence
""" """
if length == -1: if length == 1:
length = ceil(int_value.bit_length()//8) fmt = "!B"
if length == 0: elif length == 2:
length = 1 fmt = "!H"
return int_value.to_bytes(length, byteorder='big') return pack(fmt, int_value)
@asyncio.coroutine @asyncio.coroutine
@ -66,8 +66,11 @@ def decode_string(reader) -> bytes:
""" """
length_bytes = yield from read_or_raise(reader, 2) length_bytes = yield from read_or_raise(reader, 2)
str_length = unpack("!H", length_bytes) str_length = unpack("!H", length_bytes)
byte_str = yield from read_or_raise(reader, str_length[0]) if str_length[0]:
return byte_str.decode(encoding='utf-8') byte_str = yield from read_or_raise(reader, str_length[0])
return byte_str.decode(encoding='utf-8')
else:
return ''
@asyncio.coroutine @asyncio.coroutine

Wyświetl plik

@ -355,14 +355,14 @@ class ProtocolHandler:
def _reader_loop(self): def _reader_loop(self):
self.logger.debug("%s Starting reader coro" % self.session.client_id) self.logger.debug("%s Starting reader coro" % self.session.client_id)
running_tasks = collections.deque() running_tasks = collections.deque()
keepalive_timeout = self.session.keep_alive
if keepalive_timeout <= 0:
keepalive_timeout = None
while True: while True:
try: try:
self._reader_ready.set() self._reader_ready.set()
while running_tasks and running_tasks[0].done(): while running_tasks and running_tasks[0].done():
running_tasks.popleft() running_tasks.popleft()
keepalive_timeout = self.session.keep_alive
if keepalive_timeout <= 0:
keepalive_timeout = None
fixed_header = yield from asyncio.wait_for( fixed_header = yield from asyncio.wait_for(
MQTTFixedHeader.from_stream(self.reader), MQTTFixedHeader.from_stream(self.reader),
keepalive_timeout, loop=self._loop) keepalive_timeout, loop=self._loop)
@ -440,7 +440,6 @@ class ProtocolHandler:
self._keepalive_task = self._loop.call_later(self.keepalive_timeout, self.handle_write_timeout) self._keepalive_task = self._loop.call_later(self.keepalive_timeout, self.handle_write_timeout)
yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session) yield from self.plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session)
self._loop.call_soon(self.on_packet_sent.send, packet)
except ConnectionResetError as cre: except ConnectionResetError as cre:
yield from self.handle_connection_closed() yield from self.handle_connection_closed()
raise raise

Wyświetl plik

@ -18,10 +18,10 @@ class PublishVariableHeader(MQTTVariableHeader):
return type(self).__name__ + '(topic={0}, packet_id={1})'.format(self.topic_name, self.packet_id) return type(self).__name__ + '(topic={0}, packet_id={1})'.format(self.topic_name, self.packet_id)
def to_bytes(self): def to_bytes(self):
out = b'' out = bytearray()
out += encode_string(self.topic_name) out.extend(encode_string(self.topic_name))
if self.packet_id is not None: if self.packet_id is not None:
out += int_to_bytes(self.packet_id, 2) out.extend(int_to_bytes(self.packet_id, 2))
return out return out
@classmethod @classmethod

Wyświetl plik

@ -40,6 +40,12 @@ def test_coro():
if __name__ == '__main__': if __name__ == '__main__':
formatter = "[%(asctime)s] :: %(levelname)s :: %(name)s :: %(message)s" formatter = "[%(asctime)s] :: %(levelname)s :: %(name)s :: %(message)s"
#formatter = "%(asctime)s :: %(levelname)s :: %(message)s" #formatter = "%(asctime)s :: %(levelname)s :: %(message)s"
logging.basicConfig(level=logging.DEBUG, format=formatter) logging.basicConfig(level=logging.INFO, format=formatter)
#import selectors
#selector = selectors.EpollSelector()
#loop = asyncio.SelectorEventLoop(selector)
#asyncio.set_event_loop(loop)
asyncio.get_event_loop().run_until_complete(test_coro()) asyncio.get_event_loop().run_until_complete(test_coro())
asyncio.get_event_loop().run_forever() asyncio.get_event_loop().run_forever()

Wyświetl plik

@ -30,6 +30,6 @@ class SubackPacketTest(unittest.TestCase):
SubackPayload.RETURN_CODE_02, SubackPayload.RETURN_CODE_02,
SubackPayload.RETURN_CODE_80 SubackPayload.RETURN_CODE_80
]) ])
publish = SubackPacket(variable_header=variable_header, payload=payload) suback = SubackPacket(variable_header=variable_header, payload=payload)
out = publish.to_bytes() out = suback.to_bytes()
self.assertEqual(out, b'\x90\x06\x00\x0a\x00\x01\x02\x80') self.assertEqual(out, b'\x90\x06\x00\x0a\x00\x01\x02\x80')