diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..77e81f9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +hbmqtt/__pycache__ +*.pyc diff --git a/hbmqtt/handlers/utils.py b/hbmqtt/codecs.py similarity index 97% rename from hbmqtt/handlers/utils.py rename to hbmqtt/codecs.py index 230d614..ed320ba 100644 --- a/hbmqtt/handlers/utils.py +++ b/hbmqtt/codecs.py @@ -5,7 +5,7 @@ import asyncio from asyncio import IncompleteReadError from math import ceil -from hbmqtt.broker.errors import NoDataException +from hbmqtt.handlers.errors import NoDataException def bytes_to_hex_str(data): diff --git a/hbmqtt/errors.py b/hbmqtt/errors.py index e8c0866..b7dac0d 100644 --- a/hbmqtt/errors.py +++ b/hbmqtt/errors.py @@ -9,4 +9,10 @@ class MQTTException(BaseException): """ Base class for all errors refering to MQTT specifications """ - pass \ No newline at end of file + pass + +class CodecException(HandlerException): + """ + Exceptions thrown by packet encode/decode functions + """ + pass diff --git a/hbmqtt/handlers/connack.py b/hbmqtt/handlers/connack.py index 1250dcf..822203b 100644 --- a/hbmqtt/handlers/connack.py +++ b/hbmqtt/handlers/connack.py @@ -4,7 +4,7 @@ from hbmqtt.handlers.packet import ResponsePacketHandler, MQTTHeader, PacketType from hbmqtt.messages.connack import ConnackPacket, ConnackVariableHeader, ReturnCode from hbmqtt.session import Session -from hbmqtt.handlers.utils import * +from hbmqtt.codecs import * from hbmqtt.errors import MQTTException diff --git a/hbmqtt/handlers/connect.py b/hbmqtt/handlers/connect.py index 190078a..5b4bba1 100644 --- a/hbmqtt/handlers/connect.py +++ b/hbmqtt/handlers/connect.py @@ -1,20 +1,22 @@ __author__ = 'nico' -import abc import asyncio + from hbmqtt.handlers.packet import RequestPacketHandler from hbmqtt.handlers.errors import HandlerException, NoDataException from hbmqtt.messages.packet import MQTTHeader, PacketType from hbmqtt.messages.connect import ConnectPacket, ConnectVariableHeader, ConnectPayload from hbmqtt.errors import MQTTException from hbmqtt.session import Session -from hbmqtt.handlers.utils import ( +from hbmqtt.codecs import ( read_or_raise, bytes_to_int, decode_string, int_to_bytes, encode_string ) + + class ConnectHandler(RequestPacketHandler): def __init__(self): super().__init__() diff --git a/hbmqtt/handlers/packet.py b/hbmqtt/handlers/packet.py index abebf15..fc71dcd 100644 --- a/hbmqtt/handlers/packet.py +++ b/hbmqtt/handlers/packet.py @@ -6,16 +6,17 @@ import asyncio import abc import sys import logging + from hbmqtt.messages.packet import MQTTPacket, MQTTHeader, PacketType -from hbmqtt.handlers.utils import int_to_bytes, bytes_to_int, read_or_raise, bytes_to_hex_str +from hbmqtt.codecs import int_to_bytes, bytes_to_int, read_or_raise, bytes_to_hex_str from hbmqtt.handlers.errors import CodecException, HandlerException from hbmqtt.session import Session from hbmqtt.errors import MQTTException -if sys.version_info >= (3,4): - import asyncio.ensure_future as async +if sys.version_info >= (3,4,4): + from asyncio import ensure_future as async else: - import asyncio.async as async + from asyncio import async class PacketHandler(metaclass=abc.ABCMeta): @@ -48,9 +49,8 @@ class PacketHandler(metaclass=abc.ABCMeta): writer.write(encoded_payload) yield from writer.drain() - @staticmethod @asyncio.coroutine - def read_packet_header(reader: asyncio.StreamReader) -> MQTTHeader: + def read_packet_header(self, reader: asyncio.StreamReader) -> MQTTHeader: """ Read and decode MQTT message fixed header from stream :return: FixedHeader instance diff --git a/hbmqtt/messages/connect.py b/hbmqtt/messages/connect.py index 1958351..a485766 100644 --- a/hbmqtt/messages/connect.py +++ b/hbmqtt/messages/connect.py @@ -1,9 +1,12 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. -from hbmqtt.messages.packet import MQTTPacket, MQTTHeader, PacketType +import asyncio +from hbmqtt.messages.packet import MQTTPacket, MQTTFixedHeader, PacketType, MQTTVariableHeader +from hbmqtt.codecs import * +from hbmqtt.errors import MQTTException -class ConnectVariableHeader: +class ConnectVariableHeader(MQTTVariableHeader): USERNAME_FLAG = 0x80 PASSWORD_FLAG = 0x40 WILL_RETAIN_FLAG = 0x20 @@ -86,6 +89,44 @@ class ConnectVariableHeader: self.flags &= (0x00 << 3) self.flags |= (val << 3) + @classmethod + def from_stream(cls, reader: asyncio.StreamReader): + # protocol name + protocol_name = yield from decode_string(reader) + if protocol_name != "MQTT": + raise MQTTException('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % protocol_name) + + # protocol level (only MQTT 3.1.1 supported) + protocol_level_byte = yield from read_or_raise(reader, 1) + protocol_level = bytes_to_int(protocol_level_byte) + + # flags + flags_byte = yield from read_or_raise(reader, 1) + flags = bytes_to_int(flags_byte) + if flags & 0x01: + raise MQTTException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0') + + # keep-alive + keep_alive_byte = yield from read_or_raise(reader, 2) + keep_alive = bytes_to_int(keep_alive_byte) + + return cls(flags, keep_alive, protocol_name, protocol_level) + + def to_bytes(self): + out = b'' + + # Protocol name + out += encode_string(self.proto_name) + # Protocol level + out += int_to_bytes(self.proto_level) + # flags + out += int_to_bytes(self.flags) + # keep alive + out += int_to_bytes(self.keep_alive, 2) + + return out + + class ConnectPayload: def __init__(self, client_id=None, will_topic=None, will_message=None, username=None, password=None): @@ -98,7 +139,7 @@ class ConnectPayload: class ConnectPacket(MQTTPacket): def __init__(self, vh: ConnectVariableHeader, payload: ConnectPayload): - header = MQTTHeader(PacketType.CONNECT, 0x00) + header = MQTTFixedHeader(PacketType.CONNECT, 0x00) super().__init__(header) self.variable_header = vh self.payload = payload diff --git a/hbmqtt/messages/packet.py b/hbmqtt/messages/packet.py index da93530..2344d77 100644 --- a/hbmqtt/messages/packet.py +++ b/hbmqtt/messages/packet.py @@ -3,6 +3,10 @@ # See the file license.txt for copying permission. from enum import Enum +from hbmqtt.errors import CodecException, MQTTException +from hbmqtt.codecs import * +import abc + class PacketType(Enum): RESERVED_0 = 0 CONNECT = 1 @@ -25,7 +29,7 @@ class PacketType(Enum): def get_packet_type(byte): return PacketType(byte) -class MQTTHeader: +class MQTTFixedHeader: def __init__(self, packet_type, flags=0, length=0): if isinstance(packet_type, int): enum_type = packet_type @@ -35,9 +39,155 @@ class MQTTHeader: self.remaining_length = length self.flags = flags + def to_bytes(self): + def encode_remaining_length(length: int): + encoded = b'' + while True: + length_byte = length % 0x80 + length //= 0x80 + if length > 0: + length_byte |= 0x80 + encoded += int_to_bytes(length_byte) + if length <= 0: + break + return encoded + + out = bytes(3) # MQTTHeader are at least 3 bytes long + packet_type = 0 + try: + packet_type = (self.packet_type.value << 4) | self.flags + out += int_to_bytes(packet_type) + except OverflowError: + raise CodecException('packet_type encoding exceed 1 byte length: value=%d', packet_type) + + encoded_length = encode_remaining_length(self.remaining_length) + out += encoded_length + + return out + + @asyncio.coroutine + def to_stream(self, writer: asyncio.StreamWriter): + writer.write(self.to_bytes()) + yield from writer.drain() + + @classmethod + @asyncio.coroutine + def from_stream(cls, reader: asyncio.StreamReader): + """ + Read and decode MQTT message fixed header from stream + :return: FixedHeader instance + """ + def decode_message_type(byte): + byte_type = (bytes_to_int(byte) & 0xf0) >> 4 + return PacketType(byte_type) + + def decode_flags(data): + byte = bytes_to_int(data) + return byte & 0x0f + + @asyncio.coroutine + def decode_remaining_length(): + """ + Decode message length according to MQTT specifications + :return: + """ + multiplier = 1 + value = 0 + length_bytes = b'' + while True: + encoded_byte = yield from read_or_raise(reader, 1) + length_bytes += encoded_byte + int_byte = bytes_to_int(encoded_byte) + value += (int_byte & 0x7f) * multiplier + if (int_byte & 0x80) == 0: + break + else: + multiplier *= 128 + if multiplier > 128*128*128: + raise MQTTException("Invalid remaining length bytes:%s" % bytes_to_hex_str(length_bytes)) + return value + + b1 = yield from read_or_raise(reader, 1) + msg_type = decode_message_type(b1) + if msg_type is PacketType.RESERVED_0 or msg_type is PacketType.RESERVED_15: + raise MQTTException("Usage of control packet type %s is forbidden" % msg_type) + flags = decode_flags(b1) + + remain_length = yield from decode_remaining_length() + return cls(msg_type, flags, remain_length) + + +class MQTTVariableHeader: + def __init__(self): + pass + + @asyncio.coroutine + def to_stream(self, writer: asyncio.StreamWriter): + writer.write(self.to_bytes()) + yield from writer.drain() + + @abc.abstractmethod + def to_bytes(self): + return + + @classmethod + @asyncio.coroutine + @abc.abstractclassmethod + def from_stream(cls, reader: asyncio.StreamReader): + return + + +class MQTTPayload: + def __init__(self): + pass + + @asyncio.coroutine + def to_stream(self, writer: asyncio.StreamWriter): + writer.write(self.to_bytes()) + yield from writer.drain() + + @abc.abstractmethod + def to_bytes(self): + return + + @classmethod + @asyncio.coroutine + @abc.abstractclassmethod + def from_stream(cls, reader: asyncio.StreamReader): + return + class MQTTPacket: - def __init__(self, fixed: MQTTHeader): + def __init__(self, fixed: MQTTFixedHeader, variable_header: MQTTVariableHeader=None, payload: MQTTPayload=None): self.fixed_header = fixed - self.variable_header = None - self.payload = None + self.variable_header = variable_header + self.payload = payload + + @asyncio.coroutine + def to_stream(self, writer: asyncio.StreamWriter): + writer.write(self.to_bytes()) + yield from writer.drain() + + def to_bytes(self): + if self.variable_header: + variable_header_bytes = self.variable_header.to_bytes() + else: + variable_header_bytes = b'' + if self.payload: + payload_bytes = self.payload.to_bytes() + else: + payload_bytes = b'' + + self.fixed_header.remaining_length = len(variable_header_bytes) + len(payload_bytes) + fixed_header_bytes = self.fixed_header.to_bytes() + + return fixed_header_bytes + variable_header_bytes + payload_bytes + + @classmethod + @asyncio.coroutine + def from_stream(cls, reader: asyncio.StreamReader): + fixed_header = yield from MQTTFixedHeader.from_stream(reader) + variable_header = yield from MQTTVariableHeader.from_stream(reader) + payload = yield from MQTTPayload.from_stream(reader) + + return cls(fixed_header, variable_header, payload) diff --git a/tests/codecs/test_utils.py b/tests/codecs/test_utils.py index 84027f0..494f7b9 100644 --- a/tests/codecs/test_utils.py +++ b/tests/codecs/test_utils.py @@ -4,7 +4,7 @@ import unittest import asyncio -from hbmqtt.handlers.utils import ( +from hbmqtt.codecs import ( bytes_to_hex_str, bytes_to_int, decode_string, diff --git a/tests/handlers/__init__.py b/tests/handlers/__init__.py new file mode 100644 index 0000000..e1bd617 --- /dev/null +++ b/tests/handlers/__init__.py @@ -0,0 +1 @@ +__author__ = 'nico' diff --git a/tests/codecs/test_header.py b/tests/handlers/packet.py similarity index 63% rename from tests/codecs/test_header.py rename to tests/handlers/packet.py index 9cbf616..5842b67 100644 --- a/tests/codecs/test_header.py +++ b/tests/handlers/packet.py @@ -4,18 +4,20 @@ import unittest import asyncio -from hbmqtt.codecs.header import MQTTHeaderCodec, MQTTHeaderException +from hbmqtt.handlers.packet import PacketHandler from hbmqtt.messages.packet import PacketType, MQTTHeader +from hbmqtt.errors import MQTTException -class TestMQTTHeaderCodec(unittest.TestCase): +class TestPacketHandler(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() - def test_decode_ok(self): + def test_read_packet_header(self): + packet_handler = PacketHandler() stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'\x10\x7f') - header = self.loop.run_until_complete(MQTTHeaderCodec.decode(stream)) + header = self.loop.run_until_complete(packet_handler.read_packet_header(stream)) self.assertEqual(header.message_type, PacketType.CONNECT) self.assertFalse(header.flags & 0x08) self.assertEqual((header.flags & 0x06) >> 1, 0) @@ -25,8 +27,8 @@ class TestMQTTHeaderCodec(unittest.TestCase): def test_decode_ok_with_length(self): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'\x10\xff\xff\xff\x7f') - header = self.loop.run_until_complete(MQTTHeaderCodec.decode(stream)) - self.assertEqual(header.message_type, PacketType.CONNECT) + header = self.loop.run_until_complete(PacketHandler.read_packet_header(stream)) + self.assertEqual(header.packet_type, PacketType.CONNECT) self.assertFalse(header.flags & 0x08) self.assertEqual((header.flags & 0x06) >> 1, 0) self.assertFalse(header.flags & 0x01) @@ -35,21 +37,21 @@ class TestMQTTHeaderCodec(unittest.TestCase): def test_decode_reserved(self): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'\x0f\x7f') - with self.assertRaises(MQTTHeaderException): - self.loop.run_until_complete(MQTTHeaderCodec.decode(stream)) + with self.assertRaises(MQTTException): + self.loop.run_until_complete(PacketHandler.read_packet_header(stream)) def test_decode_ko_with_length(self): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'\x10\xff\xff\xff\xff\x7f') - with self.assertRaises(MQTTHeaderException): - self.loop.run_until_complete(MQTTHeaderCodec.decode(stream)) + with self.assertRaises(MQTTException): + self.loop.run_until_complete(PacketHandler.read_packet_header(stream)) def test_encode(self): header = MQTTHeader(PacketType.CONNECT, 0x00, 0) - data = MQTTHeaderCodec.encode(header) + data = self.loop.run_until_complete(PacketHandler._encode_fixed_header(header)) self.assertEqual(data, b'\x10\x00') def test_encode_2(self): header = MQTTHeader(PacketType.CONNECT, 0x00, 268435455) - data = MQTTHeaderCodec.encode(header) + data = self.loop.run_until_complete(PacketHandler._encode_fixed_header(header)) self.assertEqual(data, b'\x10\xff\xff\xff\x7f')