From d3f09dc4edcab26fe9fdbe55e8c932ffa381df74 Mon Sep 17 00:00:00 2001 From: Nicolas Jouanin Date: Thu, 28 May 2015 23:18:42 +0200 Subject: [PATCH] Init message streaming for message encoding/decoding --- hbmqtt/errors.py | 2 +- hbmqtt/message.py | 38 ++++++++++++++++++++++ hbmqtt/streams/__init__.py | 3 ++ hbmqtt/streams/base.py | 51 +++++++++++++++++++++++++++++ hbmqtt/streams/errors.py | 10 ++++++ tests/errors/test_base_stream.py | 56 ++++++++++++++++++++++++++++++++ 6 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 hbmqtt/message.py create mode 100644 hbmqtt/streams/__init__.py create mode 100644 hbmqtt/streams/base.py create mode 100644 hbmqtt/streams/errors.py create mode 100644 tests/errors/test_base_stream.py diff --git a/hbmqtt/errors.py b/hbmqtt/errors.py index 5fa240e..87e4a21 100644 --- a/hbmqtt/errors.py +++ b/hbmqtt/errors.py @@ -3,4 +3,4 @@ # See the file license.txt for copying permission. class BrokerException(BaseException): - pass \ No newline at end of file + pass diff --git a/hbmqtt/message.py b/hbmqtt/message.py new file mode 100644 index 0000000..38a53da --- /dev/null +++ b/hbmqtt/message.py @@ -0,0 +1,38 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. +from enum import Enum + +class MessageType(Enum): + RESERVED_0 = 0 + CONNECT = 1 + CONNACK = 2 + PUBLISH = 3 + PUBACK = 4 + PUBREC = 5 + PUBREL = 6 + PUBCOMP = 7 + SUBSCRIBE = 8 + SUBACK = 9 + UNSUBSCRIBE = 10 + UNSUBACK = 11 + PINGREQ = 12 + PINGRESP = 13 + DISCONNECT = 14 + RESERVED_15 = 15 + + +def get_message_type(byte): + return MessageType(byte) + +class Message: + def __init__(self, msg_type, length, dup_flag=False, qos=0, retain_flag=False): + if isinstance(msg_type, int): + enum_type = msg_type + else: + enum_type = get_message_type(msg_type) + self.message_type = enum_type + self.remainingLength = length + self.dupFlag = dup_flag + self.qos = qos + self.retain = retain_flag \ No newline at end of file diff --git a/hbmqtt/streams/__init__.py b/hbmqtt/streams/__init__.py new file mode 100644 index 0000000..86831ba --- /dev/null +++ b/hbmqtt/streams/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. diff --git a/hbmqtt/streams/base.py b/hbmqtt/streams/base.py new file mode 100644 index 0000000..23a18ff --- /dev/null +++ b/hbmqtt/streams/base.py @@ -0,0 +1,51 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. +import asyncio +from hbmqtt.utils import ( + bytes_to_hex_str, + hex_to_int, +) +from hbmqtt.message import Message +from hbmqtt.streams.errors import FixedHeaderException + +class BaseStream: + def __init__(self): + pass + + def decode(self, reader): + b1 = yield from reader.read(1) + msg_type = BaseStream.get_message_type(b1) + (dup_flag, qos, retain_flag) = BaseStream.get_flags(b1) + remain_length = yield from self.decode_remaining_length(b1, reader) + return Message(msg_type, remain_length, dup_flag, qos, retain_flag) + + @staticmethod + def get_message_type(byte): + return (hex_to_int(byte) & 0xf0) >> 4 + + @staticmethod + def get_flags(data): + byte = hex_to_int(data) + b3 = True if (byte & 0x08) >> 3 else False + b21 = (byte & 0x06) >> 1 + b0 = True if (byte & 0x01) else False + return b3, b21, b0 + + @asyncio.coroutine + def decode_remaining_length(self, reader): + multiplier = 1 + value = 0 + length_bytes = b'' + while True: + encoded_byte = yield from reader.read(1) + length_bytes += encoded_byte + int_byte = hex_to_int(encoded_byte) + value += (int_byte & 0x7f) * multiplier + if (int_byte & 0x80) == 0: + break + else: + multiplier *= 128 + if multiplier > 128*128*128: + raise FixedHeaderException("Invalid remaining length bytes:%s" % bytes_to_hex_str(length_bytes)) + return value diff --git a/hbmqtt/streams/errors.py b/hbmqtt/streams/errors.py new file mode 100644 index 0000000..6b2536c --- /dev/null +++ b/hbmqtt/streams/errors.py @@ -0,0 +1,10 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. + +class StreamException(BaseException): + pass + + +class FixedHeaderException(StreamException): + pass diff --git a/tests/errors/test_base_stream.py b/tests/errors/test_base_stream.py new file mode 100644 index 0000000..440a6bc --- /dev/null +++ b/tests/errors/test_base_stream.py @@ -0,0 +1,56 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. +import unittest +import asyncio +from hbmqtt.streams.base import BaseStream +from hbmqtt.streams.errors import FixedHeaderException + +class TestBaseStream(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + + def test_get_message_type(self): + m_type = BaseStream.get_message_type(b'\x10') + self.assertEqual(m_type, 1) + + def test_get_flags(self): + (dup_flag, qos, retain_flag) = BaseStream.get_flags(b'\x1f') + self.assertTrue(dup_flag) + self.assertEqual(qos, 3) + self.assertTrue(retain_flag) + + def test_decode_remaining_length1(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'\x7f') + s = BaseStream() + length = self.loop.run_until_complete(s.decode_remaining_length(stream)) + self.assertEqual(length, 127) + + def test_decode_remaining_length2(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'\xff\x7f') + s = BaseStream() + length = self.loop.run_until_complete(s.decode_remaining_length(stream)) + self.assertEqual(length, 16383) + + def test_decode_remaining_length3(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'\xff\xff\x7f') + s = BaseStream() + length = self.loop.run_until_complete(s.decode_remaining_length(stream)) + self.assertEqual(length, 2097151) + + def test_decode_remaining_length4(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'\xff\xff\xff\x7f') + s = BaseStream() + length = self.loop.run_until_complete(s.decode_remaining_length(stream)) + self.assertEqual(length, 268435455) + + def test_decode_remaining_length5(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'\xff\xff\xff\xff\x7f') + s = BaseStream() + with self.assertRaises(FixedHeaderException): + self.loop.run_until_complete(s.decode_remaining_length(stream))