kopia lustrzana https://github.com/Yakifo/amqtt
Work in progress (tons of errors remaining)
rodzic
65eb6e68f8
commit
8eff6bcc15
|
@ -0,0 +1,2 @@
|
|||
hbmqtt/__pycache__
|
||||
*.pyc
|
|
@ -5,7 +5,7 @@ import asyncio
|
|||
from asyncio import IncompleteReadError
|
||||
from math import ceil
|
||||
|
||||
from hbmqtt.broker.errors import NoDataException
|
||||
from hbmqtt.handlers.errors import NoDataException
|
||||
|
||||
|
||||
def bytes_to_hex_str(data):
|
|
@ -9,4 +9,10 @@ class MQTTException(BaseException):
|
|||
"""
|
||||
Base class for all errors refering to MQTT specifications
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
class CodecException(HandlerException):
|
||||
"""
|
||||
Exceptions thrown by packet encode/decode functions
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
from hbmqtt.handlers.packet import ResponsePacketHandler, MQTTHeader, PacketType
|
||||
from hbmqtt.messages.connack import ConnackPacket, ConnackVariableHeader, ReturnCode
|
||||
from hbmqtt.session import Session
|
||||
from hbmqtt.handlers.utils import *
|
||||
from hbmqtt.codecs import *
|
||||
from hbmqtt.errors import MQTTException
|
||||
|
||||
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
__author__ = 'nico'
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
|
||||
from hbmqtt.handlers.packet import RequestPacketHandler
|
||||
from hbmqtt.handlers.errors import HandlerException, NoDataException
|
||||
from hbmqtt.messages.packet import MQTTHeader, PacketType
|
||||
from hbmqtt.messages.connect import ConnectPacket, ConnectVariableHeader, ConnectPayload
|
||||
from hbmqtt.errors import MQTTException
|
||||
from hbmqtt.session import Session
|
||||
from hbmqtt.handlers.utils import (
|
||||
from hbmqtt.codecs import (
|
||||
read_or_raise,
|
||||
bytes_to_int,
|
||||
decode_string,
|
||||
int_to_bytes,
|
||||
encode_string
|
||||
)
|
||||
|
||||
|
||||
class ConnectHandler(RequestPacketHandler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -6,16 +6,17 @@ import asyncio
|
|||
import abc
|
||||
import sys
|
||||
import logging
|
||||
|
||||
from hbmqtt.messages.packet import MQTTPacket, MQTTHeader, PacketType
|
||||
from hbmqtt.handlers.utils import int_to_bytes, bytes_to_int, read_or_raise, bytes_to_hex_str
|
||||
from hbmqtt.codecs import int_to_bytes, bytes_to_int, read_or_raise, bytes_to_hex_str
|
||||
from hbmqtt.handlers.errors import CodecException, HandlerException
|
||||
from hbmqtt.session import Session
|
||||
from hbmqtt.errors import MQTTException
|
||||
|
||||
if sys.version_info >= (3,4):
|
||||
import asyncio.ensure_future as async
|
||||
if sys.version_info >= (3,4,4):
|
||||
from asyncio import ensure_future as async
|
||||
else:
|
||||
import asyncio.async as async
|
||||
from asyncio import async
|
||||
|
||||
|
||||
class PacketHandler(metaclass=abc.ABCMeta):
|
||||
|
@ -48,9 +49,8 @@ class PacketHandler(metaclass=abc.ABCMeta):
|
|||
writer.write(encoded_payload)
|
||||
yield from writer.drain()
|
||||
|
||||
@staticmethod
|
||||
@asyncio.coroutine
|
||||
def read_packet_header(reader: asyncio.StreamReader) -> MQTTHeader:
|
||||
def read_packet_header(self, reader: asyncio.StreamReader) -> MQTTHeader:
|
||||
"""
|
||||
Read and decode MQTT message fixed header from stream
|
||||
:return: FixedHeader instance
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
# Copyright (c) 2015 Nicolas JOUANIN
|
||||
#
|
||||
# See the file license.txt for copying permission.
|
||||
from hbmqtt.messages.packet import MQTTPacket, MQTTHeader, PacketType
|
||||
import asyncio
|
||||
from hbmqtt.messages.packet import MQTTPacket, MQTTFixedHeader, PacketType, MQTTVariableHeader
|
||||
from hbmqtt.codecs import *
|
||||
from hbmqtt.errors import MQTTException
|
||||
|
||||
class ConnectVariableHeader:
|
||||
class ConnectVariableHeader(MQTTVariableHeader):
|
||||
USERNAME_FLAG = 0x80
|
||||
PASSWORD_FLAG = 0x40
|
||||
WILL_RETAIN_FLAG = 0x20
|
||||
|
@ -86,6 +89,44 @@ class ConnectVariableHeader:
|
|||
self.flags &= (0x00 << 3)
|
||||
self.flags |= (val << 3)
|
||||
|
||||
@classmethod
|
||||
def from_stream(cls, reader: asyncio.StreamReader):
|
||||
# protocol name
|
||||
protocol_name = yield from decode_string(reader)
|
||||
if protocol_name != "MQTT":
|
||||
raise MQTTException('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % protocol_name)
|
||||
|
||||
# protocol level (only MQTT 3.1.1 supported)
|
||||
protocol_level_byte = yield from read_or_raise(reader, 1)
|
||||
protocol_level = bytes_to_int(protocol_level_byte)
|
||||
|
||||
# flags
|
||||
flags_byte = yield from read_or_raise(reader, 1)
|
||||
flags = bytes_to_int(flags_byte)
|
||||
if flags & 0x01:
|
||||
raise MQTTException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
|
||||
|
||||
# keep-alive
|
||||
keep_alive_byte = yield from read_or_raise(reader, 2)
|
||||
keep_alive = bytes_to_int(keep_alive_byte)
|
||||
|
||||
return cls(flags, keep_alive, protocol_name, protocol_level)
|
||||
|
||||
def to_bytes(self):
|
||||
out = b''
|
||||
|
||||
# Protocol name
|
||||
out += encode_string(self.proto_name)
|
||||
# Protocol level
|
||||
out += int_to_bytes(self.proto_level)
|
||||
# flags
|
||||
out += int_to_bytes(self.flags)
|
||||
# keep alive
|
||||
out += int_to_bytes(self.keep_alive, 2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class ConnectPayload:
|
||||
def __init__(self, client_id=None, will_topic=None, will_message=None, username=None, password=None):
|
||||
|
@ -98,7 +139,7 @@ class ConnectPayload:
|
|||
|
||||
class ConnectPacket(MQTTPacket):
|
||||
def __init__(self, vh: ConnectVariableHeader, payload: ConnectPayload):
|
||||
header = MQTTHeader(PacketType.CONNECT, 0x00)
|
||||
header = MQTTFixedHeader(PacketType.CONNECT, 0x00)
|
||||
super().__init__(header)
|
||||
self.variable_header = vh
|
||||
self.payload = payload
|
||||
|
|
|
@ -3,6 +3,10 @@
|
|||
# See the file license.txt for copying permission.
|
||||
from enum import Enum
|
||||
|
||||
from hbmqtt.errors import CodecException, MQTTException
|
||||
from hbmqtt.codecs import *
|
||||
import abc
|
||||
|
||||
class PacketType(Enum):
|
||||
RESERVED_0 = 0
|
||||
CONNECT = 1
|
||||
|
@ -25,7 +29,7 @@ class PacketType(Enum):
|
|||
def get_packet_type(byte):
|
||||
return PacketType(byte)
|
||||
|
||||
class MQTTHeader:
|
||||
class MQTTFixedHeader:
|
||||
def __init__(self, packet_type, flags=0, length=0):
|
||||
if isinstance(packet_type, int):
|
||||
enum_type = packet_type
|
||||
|
@ -35,9 +39,155 @@ class MQTTHeader:
|
|||
self.remaining_length = length
|
||||
self.flags = flags
|
||||
|
||||
def to_bytes(self):
|
||||
def encode_remaining_length(length: int):
|
||||
encoded = b''
|
||||
while True:
|
||||
length_byte = length % 0x80
|
||||
length //= 0x80
|
||||
if length > 0:
|
||||
length_byte |= 0x80
|
||||
encoded += int_to_bytes(length_byte)
|
||||
if length <= 0:
|
||||
break
|
||||
return encoded
|
||||
|
||||
out = bytes(3) # MQTTHeader are at least 3 bytes long
|
||||
packet_type = 0
|
||||
try:
|
||||
packet_type = (self.packet_type.value << 4) | self.flags
|
||||
out += int_to_bytes(packet_type)
|
||||
except OverflowError:
|
||||
raise CodecException('packet_type encoding exceed 1 byte length: value=%d', packet_type)
|
||||
|
||||
encoded_length = encode_remaining_length(self.remaining_length)
|
||||
out += encoded_length
|
||||
|
||||
return out
|
||||
|
||||
@asyncio.coroutine
|
||||
def to_stream(self, writer: asyncio.StreamWriter):
|
||||
writer.write(self.to_bytes())
|
||||
yield from writer.drain()
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: asyncio.StreamReader):
|
||||
"""
|
||||
Read and decode MQTT message fixed header from stream
|
||||
:return: FixedHeader instance
|
||||
"""
|
||||
def decode_message_type(byte):
|
||||
byte_type = (bytes_to_int(byte) & 0xf0) >> 4
|
||||
return PacketType(byte_type)
|
||||
|
||||
def decode_flags(data):
|
||||
byte = bytes_to_int(data)
|
||||
return byte & 0x0f
|
||||
|
||||
@asyncio.coroutine
|
||||
def decode_remaining_length():
|
||||
"""
|
||||
Decode message length according to MQTT specifications
|
||||
:return:
|
||||
"""
|
||||
multiplier = 1
|
||||
value = 0
|
||||
length_bytes = b''
|
||||
while True:
|
||||
encoded_byte = yield from read_or_raise(reader, 1)
|
||||
length_bytes += encoded_byte
|
||||
int_byte = bytes_to_int(encoded_byte)
|
||||
value += (int_byte & 0x7f) * multiplier
|
||||
if (int_byte & 0x80) == 0:
|
||||
break
|
||||
else:
|
||||
multiplier *= 128
|
||||
if multiplier > 128*128*128:
|
||||
raise MQTTException("Invalid remaining length bytes:%s" % bytes_to_hex_str(length_bytes))
|
||||
return value
|
||||
|
||||
b1 = yield from read_or_raise(reader, 1)
|
||||
msg_type = decode_message_type(b1)
|
||||
if msg_type is PacketType.RESERVED_0 or msg_type is PacketType.RESERVED_15:
|
||||
raise MQTTException("Usage of control packet type %s is forbidden" % msg_type)
|
||||
flags = decode_flags(b1)
|
||||
|
||||
remain_length = yield from decode_remaining_length()
|
||||
return cls(msg_type, flags, remain_length)
|
||||
|
||||
|
||||
class MQTTVariableHeader:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@asyncio.coroutine
|
||||
def to_stream(self, writer: asyncio.StreamWriter):
|
||||
writer.write(self.to_bytes())
|
||||
yield from writer.drain()
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_bytes(self):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
@abc.abstractclassmethod
|
||||
def from_stream(cls, reader: asyncio.StreamReader):
|
||||
return
|
||||
|
||||
|
||||
class MQTTPayload:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@asyncio.coroutine
|
||||
def to_stream(self, writer: asyncio.StreamWriter):
|
||||
writer.write(self.to_bytes())
|
||||
yield from writer.drain()
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_bytes(self):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
@abc.abstractclassmethod
|
||||
def from_stream(cls, reader: asyncio.StreamReader):
|
||||
return
|
||||
|
||||
|
||||
class MQTTPacket:
|
||||
def __init__(self, fixed: MQTTHeader):
|
||||
def __init__(self, fixed: MQTTFixedHeader, variable_header: MQTTVariableHeader=None, payload: MQTTPayload=None):
|
||||
self.fixed_header = fixed
|
||||
self.variable_header = None
|
||||
self.payload = None
|
||||
self.variable_header = variable_header
|
||||
self.payload = payload
|
||||
|
||||
@asyncio.coroutine
|
||||
def to_stream(self, writer: asyncio.StreamWriter):
|
||||
writer.write(self.to_bytes())
|
||||
yield from writer.drain()
|
||||
|
||||
def to_bytes(self):
|
||||
if self.variable_header:
|
||||
variable_header_bytes = self.variable_header.to_bytes()
|
||||
else:
|
||||
variable_header_bytes = b''
|
||||
if self.payload:
|
||||
payload_bytes = self.payload.to_bytes()
|
||||
else:
|
||||
payload_bytes = b''
|
||||
|
||||
self.fixed_header.remaining_length = len(variable_header_bytes) + len(payload_bytes)
|
||||
fixed_header_bytes = self.fixed_header.to_bytes()
|
||||
|
||||
return fixed_header_bytes + variable_header_bytes + payload_bytes
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
def from_stream(cls, reader: asyncio.StreamReader):
|
||||
fixed_header = yield from MQTTFixedHeader.from_stream(reader)
|
||||
variable_header = yield from MQTTVariableHeader.from_stream(reader)
|
||||
payload = yield from MQTTPayload.from_stream(reader)
|
||||
|
||||
return cls(fixed_header, variable_header, payload)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
import unittest
|
||||
import asyncio
|
||||
|
||||
from hbmqtt.handlers.utils import (
|
||||
from hbmqtt.codecs import (
|
||||
bytes_to_hex_str,
|
||||
bytes_to_int,
|
||||
decode_string,
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
__author__ = 'nico'
|
|
@ -4,18 +4,20 @@
|
|||
import unittest
|
||||
import asyncio
|
||||
|
||||
from hbmqtt.codecs.header import MQTTHeaderCodec, MQTTHeaderException
|
||||
from hbmqtt.handlers.packet import PacketHandler
|
||||
from hbmqtt.messages.packet import PacketType, MQTTHeader
|
||||
from hbmqtt.errors import MQTTException
|
||||
|
||||
|
||||
class TestMQTTHeaderCodec(unittest.TestCase):
|
||||
class TestPacketHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
def test_decode_ok(self):
|
||||
def test_read_packet_header(self):
|
||||
packet_handler = PacketHandler()
|
||||
stream = asyncio.StreamReader(loop=self.loop)
|
||||
stream.feed_data(b'\x10\x7f')
|
||||
header = self.loop.run_until_complete(MQTTHeaderCodec.decode(stream))
|
||||
header = self.loop.run_until_complete(packet_handler.read_packet_header(stream))
|
||||
self.assertEqual(header.message_type, PacketType.CONNECT)
|
||||
self.assertFalse(header.flags & 0x08)
|
||||
self.assertEqual((header.flags & 0x06) >> 1, 0)
|
||||
|
@ -25,8 +27,8 @@ class TestMQTTHeaderCodec(unittest.TestCase):
|
|||
def test_decode_ok_with_length(self):
|
||||
stream = asyncio.StreamReader(loop=self.loop)
|
||||
stream.feed_data(b'\x10\xff\xff\xff\x7f')
|
||||
header = self.loop.run_until_complete(MQTTHeaderCodec.decode(stream))
|
||||
self.assertEqual(header.message_type, PacketType.CONNECT)
|
||||
header = self.loop.run_until_complete(PacketHandler.read_packet_header(stream))
|
||||
self.assertEqual(header.packet_type, PacketType.CONNECT)
|
||||
self.assertFalse(header.flags & 0x08)
|
||||
self.assertEqual((header.flags & 0x06) >> 1, 0)
|
||||
self.assertFalse(header.flags & 0x01)
|
||||
|
@ -35,21 +37,21 @@ class TestMQTTHeaderCodec(unittest.TestCase):
|
|||
def test_decode_reserved(self):
|
||||
stream = asyncio.StreamReader(loop=self.loop)
|
||||
stream.feed_data(b'\x0f\x7f')
|
||||
with self.assertRaises(MQTTHeaderException):
|
||||
self.loop.run_until_complete(MQTTHeaderCodec.decode(stream))
|
||||
with self.assertRaises(MQTTException):
|
||||
self.loop.run_until_complete(PacketHandler.read_packet_header(stream))
|
||||
|
||||
def test_decode_ko_with_length(self):
|
||||
stream = asyncio.StreamReader(loop=self.loop)
|
||||
stream.feed_data(b'\x10\xff\xff\xff\xff\x7f')
|
||||
with self.assertRaises(MQTTHeaderException):
|
||||
self.loop.run_until_complete(MQTTHeaderCodec.decode(stream))
|
||||
with self.assertRaises(MQTTException):
|
||||
self.loop.run_until_complete(PacketHandler.read_packet_header(stream))
|
||||
|
||||
def test_encode(self):
|
||||
header = MQTTHeader(PacketType.CONNECT, 0x00, 0)
|
||||
data = MQTTHeaderCodec.encode(header)
|
||||
data = self.loop.run_until_complete(PacketHandler._encode_fixed_header(header))
|
||||
self.assertEqual(data, b'\x10\x00')
|
||||
|
||||
def test_encode_2(self):
|
||||
header = MQTTHeader(PacketType.CONNECT, 0x00, 268435455)
|
||||
data = MQTTHeaderCodec.encode(header)
|
||||
data = self.loop.run_until_complete(PacketHandler._encode_fixed_header(header))
|
||||
self.assertEqual(data, b'\x10\xff\xff\xff\x7f')
|
Ładowanie…
Reference in New Issue