kopia lustrzana https://github.com/Yakifo/amqtt
Do some refactoring
rodzic
05ec91f481
commit
17ea5f8fc2
|
@ -25,7 +25,7 @@ class MessageType(Enum):
|
||||||
def get_message_type(byte):
|
def get_message_type(byte):
|
||||||
return MessageType(byte)
|
return MessageType(byte)
|
||||||
|
|
||||||
class Message:
|
class FixedHeader:
|
||||||
def __init__(self, msg_type, length, dup_flag=False, qos=0, retain_flag=False):
|
def __init__(self, msg_type, length, dup_flag=False, qos=0, retain_flag=False):
|
||||||
if isinstance(msg_type, int):
|
if isinstance(msg_type, int):
|
||||||
enum_type = msg_type
|
enum_type = msg_type
|
||||||
|
@ -35,4 +35,4 @@ class Message:
|
||||||
self.remainingLength = length
|
self.remainingLength = length
|
||||||
self.dupFlag = dup_flag
|
self.dupFlag = dup_flag
|
||||||
self.qos = qos
|
self.qos = qos
|
||||||
self.retain = retain_flag
|
self.retain = retain_flag
|
|
@ -6,19 +6,19 @@ from hbmqtt.utils import (
|
||||||
bytes_to_hex_str,
|
bytes_to_hex_str,
|
||||||
hex_to_int,
|
hex_to_int,
|
||||||
)
|
)
|
||||||
from hbmqtt.message import Message
|
from hbmqtt.fixedheader import FixedHeader
|
||||||
from hbmqtt.streams.errors import FixedHeaderException
|
from hbmqtt.streams.errors import FixedHeaderException
|
||||||
|
|
||||||
class BaseStream:
|
class FixedHeaderStream:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def decode(self, reader):
|
def decode(self, reader):
|
||||||
b1 = yield from reader.read(1)
|
b1 = yield from reader.read(1)
|
||||||
msg_type = BaseStream.get_message_type(b1)
|
msg_type = FixedHeaderStream.get_message_type(b1)
|
||||||
(dup_flag, qos, retain_flag) = BaseStream.get_flags(b1)
|
(dup_flag, qos, retain_flag) = FixedHeaderStream.get_flags(b1)
|
||||||
remain_length = yield from self.decode_remaining_length(b1, reader)
|
remain_length = yield from self.decode_remaining_length(b1, reader)
|
||||||
return Message(msg_type, remain_length, dup_flag, qos, retain_flag)
|
return FixedHeader(msg_type, remain_length, dup_flag, qos, retain_flag)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_message_type(byte):
|
def get_message_type(byte):
|
|
@ -0,0 +1 @@
|
||||||
|
__author__ = 'nico'
|
|
@ -0,0 +1 @@
|
||||||
|
__author__ = 'nico'
|
|
@ -3,19 +3,19 @@
|
||||||
# See the file license.txt for copying permission.
|
# See the file license.txt for copying permission.
|
||||||
import unittest
|
import unittest
|
||||||
import asyncio
|
import asyncio
|
||||||
from hbmqtt.streams.base import BaseStream
|
from hbmqtt.streams.fixed_header import FixedHeaderStream
|
||||||
from hbmqtt.streams.errors import FixedHeaderException
|
from hbmqtt.streams.errors import FixedHeaderException
|
||||||
|
|
||||||
class TestBaseStream(unittest.TestCase):
|
class TestFixedHeader(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
def test_get_message_type(self):
|
def test_get_message_type(self):
|
||||||
m_type = BaseStream.get_message_type(b'\x10')
|
m_type = FixedHeaderStream.get_message_type(b'\x10')
|
||||||
self.assertEqual(m_type, 1)
|
self.assertEqual(m_type, 1)
|
||||||
|
|
||||||
def test_get_flags(self):
|
def test_get_flags(self):
|
||||||
(dup_flag, qos, retain_flag) = BaseStream.get_flags(b'\x1f')
|
(dup_flag, qos, retain_flag) = FixedHeaderStream.get_flags(b'\x1f')
|
||||||
self.assertTrue(dup_flag)
|
self.assertTrue(dup_flag)
|
||||||
self.assertEqual(qos, 3)
|
self.assertEqual(qos, 3)
|
||||||
self.assertTrue(retain_flag)
|
self.assertTrue(retain_flag)
|
||||||
|
@ -23,34 +23,34 @@ class TestBaseStream(unittest.TestCase):
|
||||||
def test_decode_remaining_length1(self):
|
def test_decode_remaining_length1(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
stream.feed_data(b'\x7f')
|
stream.feed_data(b'\x7f')
|
||||||
s = BaseStream()
|
s = FixedHeaderStream()
|
||||||
length = self.loop.run_until_complete(s.decode_remaining_length(stream))
|
length = self.loop.run_until_complete(s.decode_remaining_length(stream))
|
||||||
self.assertEqual(length, 127)
|
self.assertEqual(length, 127)
|
||||||
|
|
||||||
def test_decode_remaining_length2(self):
|
def test_decode_remaining_length2(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
stream.feed_data(b'\xff\x7f')
|
stream.feed_data(b'\xff\x7f')
|
||||||
s = BaseStream()
|
s = FixedHeaderStream()
|
||||||
length = self.loop.run_until_complete(s.decode_remaining_length(stream))
|
length = self.loop.run_until_complete(s.decode_remaining_length(stream))
|
||||||
self.assertEqual(length, 16383)
|
self.assertEqual(length, 16383)
|
||||||
|
|
||||||
def test_decode_remaining_length3(self):
|
def test_decode_remaining_length3(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
stream.feed_data(b'\xff\xff\x7f')
|
stream.feed_data(b'\xff\xff\x7f')
|
||||||
s = BaseStream()
|
s = FixedHeaderStream()
|
||||||
length = self.loop.run_until_complete(s.decode_remaining_length(stream))
|
length = self.loop.run_until_complete(s.decode_remaining_length(stream))
|
||||||
self.assertEqual(length, 2097151)
|
self.assertEqual(length, 2097151)
|
||||||
|
|
||||||
def test_decode_remaining_length4(self):
|
def test_decode_remaining_length4(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
stream.feed_data(b'\xff\xff\xff\x7f')
|
stream.feed_data(b'\xff\xff\xff\x7f')
|
||||||
s = BaseStream()
|
s = FixedHeaderStream()
|
||||||
length = self.loop.run_until_complete(s.decode_remaining_length(stream))
|
length = self.loop.run_until_complete(s.decode_remaining_length(stream))
|
||||||
self.assertEqual(length, 268435455)
|
self.assertEqual(length, 268435455)
|
||||||
|
|
||||||
def test_decode_remaining_length5(self):
|
def test_decode_remaining_length5(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
stream.feed_data(b'\xff\xff\xff\xff\x7f')
|
stream.feed_data(b'\xff\xff\xff\xff\x7f')
|
||||||
s = BaseStream()
|
s = FixedHeaderStream()
|
||||||
with self.assertRaises(FixedHeaderException):
|
with self.assertRaises(FixedHeaderException):
|
||||||
self.loop.run_until_complete(s.decode_remaining_length(stream))
|
self.loop.run_until_complete(s.decode_remaining_length(stream))
|
Ładowanie…
Reference in New Issue