Add publish payload data

pull/8/head
Nicolas Jouanin 2015-06-18 19:30:13 +02:00
rodzic dd25134902
commit a22809d2b7
2 zmienionych plików z 30 dodań i 7 usunięć

Wyświetl plik

@ -1,7 +1,7 @@
# Copyright (c) 2015 Nicolas JOUANIN # Copyright (c) 2015 Nicolas JOUANIN
# #
# See the file license.txt for copying permission. # See the file license.txt for copying permission.
from hbmqtt.mqtt.packet import MQTTPacket, MQTTFixedHeader, PacketType, MQTTVariableHeader from hbmqtt.mqtt.packet import MQTTPacket, MQTTFixedHeader, PacketType, MQTTVariableHeader, MQTTPayload
from hbmqtt.errors import HBMQTTException, MQTTException from hbmqtt.errors import HBMQTTException, MQTTException
from hbmqtt.codecs import * from hbmqtt.codecs import *
@ -79,10 +79,25 @@ class PublishVariableHeader(MQTTVariableHeader):
return cls(topic_name, packet_id) return cls(topic_name, packet_id)
class PublishPayload(MQTTPayload):
def __init__(self, data: bytes=None):
super().__init__()
self.data = data
def to_bytes(self, fixed_header: MQTTFixedHeader, variable_header: MQTTVariableHeader):
return self.data
@classmethod
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
variable_header: MQTTVariableHeader):
data = yield from reader.read()
return cls(data)
class PublishPacket(MQTTPacket): class PublishPacket(MQTTPacket):
FIXED_HEADER = PublishFixedHeader FIXED_HEADER = PublishFixedHeader
VARIABLE_HEADER = PublishVariableHeader VARIABLE_HEADER = PublishVariableHeader
PAYLOAD = None PAYLOAD = PublishPayload
def __init__(self, fixed: PublishFixedHeader=None, variable_header: PublishVariableHeader=None, payload=None): def __init__(self, fixed: PublishFixedHeader=None, variable_header: PublishVariableHeader=None, payload=None):
if fixed is None: if fixed is None:
@ -94,4 +109,9 @@ class PublishPacket(MQTTPacket):
super().__init__(header) super().__init__(header)
self.variable_header = variable_header self.variable_header = variable_header
self.payload = None self.payload = payload
@classmethod
def build(cls, topic_name: str, packet_id: int=None):
v_header = PublishVariableHeader(topic_name, packet_id)
return PublishPacket(variable_header=v_header)

Wyświetl plik

@ -3,7 +3,7 @@
# See the file license.txt for copying permission. # See the file license.txt for copying permission.
import unittest import unittest
from hbmqtt.mqtt.publish import PublishPacket, PublishVariableHeader from hbmqtt.mqtt.publish import PublishPacket, PublishVariableHeader, PublishPayload
from hbmqtt.codecs import * from hbmqtt.codecs import *
class PublishPacketTest(unittest.TestCase): class PublishPacketTest(unittest.TestCase):
@ -11,18 +11,21 @@ class PublishPacketTest(unittest.TestCase):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
def test_from_stream(self): def test_from_stream(self):
data = b'\x3f\x09\x00\x05topic\x00\x0a' data = b'\x3f\x09\x00\x05topic\x00\x0a0123456789'
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(data) stream.feed_data(data)
stream.feed_eof()
message = self.loop.run_until_complete(PublishPacket.from_stream(stream)) message = self.loop.run_until_complete(PublishPacket.from_stream(stream))
self.assertEqual(message.variable_header.topic_name, 'topic') self.assertEqual(message.variable_header.topic_name, 'topic')
self.assertEqual(message.variable_header.packet_id, 10) self.assertEqual(message.variable_header.packet_id, 10)
self.assertEqual(message.fixed_header.qos, 0x03) self.assertEqual(message.fixed_header.qos, 0x03)
self.assertTrue(message.fixed_header.dup_flag) self.assertTrue(message.fixed_header.dup_flag)
self.assertTrue(message.fixed_header.retain_flag) self.assertTrue(message.fixed_header.retain_flag)
self.assertTrue(message.payload.data, b'0123456789')
def test_to_stream(self): def test_to_stream(self):
variable_header = PublishVariableHeader('topic', 10) variable_header = PublishVariableHeader('topic', 10)
publish = PublishPacket(variable_header=variable_header) payload = PublishPayload(b'0123456789')
publish = PublishPacket(variable_header=variable_header, payload=payload)
out = publish.to_bytes() out = publish.to_bytes()
self.assertEqual(out, b'\x30\x09\x00\x05topic\x00\x0a') self.assertEqual(out, b'\x30\x13\x00\x05topic\x00\x0a0123456789')