Work in progress (tons of errors remaining)

pull/8/head
Nicolas Jouanin 2015-06-11 23:12:29 +02:00
rodzic 65eb6e68f8
commit 8eff6bcc15
11 zmienionych plików z 235 dodań i 31 usunięć

2
.gitignore vendored 100644
Wyświetl plik

@ -0,0 +1,2 @@
hbmqtt/__pycache__
*.pyc

Wyświetl plik

@ -5,7 +5,7 @@ import asyncio
from asyncio import IncompleteReadError
from math import ceil
from hbmqtt.broker.errors import NoDataException
from hbmqtt.handlers.errors import NoDataException
def bytes_to_hex_str(data):

Wyświetl plik

@ -9,4 +9,10 @@ class MQTTException(BaseException):
"""
Base class for all errors refering to MQTT specifications
"""
pass
pass
class CodecException(HandlerException):
"""
Exceptions thrown by packet encode/decode functions
"""
pass

Wyświetl plik

@ -4,7 +4,7 @@
from hbmqtt.handlers.packet import ResponsePacketHandler, MQTTHeader, PacketType
from hbmqtt.messages.connack import ConnackPacket, ConnackVariableHeader, ReturnCode
from hbmqtt.session import Session
from hbmqtt.handlers.utils import *
from hbmqtt.codecs import *
from hbmqtt.errors import MQTTException

Wyświetl plik

