Do some refactoring

pull/8/head
Nicolas Jouanin 2015-05-29 15:28:36 +02:00
rodzic 05ec91f481
commit 17ea5f8fc2
5 zmienionych plików z 18 dodań i 16 usunięć

Wyświetl plik

@ -25,7 +25,7 @@ class MessageType(Enum):
def get_message_type(byte): def get_message_type(byte):
return MessageType(byte) return MessageType(byte)
class Message: class FixedHeader:
def __init__(self, msg_type, length, dup_flag=False, qos=0, retain_flag=False): def __init__(self, msg_type, length, dup_flag=False, qos=0, retain_flag=False):
if isinstance(msg_type, int): if isinstance(msg_type, int):
enum_type = msg_type enum_type = msg_type
@ -35,4 +35,4 @@ class Message:
self.remainingLength = length self.remainingLength = length
self.dupFlag = dup_flag self.dupFlag = dup_flag
self.qos = qos self.qos = qos
self.retain = retain_flag self.retain = retain_flag

Wyświetl plik

@ -6,19 +6,19 @@ from hbmqtt.utils import (
bytes_to_hex_str, bytes_to_hex_str,
hex_to_int, hex_to_int,
) )
from hbmqtt.message import Message from hbmqtt.fixedheader import FixedHeader
from hbmqtt.streams.errors import FixedHeaderException from hbmqtt.streams.errors import FixedHeaderException
class BaseStream: class FixedHeaderStream:
def __init__(self): def __init__(self):
pass pass
def decode(self, reader): def decode(self, reader):
b1 = yield from reader.read(1) b1 = yield from reader.read(1)
msg_type = BaseStream.get_message_type(b1) msg_type = FixedHeaderStream.get_message_type(b1)
(dup_flag, qos, retain_flag) = BaseStream.get_flags(b1) (dup_flag, qos, retain_flag) = FixedHeaderStream.get_flags(b1)
remain_length = yield from self.decode_remaining_length(b1, reader) remain_length = yield from self.decode_remaining_length(b1, reader)
return Message(msg_type, remain_length, dup_flag, qos, retain_flag) return FixedHeader(msg_type, remain_length, dup_flag, qos, retain_flag)
@staticmethod @staticmethod
def get_message_type(byte): def get_message_type(byte):

Wyświetl plik

@ -0,0 +1 @@
__author__ = 'nico'

Wyświetl plik

@ -0,0 +1 @@
__author__ = 'nico'

Wyświetl plik

@ -3,19 +3,19 @@
# See the file license.txt for copying permission. # See the file license.txt for copying permission.
import unittest import unittest
import asyncio import asyncio
from hbmqtt.streams.base import BaseStream from hbmqtt.streams.fixed_header import FixedHeaderStream
from hbmqtt.streams.errors import FixedHeaderException from hbmqtt.streams.errors import FixedHeaderException
class TestBaseStream(unittest.TestCase): class TestFixedHeader(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
def test_get_message_type(self): def test_get_message_type(self):
m_type = BaseStream.get_message_type(b'\x10') m_type = FixedHeaderStream.get_message_type(b'\x10')
self.assertEqual(m_type, 1) self.assertEqual(m_type, 1)
def test_get_flags(self): def test_get_flags(self):
(dup_flag, qos, retain_flag) = BaseStream.get_flags(b'\x1f') (dup_flag, qos, retain_flag) = FixedHeaderStream.get_flags(b'\x1f')
self.assertTrue(dup_flag) self.assertTrue(dup_flag)
self.assertEqual(qos, 3) self.assertEqual(qos, 3)
self.assertTrue(retain_flag) self.assertTrue(retain_flag)
@ -23,34 +23,34 @@ class TestBaseStream(unittest.TestCase):
def test_decode_remaining_length1(self): def test_decode_remaining_length1(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'\x7f') stream.feed_data(b'\x7f')
s = BaseStream() s = FixedHeaderStream()
length = self.loop.run_until_complete(s.decode_remaining_length(stream)) length = self.loop.run_until_complete(s.decode_remaining_length(stream))
self.assertEqual(length, 127) self.assertEqual(length, 127)
def test_decode_remaining_length2(self): def test_decode_remaining_length2(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'\xff\x7f') stream.feed_data(b'\xff\x7f')
s = BaseStream() s = FixedHeaderStream()
length = self.loop.run_until_complete(s.decode_remaining_length(stream)) length = self.loop.run_until_complete(s.decode_remaining_length(stream))
self.assertEqual(length, 16383) self.assertEqual(length, 16383)
def test_decode_remaining_length3(self): def test_decode_remaining_length3(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'\xff\xff\x7f') stream.feed_data(b'\xff\xff\x7f')
s = BaseStream() s = FixedHeaderStream()
length = self.loop.run_until_complete(s.decode_remaining_length(stream)) length = self.loop.run_until_complete(s.decode_remaining_length(stream))
self.assertEqual(length, 2097151) self.assertEqual(length, 2097151)
def test_decode_remaining_length4(self): def test_decode_remaining_length4(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'\xff\xff\xff\x7f') stream.feed_data(b'\xff\xff\xff\x7f')
s = BaseStream() s = FixedHeaderStream()
length = self.loop.run_until_complete(s.decode_remaining_length(stream)) length = self.loop.run_until_complete(s.decode_remaining_length(stream))
self.assertEqual(length, 268435455) self.assertEqual(length, 268435455)
def test_decode_remaining_length5(self): def test_decode_remaining_length5(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'\xff\xff\xff\xff\x7f') stream.feed_data(b'\xff\xff\xff\xff\x7f')
s = BaseStream() s = FixedHeaderStream()
with self.assertRaises(FixedHeaderException): with self.assertRaises(FixedHeaderException):
self.loop.run_until_complete(s.decode_remaining_length(stream)) self.loop.run_until_complete(s.decode_remaining_length(stream))