Change flags management

pull/8/head
Nicolas Jouanin 2015-05-29 22:18:52 +02:00
rodzic 3048c1b836
commit a25ba48c3b
3 zmienionych plików z 24 dodań i 20 usunięć

Wyświetl plik

@ -26,13 +26,11 @@ def get_message_type(byte):
return MessageType(byte) return MessageType(byte)
class FixedHeader: class FixedHeader:
def __init__(self, msg_type, length, dup_flag=False, qos=0, retain_flag=False): def __init__(self, msg_type, flags, length):
if isinstance(msg_type, int): if isinstance(msg_type, int):
enum_type = msg_type enum_type = msg_type
else: else:
enum_type = get_message_type(msg_type) enum_type = get_message_type(msg_type)
self.message_type = enum_type self.message_type = enum_type
self.remainingLength = length self.remainingLength = length
self.dup_flag = dup_flag self.flags = flags
self.qos = qos
self.retain_flag = retain_flag

Wyświetl plik

@ -7,7 +7,10 @@ from hbmqtt.utils import (
hex_to_int, hex_to_int,
) )
from hbmqtt.message import FixedHeader, MessageType from hbmqtt.message import FixedHeader, MessageType
from hbmqtt.streams.errors import FixedHeaderException from hbmqtt.streams.errors import StreamException, NoDataException
class FixedHeaderException(StreamException):
pass
class FixedHeaderStream: class FixedHeaderStream:
def __init__(self): def __init__(self):
@ -15,12 +18,15 @@ class FixedHeaderStream:
def decode(self, reader) -> FixedHeader: def decode(self, reader) -> FixedHeader:
b1 = yield from reader.read(1) b1 = yield from reader.read(1)
if not b1:
raise NoDataException
msg_type = FixedHeaderStream.get_message_type(b1) msg_type = FixedHeaderStream.get_message_type(b1)
if msg_type is MessageType.RESERVED_0 or msg_type is MessageType.RESERVED_15: if msg_type is MessageType.RESERVED_0 or msg_type is MessageType.RESERVED_15:
raise FixedHeaderException("Usage of control packet type %s is forbidden" % msg_type) raise FixedHeaderException("Usage of control packet type %s is forbidden" % msg_type)
(dup_flag, qos, retain_flag) = FixedHeaderStream.get_flags(b1) flags = FixedHeaderStream.get_flags(b1)
remain_length = yield from self.decode_remaining_length(reader) remain_length = yield from self.decode_remaining_length(reader)
return FixedHeader(msg_type, remain_length, dup_flag, qos, retain_flag) return FixedHeader(msg_type, flags, remain_length)
@staticmethod @staticmethod
def get_message_type(byte): def get_message_type(byte):
@ -30,10 +36,7 @@ class FixedHeaderStream:
@staticmethod @staticmethod
def get_flags(data): def get_flags(data):
byte = hex_to_int(data) byte = hex_to_int(data)
b3 = True if (byte & 0x08) >> 3 else False return byte & 0x0f
b21 = (byte & 0x06) >> 1
b0 = True if (byte & 0x01) else False
return b3, b21, b0
@asyncio.coroutine @asyncio.coroutine
def decode_remaining_length(self, reader): def decode_remaining_length(self, reader):
@ -42,6 +45,8 @@ class FixedHeaderStream:
length_bytes = b'' length_bytes = b''
while True: while True:
encoded_byte = yield from reader.read(1) encoded_byte = yield from reader.read(1)
if not encoded_byte:
raise NoDataException
length_bytes += encoded_byte length_bytes += encoded_byte
int_byte = hex_to_int(encoded_byte) int_byte = hex_to_int(encoded_byte)
value += (int_byte & 0x7f) * multiplier value += (int_byte & 0x7f) * multiplier

Wyświetl plik

@ -3,8 +3,7 @@
# See the file license.txt for copying permission. # See the file license.txt for copying permission.
import unittest import unittest
import asyncio import asyncio
from hbmqtt.streams.fixed_header import FixedHeaderStream from hbmqtt.streams.fixed_header import FixedHeaderStream, FixedHeaderException
from hbmqtt.streams.errors import FixedHeaderException
from hbmqtt.message import MessageType from hbmqtt.message import MessageType
class TestFixedHeader(unittest.TestCase): class TestFixedHeader(unittest.TestCase):
@ -16,10 +15,12 @@ class TestFixedHeader(unittest.TestCase):
self.assertEqual(m_type, MessageType.CONNECT) self.assertEqual(m_type, MessageType.CONNECT)
def test_get_flags(self): def test_get_flags(self):
(dup_flag, qos, retain_flag) = FixedHeaderStream.get_flags(b'\x1f') flags = FixedHeaderStream.get_flags(b'\x1f')
self.assertTrue(dup_flag) self.assertTrue(flags & 0x08)
self.assertEqual(qos, 3) self.assertTrue(flags & 0x04)
self.assertTrue(retain_flag) self.assertTrue(flags & 0x02)
self.assertTrue(flags & 0x01)
self.assertFalse(flags & 0x10)
def test_decode_remaining_length1(self): def test_decode_remaining_length1(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
@ -62,9 +63,9 @@ class TestFixedHeader(unittest.TestCase):
s = FixedHeaderStream() s = FixedHeaderStream()
header = self.loop.run_until_complete(s.decode(stream)) header = self.loop.run_until_complete(s.decode(stream))
self.assertEqual(header.message_type, MessageType.CONNECT) self.assertEqual(header.message_type, MessageType.CONNECT)
self.assertFalse(header.dup_flag) self.assertFalse(header.flags & 0x08)
self.assertEqual(header.qos, 0) self.assertEqual((header.flags & 0x06) >> 1 , 0)
self.assertFalse(header.retain_flag) self.assertFalse(header.flags & 0x01)
def test_decode_ko(self): def test_decode_ko(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)