@ -1,20 +1,22 @@
__author__ = 'nico'
import abc
import asyncio
from hbmqtt.handlers.packet import RequestPacketHandler
from hbmqtt.handlers.errors import HandlerException, NoDataException
from hbmqtt.messages.packet import MQTTHeader, PacketType
from hbmqtt.messages.connect import ConnectPacket, ConnectVariableHeader, ConnectPayload
from hbmqtt.errors import MQTTException
from hbmqtt.session import Session
from hbmqtt.handlers.utils import (
from hbmqtt.codecs import (
read_or_raise,
bytes_to_int,
decode_string,
int_to_bytes,
encode_string
)
class ConnectHandler(RequestPacketHandler):
def __init__(self):
super().__init__()

Wyświetl plik

@ -6,16 +6,17 @@ import asyncio
import abc
import sys
import logging
from hbmqtt.messages.packet import MQTTPacket, MQTTHeader, PacketType
from hbmqtt.handlers.utils import int_to_bytes, bytes_to_int, read_or_raise, bytes_to_hex_str
from hbmqtt.codecs import int_to_bytes, bytes_to_int, read_or_raise, bytes_to_hex_str
from hbmqtt.handlers.errors import CodecException, HandlerException
from hbmqtt.session import Session
from hbmqtt.errors import MQTTException
if sys.version_info >= (3,4):
import asyncio.ensure_future as async
if sys.version_info >= (3,4,4):
from asyncio import ensure_future as async
else:
import asyncio.async as async
from asyncio import async
class PacketHandler(metaclass=abc.ABCMeta):
@ -48,9 +49,8 @@ class PacketHandler(metaclass=abc.ABCMeta):
writer.write(encoded_payload)
yield from writer.drain()
@staticmethod
@asyncio.coroutine
def read_packet_header(reader: asyncio.StreamReader) -> MQTTHeader:
def read_packet_header(self, reader: asyncio.StreamReader) -> MQTTHeader:
"""
Read and decode MQTT message fixed header from stream
:return: FixedHeader instance

Wyświetl plik

@ -1,9 +1,12 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.
from hbmqtt.messages.packet import MQTTPacket, MQTTHeader, PacketType
import asyncio
from hbmqtt.messages.packet import MQTTPacket, MQTTFixedHeader, PacketType, MQTTVariableHeader
from hbmqtt.codecs import *
from hbmqtt.errors import MQTTException
class ConnectVariableHeader:
class ConnectVariableHeader(MQTTVariableHeader):
USERNAME_FLAG = 0x80
PASSWORD_FLAG = 0x40
WILL_RETAIN_FLAG = 0x20
@ -86,6 +89,44 @@ class ConnectVariableHeader:
self.flags &= (0x00 << 3)
self.flags |= (val << 3)
@classmethod
def from_stream(cls, reader: asyncio.StreamReader):
# protocol name
protocol_name = yield from decode_string(reader)
if protocol_name != "MQTT":
raise MQTTException('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % protocol_name)
# protocol level (only MQTT 3.1.1 supported)
protocol_level_byte = yield from read_or_raise(reader, 1)
protocol_level = bytes_to_int(protocol_level_byte)
# flags
flags_byte = yield from read_or_raise(reader, 1)
flags = bytes_to_int(flags_byte)
if flags & 0x01:
raise MQTTException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
# keep-alive
keep_alive_byte = yield from read_or_raise(reader, 2)
keep_alive = bytes_to_int(keep_alive_byte)
return cls(flags, keep_alive, protocol_name, protocol_level)
def to_bytes(self):
out = b''
# Protocol name
out += encode_string(self.proto_name)
# Protocol level
out += int_to_bytes(self.proto_level)
# flags
out += int_to_bytes(self.flags)
# keep alive
out += int_to_bytes(self.keep_alive, 2)
return out
class ConnectPayload:
def __init__(self, client_id=None, will_topic=None, will_message=None, username=None, password=None):
@ -98,7 +139,7 @@ class ConnectPayload:
class ConnectPacket(MQTTPacket):
def __init__(self, vh: ConnectVariableHeader, payload: ConnectPayload):
header = MQTTHeader(PacketType.CONNECT, 0x00)
header = MQTTFixedHeader(PacketType.CONNECT, 0x00)
super().__init__(header)
self.variable_header = vh
self.payload = payload

Wyświetl plik

@ -3,6 +3,10 @@
# See the file license.txt for copying permission.
from enum import Enum
from hbmqtt.errors import CodecException, MQTTException
from hbmqtt.codecs import *
import abc
class PacketType(Enum):
RESERVED_0 = 0
CONNECT = 1
@ -25,7 +29,7 @@ class PacketType(Enum):
def get_packet_type(byte):
return PacketType(byte)
class MQTTHeader:
class MQTTFixedHeader:
def __init__(self, packet_type, flags=0, length=0):
if isinstance(packet_type, int):
enum_type = packet_type
@ -35,9 +39,155 @@ class MQTTHeader:
self.remaining_length = length
self.flags = flags
def to_bytes(self):
def encode_remaining_length(length: int):
encoded = b''
while True:
length_byte = length % 0x80
length //= 0x80
if length > 0:
length_byte |= 0x80
encoded += int_to_bytes(length_byte)
if length <= 0:
break
return encoded
out = bytes(3) # MQTTHeader are at least 3 bytes long
packet_type = 0
try:
packet_type = (self.packet_type.value << 4) | self.flags
out += int_to_bytes(packet_type)
except OverflowError:
raise CodecException('packet_type encoding exceed 1 byte length: value=%d', packet_type)
encoded_length = encode_remaining_length(self.remaining_length)
out += encoded_length
return out
@asyncio.coroutine
def to_stream(self, writer: asyncio.StreamWriter):
writer.write(self.to_bytes())
yield from writer.drain()
@classmethod
@asyncio.coroutine
def from_stream(cls, reader: asyncio.StreamReader):
"""
Read and decode MQTT message fixed header from stream
:return: FixedHeader instance
"""
def decode_message_type(byte):
byte_type = (bytes_to_int(byte) & 0xf0) >> 4
return PacketType(byte_type)
def decode_flags(data):
byte = bytes_to_int(data)
return byte & 0x0f
@asyncio.coroutine
def decode_remaining_length():
"""
Decode message length according to MQTT specifications
:return:
"""
multiplier = 1
value = 0
length_bytes = b''
while True:
encoded_byte = yield from read_or_raise(reader, 1)
length_bytes += encoded_byte
int_byte = bytes_to_int(encoded_byte)
value += (int_byte & 0x7f) * multiplier
if (int_byte & 0x80) == 0:
break
else:
multiplier *= 128
if multiplier > 128*128*128:
raise MQTTException("Invalid remaining length bytes:%s" % bytes_to_hex_str(length_bytes))
return value
b1 = yield from read_or_raise(reader, 1)
msg_type = decode_message_type(b1)
if msg_type is PacketType.RESERVED_0 or msg_type is PacketType.RESERVED_15:
raise MQTTException("Usage of control packet type %s is forbidden" % msg_type)
flags = decode_flags(b1)
remain_length = yield from decode_remaining_length()
return cls(msg_type, flags, remain_length)
class MQTTVariableHeader:
def __init__(self):
pass
@asyncio.coroutine
def to_stream(self, writer: asyncio.StreamWriter):
writer.write(self.to_bytes())
yield from writer.drain()
@abc.abstractmethod
def to_bytes(self):
return
@classmethod
@asyncio.coroutine
@abc.abstractclassmethod
def from_stream(cls, reader: asyncio.StreamReader):
return
class MQTTPayload:
def __init__(self):
pass
@asyncio.coroutine
def to_stream(self, writer: asyncio.StreamWriter):
writer.write(self.to_bytes())
yield from writer.drain()
@abc.abstractmethod
def to_bytes(self):
return
@classmethod
@asyncio.coroutine
@abc.abstractclassmethod
def from_stream(cls, reader: asyncio.StreamReader):
return
class MQTTPacket:
def __init__(self, fixed: MQTTHeader):
def __init__(self, fixed: MQTTFixedHeader, variable_header: MQTTVariableHeader=None, payload: MQTTPayload=None):
self.fixed_header = fixed
self.variable_header = None
self.payload = None
self.variable_header = variable_header
self.payload = payload
@asyncio.coroutine
def to_stream(self, writer: asyncio.StreamWriter):
writer.write(self.to_bytes())
yield from writer.drain()
def to_bytes(self):
if self.variable_header:
variable_header_bytes = self.variable_header.to_bytes()
else:
variable_header_bytes = b''
if self.payload:
payload_bytes = self.payload.to_bytes()
else:
payload_bytes = b''
self.fixed_header.remaining_length = len(variable_header_bytes) + len(payload_bytes)
fixed_header_bytes = self.fixed_header.to_bytes()
return fixed_header_bytes + variable_header_bytes + payload_bytes
@classmethod
@asyncio.coroutine
def from_stream(cls, reader: asyncio.StreamReader):
fixed_header = yield from MQTTFixedHeader.from_stream(reader)
variable_header = yield from MQTTVariableHeader.from_stream(reader)
payload = yield from MQTTPayload.from_stream(reader)
return cls(fixed_header, variable_header, payload)

Wyświetl plik

@ -4,7 +4,7 @@
import unittest
import asyncio
from hbmqtt.handlers.utils import (
from hbmqtt.codecs import (
bytes_to_hex_str,
bytes_to_int,
decode_string,

Wyświetl plik

@ -0,0 +1 @@
__author__ = 'nico'

Wyświetl plik

@ -4,18 +4,20 @@
import unittest
import asyncio
from hbmqtt.codecs.header import MQTTHeaderCodec, MQTTHeaderException
from hbmqtt.handlers.packet import PacketHandler
from hbmqtt.messages.packet import PacketType, MQTTHeader
from hbmqtt.errors import MQTTException
class TestMQTTHeaderCodec(unittest.TestCase):
class TestPacketHandler(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
def test_decode_ok(self):
def test_read_packet_header(self):
packet_handler = PacketHandler()
stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'\x10\x7f')
header = self.loop.run_until_complete(MQTTHeaderCodec.decode(stream))
header = self.loop.run_until_complete(packet_handler.read_packet_header(stream))
self.assertEqual(header.message_type, PacketType.CONNECT)
self.assertFalse(header.flags & 0x08)
self.assertEqual((header.flags & 0x06) >> 1, 0)
@ -25,8 +27,8 @@ class TestMQTTHeaderCodec(unittest.TestCase):
def test_decode_ok_with_length(self):
stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'\x10\xff\xff\xff\x7f')
header = self.loop.run_until_complete(MQTTHeaderCodec.decode(stream))
self.assertEqual(header.message_type, PacketType.CONNECT)
header = self.loop.run_until_complete(PacketHandler.read_packet_header(stream))
self.assertEqual(header.packet_type, PacketType.CONNECT)
self.assertFalse(header.flags & 0x08)
self.assertEqual((header.flags & 0x06) >> 1, 0)
self.assertFalse(header.flags & 0x01)
@ -35,21 +37,21 @@ class TestMQTTHeaderCodec(unittest.TestCase):
def test_decode_reserved(self):
stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'\x0f\x7f')
with self.assertRaises(MQTTHeaderException):
self.loop.run_until_complete(MQTTHeaderCodec.decode(stream))
with self.assertRaises(MQTTException):
self.loop.run_until_complete(PacketHandler.read_packet_header(stream))
def test_decode_ko_with_length(self):
stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'\x10\xff\xff\xff\xff\x7f')
with self.assertRaises(MQTTHeaderException):
self.loop.run_until_complete(MQTTHeaderCodec.decode(stream))
with self.assertRaises(MQTTException):
self.loop.run_until_complete(PacketHandler.read_packet_header(stream))
def test_encode(self):
header = MQTTHeader(PacketType.CONNECT, 0x00, 0)
data = MQTTHeaderCodec.encode(header)
data = self.loop.run_until_complete(PacketHandler._encode_fixed_header(header))
self.assertEqual(data, b'\x10\x00')
def test_encode_2(self):
header = MQTTHeader(PacketType.CONNECT, 0x00, 268435455)
data = MQTTHeaderCodec.encode(header)
data = self.loop.run_until_complete(PacketHandler._encode_fixed_header(header))
self.assertEqual(data, b'\x10\xff\xff\xff\x7f')