kopia lustrzana https://github.com/Yakifo/amqtt
Merge branch 'release/0.2'
commit
cfceb2e0ab
|
@ -2,4 +2,4 @@
|
|||
#
|
||||
# See the file license.txt for copying permission.
|
||||
|
||||
VERSION = (0, 1, 0, 'final', 0)
|
||||
VERSION = (0, 2, 0, 'final', 0)
|
||||
|
|
|
@ -19,7 +19,6 @@ _defaults = {
|
|||
'ping_delay': 1,
|
||||
'default_qos': 0,
|
||||
'default_retain': False,
|
||||
'inflight-polling-interval': 1,
|
||||
'subscriptions-polling-interval': 1,
|
||||
}
|
||||
|
||||
|
@ -66,8 +65,8 @@ class MQTTClient:
|
|||
:return:
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.config = config.copy()
|
||||
self.config.update(_defaults)
|
||||
self.config = _defaults
|
||||
self.config.update(config)
|
||||
if client_id is not None:
|
||||
self.client_id = client_id
|
||||
else:
|
||||
|
@ -98,8 +97,9 @@ class MQTTClient:
|
|||
self.session = self._initsession(host, port, username, password, uri, cleansession)
|
||||
self.logger.debug("Connect with session parameters: %s" % self.session)
|
||||
|
||||
yield from self._connect_coro()
|
||||
return_code = yield from self._connect_coro()
|
||||
self.machine.connect_success()
|
||||
return return_code
|
||||
except MachineError:
|
||||
msg = "Connect call incompatible with client current state '%s'" % self.machine.current_state
|
||||
self.logger.warn(msg)
|
||||
|
@ -173,12 +173,16 @@ class MQTTClient:
|
|||
|
||||
@asyncio.coroutine
|
||||
def subscribe(self, topics):
|
||||
yield from self._handler.mqtt_subscribe(topics, self.session.next_packet_id)
|
||||
return (yield from self._handler.mqtt_subscribe(topics, self.session.next_packet_id))
|
||||
|
||||
@asyncio.coroutine
|
||||
def unsubscribe(self, topics):
|
||||
yield from self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def deliver_message(self):
|
||||
return (yield from self._handler.mqtt_deliver_next_message())
|
||||
|
||||
@asyncio.coroutine
|
||||
def _connect_coro(self):
|
||||
try:
|
||||
|
@ -194,6 +198,7 @@ class MQTTClient:
|
|||
|
||||
self.session.state = SessionState.CONNECTED
|
||||
self.logger.debug("connected to %s:%s" % (self.session.remote_address, self.session.remote_port))
|
||||
return return_code
|
||||
except Exception as e:
|
||||
self.session.state = SessionState.DISCONNECTED
|
||||
raise e
|
||||
|
|
|
@ -3,27 +3,38 @@
|
|||
# See the file license.txt for copying permission.
|
||||
import logging
|
||||
import asyncio
|
||||
from asyncio import futures
|
||||
from hbmqtt.mqtt.packet import MQTTFixedHeader
|
||||
from hbmqtt.mqtt import packet_class
|
||||
from hbmqtt.errors import NoDataException
|
||||
from hbmqtt.mqtt.packet import PacketType
|
||||
from hbmqtt.mqtt.connect import ConnectVariableHeader, ConnectPacket, ConnectPayload
|
||||
from hbmqtt.mqtt.connack import ConnackPacket
|
||||
from hbmqtt.mqtt.disconnect import DisconnectPacket
|
||||
from hbmqtt.mqtt.pingreq import PingReqPacket
|
||||
from hbmqtt.mqtt.pingresp import PingRespPacket
|
||||
from hbmqtt.mqtt.publish import PublishPacket
|
||||
from hbmqtt.mqtt.pubrel import PubrelPacket
|
||||
from hbmqtt.mqtt.puback import PubackPacket
|
||||
from hbmqtt.mqtt.pubrec import PubrecPacket
|
||||
from hbmqtt.mqtt.pubcomp import PubcompPacket
|
||||
from hbmqtt.mqtt.subscribe import SubscribePacket
|
||||
from hbmqtt.mqtt.suback import SubackPacket
|
||||
from hbmqtt.mqtt.unsubscribe import UnsubscribePacket
|
||||
from hbmqtt.mqtt.unsuback import UnsubackPacket
|
||||
from hbmqtt.session import Session
|
||||
from blinker import Signal
|
||||
from transitions import Machine, MachineError
|
||||
|
||||
class InFlightMessage:
|
||||
states = ['new', 'published', 'acknowledged', 'received', 'released', 'completed']
|
||||
|
||||
def __init__(self, packet_id, qos):
|
||||
self.packet_id = packet_id
|
||||
def __init__(self, packet, qos):
|
||||
self.packet = packet
|
||||
self.qos = qos
|
||||
self.puback = None
|
||||
self.pubrec = None
|
||||
self.pubcomp = None
|
||||
self.pubrel = None
|
||||
self._init_states()
|
||||
|
||||
def _init_states(self):
|
||||
|
@ -41,8 +52,6 @@ class ProtocolHandler:
|
|||
"""
|
||||
Class implementing the MQTT communication protocol using asyncio features
|
||||
"""
|
||||
packet_sent = Signal()
|
||||
packet_received = Signal()
|
||||
|
||||
def __init__(self, session: Session, config, loop=None):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
@ -54,87 +63,81 @@ class ProtocolHandler:
|
|||
self._loop = loop
|
||||
self._reader_task = None
|
||||
self._writer_task = None
|
||||
self._inflight_task = None
|
||||
self._reader_ready = asyncio.Event(loop=self._loop)
|
||||
self._writer_ready = asyncio.Event(loop=self._loop)
|
||||
self._inflight_ready = asyncio.Event(loop=self._loop)
|
||||
self._inflight_changed = asyncio.Condition(loop=self._loop)
|
||||
|
||||
self._running = False
|
||||
|
||||
self.session.local_address, self.session.local_port = self.session.writer.get_extra_info('sockname')
|
||||
|
||||
self.incoming_queues = dict()
|
||||
self.application_messages = asyncio.Queue()
|
||||
for p in PacketType:
|
||||
self.incoming_queues[p] = asyncio.Queue()
|
||||
self.outgoing_queue = asyncio.Queue()
|
||||
self.inflight_messages = dict()
|
||||
self._puback_waiters = dict()
|
||||
self._pubrec_waiters = dict()
|
||||
self._pubrel_waiters = dict()
|
||||
self._pubcomp_waiters = dict()
|
||||
self.delivered_message = asyncio.Queue()
|
||||
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
self._running = True
|
||||
self._reader_task = asyncio.async(self._reader_coro(), loop=self._loop)
|
||||
self._writer_task = asyncio.async(self._writer_coro(), loop=self._loop)
|
||||
self._inflight_task = asyncio.async(self._inflight_coro(), loop=self._loop)
|
||||
yield from asyncio.wait(
|
||||
[self._reader_ready.wait(), self._writer_ready.wait(), self._inflight_ready.wait()], loop=self._loop)
|
||||
[self._reader_ready.wait(), self._writer_ready.wait()], loop=self._loop)
|
||||
self.logger.debug("Handler tasks started")
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_publish(self, topic, message, packet_id, dup, qos, retain):
|
||||
def qos_0_predicate():
|
||||
ret = False
|
||||
try:
|
||||
if self.inflight_messages.get(packet_id).state == 'published':
|
||||
ret = True
|
||||
#self.logger.debug("qos_0 predicate return %s" % ret)
|
||||
return ret
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def qos_1_predicate():
|
||||
ret = False
|
||||
try:
|
||||
if self.inflight_messages.get(packet_id).state == 'acknowledged':
|
||||
ret = True
|
||||
#self.logger.debug("qos_1 predicate return %s" % ret)
|
||||
return ret
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def qos_2_predicate():
|
||||
ret = False
|
||||
try:
|
||||
if self.inflight_messages.get(packet_id).state == 'completed':
|
||||
ret = True
|
||||
#self.logger.debug("qos_1 predicate return %s" % ret)
|
||||
return ret
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
if packet_id in self.inflight_messages:
|
||||
if packet_id in self.session.inflight_out:
|
||||
self.logger.warn("A message with the same packet ID is already in flight")
|
||||
packet = PublishPacket.build(topic, message, packet_id, dup, qos, retain)
|
||||
yield from self.outgoing_queue.put(packet)
|
||||
inflight_message = InFlightMessage(packet.variable_header.packet_id, qos)
|
||||
inflight_message = InFlightMessage(packet, qos)
|
||||
self.session.inflight_out[packet.variable_header.packet_id] = inflight_message
|
||||
|
||||
inflight_message.publish()
|
||||
self.inflight_messages[packet.variable_header.packet_id] = inflight_message
|
||||
yield from self._inflight_changed.acquire()
|
||||
if qos == 0x00:
|
||||
yield from self._inflight_changed.wait_for(qos_0_predicate)
|
||||
if qos == 0x01:
|
||||
yield from self._inflight_changed.wait_for(qos_1_predicate)
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._puback_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
inflight_message.puback = waiter.result()
|
||||
inflight_message.acknowledge()
|
||||
del self._puback_waiters[packet_id]
|
||||
if qos == 0x02:
|
||||
yield from self._inflight_changed.wait_for(qos_2_predicate)
|
||||
self.inflight_messages.pop(packet.variable_header.packet_id)
|
||||
self._inflight_changed.release()
|
||||
return packet
|
||||
# Wait for PUBREC
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._pubrec_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
inflight_message.pubrec = waiter.result()
|
||||
del self._pubrec_waiters[packet_id]
|
||||
inflight_message.receive()
|
||||
|
||||
# Send pubrel
|
||||
pubrel = PubrelPacket.build(packet_id)
|
||||
yield from self.outgoing_queue.put(pubrel)
|
||||
inflight_message.pubrel = pubrel
|
||||
inflight_message.release()
|
||||
|
||||
# Wait for pubcomp
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._pubcomp_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
inflight_message.pubcomp = waiter.result()
|
||||
del self._pubcomp_waiters[packet_id]
|
||||
inflight_message.complete()
|
||||
|
||||
del self.session.inflight_out[packet_id]
|
||||
return inflight_message
|
||||
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
self._running = False
|
||||
self.session.reader.feed_eof()
|
||||
yield from asyncio.wait([self._inflight_task, self._writer_task, self._reader_task], loop=self._loop)
|
||||
yield from asyncio.wait([self._writer_task, self._reader_task], loop=self._loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _reader_coro(self):
|
||||
|
@ -147,8 +150,27 @@ class ProtocolHandler:
|
|||
cls = packet_class(fixed_header)
|
||||
packet = yield from cls.from_stream(self.session.reader, fixed_header=fixed_header)
|
||||
self.logger.debug(" <-in-- " + repr(packet))
|
||||
yield from self.incoming_queues[packet.fixed_header.packet_type].put(packet)
|
||||
self.packet_received.send(packet)
|
||||
|
||||
if packet.fixed_header.packet_type == PacketType.CONNACK:
|
||||
asyncio.Task(self.handle_connack(packet))
|
||||
elif packet.fixed_header.packet_type == PacketType.SUBACK:
|
||||
asyncio.Task(self.handle_suback(packet))
|
||||
elif packet.fixed_header.packet_type == PacketType.UNSUBACK:
|
||||
asyncio.Task(self.handle_unsuback(packet))
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBACK:
|
||||
asyncio.Task(self.handle_puback(packet))
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBREC:
|
||||
asyncio.Task(self.handle_pubrec(packet))
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBREL:
|
||||
asyncio.Task(self.handle_pubrel(packet))
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBCOMP:
|
||||
asyncio.Task(self.handle_pubcomp(packet))
|
||||
elif packet.fixed_header.packet_type == PacketType.PINGRESP:
|
||||
asyncio.Task(self.handle_pingresp(packet))
|
||||
elif packet.fixed_header.packet_type == PacketType.PUBLISH:
|
||||
asyncio.Task(self.handle_publish(packet))
|
||||
else:
|
||||
self.logger.warn("Unhandled packet type: %s" % packet.fixed_header.packet_type)
|
||||
else:
|
||||
self.logger.debug("No more data, stopping reader coro")
|
||||
break
|
||||
|
@ -164,16 +186,20 @@ class ProtocolHandler:
|
|||
@asyncio.coroutine
|
||||
def _writer_coro(self):
|
||||
self.logger.debug("Starting writer coro")
|
||||
keepalive_timeout = self.session.keep_alive - self.config['ping_delay']
|
||||
while self._running:
|
||||
try:
|
||||
self._writer_ready.set()
|
||||
packet = yield from asyncio.wait_for(self.outgoing_queue.get(), 5)
|
||||
packet = yield from asyncio.wait_for(self.outgoing_queue.get(), keepalive_timeout)
|
||||
yield from packet.to_stream(self.session.writer)
|
||||
self.logger.debug(" -out-> " + repr(packet))
|
||||
yield from self.session.writer.drain()
|
||||
self.packet_sent.send(packet)
|
||||
#self.outgoing_queue.task_done() # to be used with Python 3.5
|
||||
except asyncio.TimeoutError as ce:
|
||||
self.logger.debug("Output queue get timeout")
|
||||
if self._running:
|
||||
self.logger.debug("PING for keepalive")
|
||||
self.handle_keepalive()
|
||||
except Exception as e:
|
||||
self.logger.warn("Unhandled exception in writer coro: %s" % e)
|
||||
break
|
||||
|
@ -192,168 +218,144 @@ class ProtocolHandler:
|
|||
self.logger.debug("Writer coro stopped")
|
||||
|
||||
@asyncio.coroutine
|
||||
def _inflight_coro(self):
|
||||
self.logger.debug("Starting in-flight messages polling coro")
|
||||
while self._running:
|
||||
self._inflight_ready.set()
|
||||
yield from asyncio.sleep(self.config['inflight-polling-interval'])
|
||||
self.logger.debug("in-flight polling coro wake-up")
|
||||
try:
|
||||
while not self.incoming_queues[PacketType.PUBACK].empty():
|
||||
packet = self.incoming_queues[PacketType.PUBACK].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
inflight_message = self.inflight_messages.get(packet_id)
|
||||
inflight_message.acknowledge()
|
||||
self.logger.debug("Message with packet Id=%s acknowledged" % packet_id)
|
||||
def mqtt_deliver_next_message(self):
|
||||
inflight_message = yield from self.delivered_message.get()
|
||||
return inflight_message
|
||||
|
||||
while not self.incoming_queues[PacketType.PUBREC].empty():
|
||||
packet = self.incoming_queues[PacketType.PUBREC].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
inflight_message = self.inflight_messages.get(packet_id)
|
||||
inflight_message.receive()
|
||||
self.logger.debug("Message with packet Id=%s received" % packet_id)
|
||||
def handle_keepalive(self):
|
||||
pass
|
||||
|
||||
rel_packet = PubrelPacket.build(packet_id)
|
||||
yield from self.outgoing_queue.put(rel_packet)
|
||||
inflight_message.release()
|
||||
self.logger.debug("Message with packet Id=%s released" % packet_id)
|
||||
@asyncio.coroutine
|
||||
def handle_connack(self, connack: ConnackPacket):
|
||||
pass
|
||||
|
||||
while not self.incoming_queues[PacketType.PUBCOMP].empty():
|
||||
packet = self.incoming_queues[PacketType.PUBCOMP].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
inflight_message = self.inflight_messages.get(packet_id)
|
||||
inflight_message.complete()
|
||||
self.logger.debug("Message with packet Id=%s completed" % packet_id)
|
||||
@asyncio.coroutine
|
||||
def handle_suback(self, suback: SubackPacket):
|
||||
pass
|
||||
|
||||
yield from self._inflight_changed.acquire()
|
||||
self._inflight_changed.notify_all()
|
||||
self._inflight_changed.release()
|
||||
except KeyError:
|
||||
self.logger.warn("Received %s for unknown inflight message Id %d" % (packet.fixed_header.packet_type, packet_id))
|
||||
except MachineError as me:
|
||||
self.logger.warn("Packet type incompatible with message QOS: %s" % me)
|
||||
self.logger.debug("In-flight messages polling coro stopped")
|
||||
@asyncio.coroutine
|
||||
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
pass
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_pingresp(self, pingresp: PingRespPacket):
|
||||
pass
|
||||
|
||||
class Subscription:
|
||||
states = ['new', 'subscribed', 'acknowledged']
|
||||
@asyncio.coroutine
|
||||
def handle_puback(self, puback: PubackPacket):
|
||||
packet_id = puback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._puback_waiters[packet_id]
|
||||
waiter.set_result(puback)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
def __init__(self, packet_id, topics):
|
||||
self.topics = topics
|
||||
self.packet_id = packet_id
|
||||
self._init_states()
|
||||
@asyncio.coroutine
|
||||
def handle_pubrec(self, pubrec: PubrecPacket):
|
||||
packet_id = pubrec.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._pubrec_waiters[packet_id]
|
||||
waiter.set_result(pubrec)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBREC for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
def _init_states(self):
|
||||
self.machine = Machine(model=self, states=Subscription.states, initial='new')
|
||||
self.machine.add_transition(trigger='subscribe', source='new', dest='subscribed')
|
||||
self.machine.add_transition(trigger='acknowledge', source='subscribed', dest='acknowledged')
|
||||
@asyncio.coroutine
|
||||
def handle_pubcomp(self, pubcomp: PubcompPacket):
|
||||
packet_id = pubcomp.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._pubcomp_waiters[packet_id]
|
||||
waiter.set_result(pubcomp)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBCOMP for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
class UnSubscription:
|
||||
states = ['new', 'unsubscribed', 'acknowledged']
|
||||
@asyncio.coroutine
|
||||
def handle_pubrel(self, pubrel: PubrecPacket):
|
||||
packet_id = pubrel.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._pubrel_waiters[packet_id]
|
||||
waiter.set_result(pubrel)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received PUBREL for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
def __init__(self, packet_id, topics):
|
||||
self.topics = topics
|
||||
self.packet_id = packet_id
|
||||
self._init_states()
|
||||
|
||||
def _init_states(self):
|
||||
self.machine = Machine(model=self, states=UnSubscription.states, initial='new')
|
||||
self.machine.add_transition(trigger='unsubscribe', source='new', dest='unsubscribed')
|
||||
self.machine.add_transition(trigger='acknowledge', source='unsubscribed', dest='acknowledged')
|
||||
@asyncio.coroutine
|
||||
def handle_publish(self, publish : PublishPacket):
|
||||
inflight_message = None
|
||||
packet_id = publish.variable_header.packet_id
|
||||
qos = (publish.fixed_header.flags >> 1) & 0x03
|
||||
if packet_id in self.session.inflight_in:
|
||||
inflight_message = self.session.inflight_in[packet_id]
|
||||
else:
|
||||
inflight_message = InFlightMessage(publish, qos)
|
||||
self.session.inflight_in[packet_id] = inflight_message
|
||||
inflight_message.publish()
|
||||
|
||||
if qos == 1:
|
||||
puback = PubackPacket.build(packet_id)
|
||||
yield from self.outgoing_queue.put(puback)
|
||||
inflight_message.acknowledge()
|
||||
if qos == 2:
|
||||
pubrec = PubrecPacket.build(packet_id)
|
||||
yield from self.outgoing_queue.put(pubrec)
|
||||
inflight_message.receive()
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._pubrel_waiters[packet_id] = waiter
|
||||
yield from waiter
|
||||
inflight_message.pubrel = waiter.result()
|
||||
del self._pubrel_waiters[packet_id]
|
||||
inflight_message.release()
|
||||
pubcomp = PubcompPacket.build(packet_id)
|
||||
yield from self.outgoing_queue.put(pubcomp)
|
||||
inflight_message.complete()
|
||||
yield from self.delivered_message.put(inflight_message)
|
||||
del self.session.inflight_in[packet_id]
|
||||
|
||||
class ClientProtocolHandler(ProtocolHandler):
|
||||
def __init__(self, session: Session, config, loop=None):
|
||||
super().__init__(session, config, loop)
|
||||
self._ping_task = None
|
||||
self.subscriptions = dict()
|
||||
self._subscription_task = None
|
||||
self._subscriptions_changed = asyncio.Condition(loop=self._loop)
|
||||
self._subscriptions_ready = asyncio.Event(loop=self._loop)
|
||||
self._connack_waiter = None
|
||||
self._pingresp_queue = asyncio.Queue()
|
||||
self._subscriptions_waiter = dict()
|
||||
self._unsubscriptions_waiter = dict()
|
||||
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
yield from super().start()
|
||||
self.packet_sent.connect(self._do_keepalive)
|
||||
self._subscription_task = asyncio.async(self._subscriptions_coro(), loop=self._loop)
|
||||
yield from asyncio.wait([self._subscriptions_ready.wait()], loop=self._loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
yield from super().stop()
|
||||
yield from asyncio.wait([self._subscription_task], loop=self._loop)
|
||||
if self._ping_task:
|
||||
try:
|
||||
self._ping_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _do_keepalive(self, message):
|
||||
if self._ping_task:
|
||||
try:
|
||||
self._ping_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
next_ping = self.session.keep_alive - self.config['ping_delay']
|
||||
if next_ping > 0:
|
||||
self.logger.debug('Next ping in %d seconds if no new messages between' % next_ping)
|
||||
self._ping_task = self._loop.call_later(next_ping, asyncio.async, self.mqtt_ping())
|
||||
|
||||
def _subscriptions_coro(self):
|
||||
self.logger.debug("Starting subscriptions polling coro")
|
||||
while self._running:
|
||||
self._subscriptions_ready.set()
|
||||
yield from asyncio.sleep(self.config['subscriptions-polling-interval'])
|
||||
self.logger.debug("Subscriptions polling coro wake-up")
|
||||
try:
|
||||
while not self.incoming_queues[PacketType.SUBACK].empty():
|
||||
packet = self.incoming_queues[PacketType.SUBACK].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
subscription = self.subscriptions.get(packet_id)
|
||||
for i in range(len(subscription.topics)):
|
||||
subscription.topics[i]['return_code'] = packet.payload.return_codes[i]
|
||||
subscription.acknowledge()
|
||||
self.logger.debug("Subscription with packet Id=%s acknowledged" % packet_id)
|
||||
|
||||
while not self.incoming_queues[PacketType.UNSUBACK].empty():
|
||||
packet = self.incoming_queues[PacketType.UNSUBACK].get_nowait()
|
||||
packet_id = packet.variable_header.packet_id
|
||||
subscription = self.subscriptions.get(packet_id)
|
||||
subscription.acknowledge()
|
||||
self.logger.debug("Unsubscription with packet Id=%s acknowledged" % packet_id)
|
||||
|
||||
yield from self._subscriptions_changed.acquire()
|
||||
self._subscriptions_changed.notify_all()
|
||||
self._subscriptions_changed.release()
|
||||
except KeyError:
|
||||
self.logger.warn("Received %s for unknown subscription message Id %d" % (packet.fixed_header.packet_type, packet_id))
|
||||
except MachineError as me:
|
||||
self.logger.warn("Packet type incompatible with message QOS: %s" % me)
|
||||
self.logger.debug("Subscriptions polling coro stopped")
|
||||
def handle_keepalive(self):
|
||||
self._ping_task = self._loop.call_soon(asyncio.async, self.mqtt_ping())
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_subscribe(self, topics, packet_id):
|
||||
"""
|
||||
|
||||
:param topics: array of topics [{'filter':'/a/b', 'qos': 0x00}, ...]
|
||||
:return:
|
||||
"""
|
||||
def acknowledged_predicate():
|
||||
if self.subscriptions[subscribe.variable_header.packet_id].state == 'acknowledged':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
subscribe = SubscribePacket.build(topics, packet_id)
|
||||
yield from self.outgoing_queue.put(subscribe)
|
||||
subscription = Subscription(subscribe.variable_header.packet_id, topics)
|
||||
subscription.subscribe()
|
||||
self.subscriptions[subscribe.variable_header.packet_id] = subscription
|
||||
yield from self._subscriptions_changed.acquire()
|
||||
yield from self._subscriptions_changed.wait_for(acknowledged_predicate)
|
||||
subscription = self.subscriptions.pop(subscribe.variable_header.packet_id)
|
||||
self._subscriptions_changed.release()
|
||||
return subscription
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
|
||||
return_codes = yield from waiter
|
||||
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
|
||||
return return_codes
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_suback(self, suback: SubackPacket):
|
||||
packet_id = suback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._subscriptions_waiter.get(packet_id)
|
||||
waiter.set_result(suback.payload.return_codes)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received SUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_unsubscribe(self, topics, packet_id):
|
||||
|
@ -362,23 +364,21 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
:param topics: array of topics ['/a/b', ...]
|
||||
:return:
|
||||
"""
|
||||
def acknowledged_predicate():
|
||||
if self.subscriptions[unsubscribe.variable_header.packet_id].state == 'acknowledged':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
unsubscribe = UnsubscribePacket.build(topics, packet_id)
|
||||
yield from self.outgoing_queue.put(unsubscribe)
|
||||
subscription = UnSubscription(unsubscribe.variable_header.packet_id, topics)
|
||||
subscription.unsubscribe()
|
||||
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription
|
||||
self.subscriptions[unsubscribe.variable_header.packet_id] = subscription
|
||||
yield from self._subscriptions_changed.acquire()
|
||||
yield from self._subscriptions_changed.wait_for(acknowledged_predicate)
|
||||
subscription = self.subscriptions.pop(unsubscribe.variable_header.packet_id)
|
||||
self._subscriptions_changed.release()
|
||||
return subscription
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
|
||||
yield from waiter
|
||||
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
packet_id = unsuback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._unsubscriptions_waiter.get(packet_id)
|
||||
waiter.set_result(None)
|
||||
except KeyError as ke:
|
||||
self.logger.warn("Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_connect(self):
|
||||
|
@ -416,19 +416,28 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
|
||||
packet = build_connect_packet(self.session)
|
||||
yield from self.outgoing_queue.put(packet)
|
||||
connack = yield from self.incoming_queues[PacketType.CONNACK].get()
|
||||
self._connack_waiter = futures.Future(loop=self._loop)
|
||||
return (yield from self._connack_waiter)
|
||||
|
||||
return connack.variable_header.return_code
|
||||
@asyncio.coroutine
|
||||
def handle_connack(self, connack: ConnackPacket):
|
||||
self._connack_waiter.set_result(connack.variable_header.return_code)
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_disconnect(self):
|
||||
# yield from self.outgoing_queue.join() To be used in Python 3.5
|
||||
disconnect_packet = DisconnectPacket()
|
||||
yield from self.outgoing_queue.put(disconnect_packet)
|
||||
self._ping_task.cancel()
|
||||
self._connack_waiter = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_ping(self):
|
||||
self.logger.debug("Pinging ...")
|
||||
ping_packet = PingReqPacket()
|
||||
yield from self.outgoing_queue.put(ping_packet)
|
||||
yield from self.incoming_queues[PacketType.PINGRESP].get()
|
||||
self._pingresp_waiter = futures.Future(loop=self._loop)
|
||||
resp = yield from self._pingresp_queue.get()
|
||||
return resp
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_pingresp(self, pingresp: PingRespPacket):
|
||||
yield from self._pingresp_queue.put(pingresp)
|
||||
|
|
|
@ -19,3 +19,9 @@ class PubackPacket(MQTTPacket):
|
|||
super().__init__(header)
|
||||
self.variable_header = variable_header
|
||||
self.payload = None
|
||||
|
||||
@classmethod
|
||||
def build(cls, packet_id: int):
|
||||
v_header = PacketIdVariableHeader(packet_id)
|
||||
packet = PubackPacket(variable_header=v_header, payload=None)
|
||||
return packet
|
||||
|
|
|
@ -19,3 +19,9 @@ class PubcompPacket(MQTTPacket):
|
|||
super().__init__(header)
|
||||
self.variable_header = variable_header
|
||||
self.payload = None
|
||||
|
||||
@classmethod
|
||||
def build(cls, packet_id: int):
|
||||
v_header = PacketIdVariableHeader(packet_id)
|
||||
packet = PubcompPacket(variable_header=v_header, payload=None)
|
||||
return packet
|
||||
|
|
|
@ -6,54 +6,6 @@ from hbmqtt.errors import HBMQTTException, MQTTException
|
|||
from hbmqtt.codecs import *
|
||||
|
||||
|
||||
class PublishFixedHeader(MQTTFixedHeader):
|
||||
DUP_FLAG = 0x08
|
||||
RETAIN_FLAG = 0x01
|
||||
QOS_FLAG = 0x06
|
||||
|
||||
def set_flags(self, dup_flag=False, qos=0, retain_flag=False):
|
||||
self.dup_flag = dup_flag
|
||||
self.retain_flag = retain_flag
|
||||
self.qos = qos
|
||||
|
||||
def _set_flag(self, val, mask):
|
||||
if val:
|
||||
self.flags |= mask
|
||||
else:
|
||||
self.flags &= ~mask
|
||||
|
||||
def _get_flag(self, mask):
|
||||
if self.flags & mask:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def dup_flag(self) -> bool:
|
||||
return self._get_flag(self.DUP_FLAG)
|
||||
|
||||
@dup_flag.setter
|
||||
def dup_flag(self, val: bool):
|
||||
self._set_flag(val, self.DUP_FLAG)
|
||||
|
||||
@property
|
||||
def retain_flag(self) -> bool:
|
||||
return self._get_flag(self.RETAIN_FLAG)
|
||||
|
||||
@retain_flag.setter
|
||||
def retain_flag(self, val: bool):
|
||||
self._set_flag(val, self.RETAIN_FLAG)
|
||||
|
||||
@property
|
||||
def qos(self):
|
||||
return (self.flags & self.QOS_FLAG) >> 1
|
||||
|
||||
@qos.setter
|
||||
def qos(self, val: int):
|
||||
self.flags &= (0x00 << 1)
|
||||
self.flags |= (val << 1)
|
||||
|
||||
|
||||
class PublishVariableHeader(MQTTVariableHeader):
|
||||
def __init__(self, topic_name: str, packet_id: int=None):
|
||||
super().__init__()
|
||||
|
@ -73,9 +25,10 @@ class PublishVariableHeader(MQTTVariableHeader):
|
|||
return out
|
||||
|
||||
@classmethod
|
||||
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: PublishFixedHeader):
|
||||
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader):
|
||||
topic_name = yield from decode_string(reader)
|
||||
if fixed_header.qos:
|
||||
has_qos = (fixed_header.flags >> 1) & 0x03
|
||||
if has_qos:
|
||||
packet_id = yield from decode_packet_id(reader)
|
||||
else:
|
||||
packet_id = None
|
||||
|
@ -93,18 +46,25 @@ class PublishPayload(MQTTPayload):
|
|||
@classmethod
|
||||
def from_stream(cls, reader: asyncio.StreamReader, fixed_header: MQTTFixedHeader,
|
||||
variable_header: MQTTVariableHeader):
|
||||
data = yield from reader.read()
|
||||
data = yield from reader.read(fixed_header.remaining_length-variable_header.bytes_length)
|
||||
return cls(data)
|
||||
|
||||
def __repr__(self):
|
||||
return type(self).__name__ + '(data={0!r})'.format(repr(self.data))
|
||||
|
||||
|
||||
class PublishPacket(MQTTPacket):
|
||||
FIXED_HEADER = PublishFixedHeader
|
||||
VARIABLE_HEADER = PublishVariableHeader
|
||||
PAYLOAD = PublishPayload
|
||||
|
||||
def __init__(self, fixed: PublishFixedHeader=None, variable_header: PublishVariableHeader=None, payload=None):
|
||||
DUP_FLAG = 0x08
|
||||
RETAIN_FLAG = 0x01
|
||||
QOS_FLAG = 0x06
|
||||
|
||||
|
||||
def __init__(self, fixed: MQTTFixedHeader=None, variable_header: PublishVariableHeader=None, payload=None):
|
||||
if fixed is None:
|
||||
header = PublishFixedHeader(PacketType.PUBLISH, 0x00)
|
||||
header = MQTTFixedHeader(PacketType.PUBLISH, 0x00)
|
||||
else:
|
||||
if fixed.packet_type is not PacketType.PUBLISH:
|
||||
raise HBMQTTException("Invalid fixed packet type %s for PublishPacket init" % fixed.packet_type)
|
||||
|
@ -114,12 +74,54 @@ class PublishPacket(MQTTPacket):
|
|||
self.variable_header = variable_header
|
||||
self.payload = payload
|
||||
|
||||
def set_flags(self, dup_flag=False, qos=0, retain_flag=False):
|
||||
self.dup_flag = dup_flag
|
||||
self.retain_flag = retain_flag
|
||||
self.qos = qos
|
||||
|
||||
def _set_header_flag(self, val, mask):
|
||||
if val:
|
||||
self.fixed_header.flags |= mask
|
||||
else:
|
||||
self.fixed_header.flags &= ~mask
|
||||
|
||||
def _get_header_flag(self, mask):
|
||||
if self.fixed_header.flags & mask:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def dup_flag(self) -> bool:
|
||||
return self._get_header_flag(self.DUP_FLAG)
|
||||
|
||||
@dup_flag.setter
|
||||
def dup_flag(self, val: bool):
|
||||
self._set_header_flag(val, self.DUP_FLAG)
|
||||
|
||||
@property
|
||||
def retain_flag(self) -> bool:
|
||||
return self._get_header_flag(self.RETAIN_FLAG)
|
||||
|
||||
@retain_flag.setter
|
||||
def retain_flag(self, val: bool):
|
||||
self._set_header_flag(val, self.RETAIN_FLAG)
|
||||
|
||||
@property
|
||||
def qos(self):
|
||||
return (self.fixed_header.flags & self.QOS_FLAG) >> 1
|
||||
|
||||
@qos.setter
|
||||
def qos(self, val: int):
|
||||
self.fixed_header.flags &= (0x00 << 1)
|
||||
self.fixed_header.flags |= (val << 1)
|
||||
|
||||
@classmethod
|
||||
def build(cls, topic_name: str, message:bytes, packet_id: int, dup_flag, qos, retain):
|
||||
v_header = PublishVariableHeader(topic_name, packet_id)
|
||||
payload = PublishPayload(message)
|
||||
packet = PublishPacket(variable_header=v_header, payload=payload)
|
||||
packet.fixed_header.dup_flag = dup_flag
|
||||
packet.fixed_header.retain_flag = retain
|
||||
packet.fixed_header.qos = qos
|
||||
packet.dup_flag = dup_flag
|
||||
packet.retain_flag = retain
|
||||
packet.qos = qos
|
||||
return packet
|
||||
|
|
|
@ -19,3 +19,9 @@ class PubrecPacket(MQTTPacket):
|
|||
super().__init__(header)
|
||||
self.variable_header = variable_header
|
||||
self.payload = None
|
||||
|
||||
@classmethod
|
||||
def build(cls, packet_id: int):
|
||||
v_header = PacketIdVariableHeader(packet_id)
|
||||
packet = PubrecPacket(variable_header=v_header, payload=None)
|
||||
return packet
|
||||
|
|
|
@ -27,6 +27,9 @@ class Session:
|
|||
self.scheme = None
|
||||
self._packet_id = 0
|
||||
|
||||
self.inflight_out = dict()
|
||||
self.inflight_in = dict()
|
||||
|
||||
@property
|
||||
def next_packet_id(self):
|
||||
self._packet_id += 1
|
||||
|
|
16
readme.md
16
readme.md
|
@ -2,5 +2,19 @@
|
|||
|
||||
HBMQTT is an open source [MQTT](http://www.mqtt.org) broker written with Python using asynchronous I/O.
|
||||
|
||||
## Getting started
|
||||
|
||||
hbmqtt is deployed on [Pypi](https://pypi.python.org/pypi/hbmqtt) and can installed simply using `pip` :
|
||||
|
||||
$ pip install hbmqtt
|
||||
|
||||
### Client
|
||||
|
||||
MQTT client resides in the `Client` class. The examples scripts in `samples/` sub-directory to know hos to use if for connecting, subscribing or publishing on a MQTT broker.
|
||||
|
||||
## Build status
|
||||
[](https://travis-ci.org/beerfactory/hbmqtt)
|
||||
[](https://travis-ci.org/beerfactory/hbmqtt)
|
||||
|
||||
## Support
|
||||
|
||||
Support is available on the [project forum](http://forum.beerfactory.org/c/hbmqtt). Issues can be reported directly on Beerfactory project [Jira instance](http://community.beerfactory.org/jira/).
|
|
@ -0,0 +1,30 @@
|
|||
import logging
|
||||
from hbmqtt.client._client import MQTTClient
|
||||
import asyncio
|
||||
|
||||
#
|
||||
# This sample shows a client running idle.
|
||||
# Meanwhile, keepalive is managed through PING messages sent every 5 seconds
|
||||
#
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = {
|
||||
'keep_alive': 5,
|
||||
'ping_delay': 1,
|
||||
}
|
||||
C = MQTTClient(config=config)
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from C.connect(uri='mqtt://iot.eclipse.org:1883/', username=None, password=None)
|
||||
yield from asyncio.sleep(18)
|
||||
|
||||
yield from C.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
asyncio.get_event_loop().run_until_complete(test_coro())
|
|
@ -0,0 +1,32 @@
|
|||
import logging
|
||||
from hbmqtt.client._client import MQTTClient
|
||||
import asyncio
|
||||
|
||||
|
||||
#
|
||||
# This sample shows how to publish messages to broker using different QOS
|
||||
# Debug outputs shows the message flows
|
||||
#
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
C = MQTTClient()
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from C.connect(uri='mqtt://iot.eclipse.org:1883/', username=None, password=None)
|
||||
tasks = [
|
||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_0')),
|
||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_1', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'TEST MESSAGE WITH QOS_2', qos=0x02)),
|
||||
]
|
||||
yield from asyncio.wait(tasks)
|
||||
|
||||
logger.info("messages published")
|
||||
yield from C.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
asyncio.get_event_loop().run_until_complete(test_coro())
|
|
@ -0,0 +1,36 @@
|
|||
import logging
|
||||
from hbmqtt.client._client import MQTTClient
|
||||
import asyncio
|
||||
|
||||
|
||||
#
|
||||
# This sample shows how to subscbribe a topic and receive data from incoming messages
|
||||
# It subscribes to '$SYS/broker/uptime' topic and displays the first ten values returned
|
||||
# by the broker.
|
||||
#
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
C = MQTTClient()
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from C.connect(uri='mqtt://iot.eclipse.org:1883/', username=None, password=None)
|
||||
# Subscribe to '$SYS/broker/uptime' with QOS=1
|
||||
yield from C.subscribe([
|
||||
{'filter': '$SYS/broker/uptime', 'qos': 0x01},
|
||||
])
|
||||
logger.info("Subscribed")
|
||||
for i in range (1,10):
|
||||
ret = yield from C.deliver_message()
|
||||
data = ret.packet.payload.data
|
||||
logger.info(str(data))
|
||||
yield from C.unsubscribe(['$SYS/broker/uptime'])
|
||||
logger.info("UnSubscribed")
|
||||
yield from C.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=formatter)
|
||||
asyncio.get_event_loop().run_until_complete(test_coro())
|
|
@ -1,51 +0,0 @@
|
|||
import logging
|
||||
from hbmqtt.client._client import MQTTClient
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
C = MQTTClient()
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_coro():
|
||||
yield from C.connect(uri='mqtt://iot.eclipse.org:1883/', username='testuser', password="passwd")
|
||||
tasks = [
|
||||
asyncio.async(C.publish('a/b', b'0123456789')),
|
||||
asyncio.async(C.publish('a/b', b'0', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'1', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'2', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'3', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'4', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'5', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'6', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'7', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'8', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'9', qos=0x01)),
|
||||
asyncio.async(C.publish('a/b', b'0', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'1', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'2', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'3', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'4', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'5', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'6', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'7', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'8', qos=0x02)),
|
||||
asyncio.async(C.publish('a/b', b'9', qos=0x02)),
|
||||
]
|
||||
yield from asyncio.wait(tasks)
|
||||
logger.info("messages published")
|
||||
yield from C.subscribe([
|
||||
{'filter': 'a/b', 'qos': 0x01},
|
||||
{'filter': 'c/d', 'qos': 0x02}
|
||||
])
|
||||
logger.info("Subscribed")
|
||||
yield from C.unsubscribe(['a/b', 'c/d'])
|
||||
logger.info("Unsubscribed")
|
||||
|
||||
yield from C.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
formatter = "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||
asyncio.get_event_loop().run_until_complete(test_coro())
|
|
@ -1,91 +0,0 @@
|
|||
# Copyright (c) 2015 Nicolas JOUANIN
|
||||
#
|
||||
# See the file license.txt for copying permission.
|
||||
import unittest
|
||||
import asyncio
|
||||
|
||||
from hbmqtt.mqtt.connect import ConnectPacket, ConnectVariableHeader, ConnectPayload
|
||||
from hbmqtt.mqtt.protocol import ProtocolHandler
|
||||
from hbmqtt.session import Session
|
||||
from hbmqtt.mqtt.packet import PacketType
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
ret_packet = None
|
||||
|
||||
config = {
|
||||
'keep_alive': 10,
|
||||
'ping_delay': 1,
|
||||
'default_qos': 0,
|
||||
'default_retain': False,
|
||||
'inflight-polling-interval': 1,
|
||||
'subscriptions-polling-interval': 1,
|
||||
}
|
||||
|
||||
class ConnectPacketTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def test_read_loop(self):
|
||||
data = b'\x10\x3e\x00\x04MQTT\x04\xce\x00\x00\x00\x0a0123456789\x00\x09WillTopic\x00\x0bWillMessage\x00\x04user\x00\x08password'
|
||||
@asyncio.coroutine
|
||||
def serve_test(reader, writer):
|
||||
writer.write(data)
|
||||
yield from writer.drain()
|
||||
writer.close()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
coro = asyncio.start_server(serve_test, '127.0.0.1', 8888, loop=loop)
|
||||
server = loop.run_until_complete(coro)
|
||||
|
||||
@asyncio.coroutine
|
||||
def client():
|
||||
S = Session()
|
||||
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888,
|
||||
loop=loop)
|
||||
h = ProtocolHandler(S, config)
|
||||
yield from h.start()
|
||||
incoming_packet = yield from h.incoming_queues[PacketType.CONNECT].get()
|
||||
yield from h.stop()
|
||||
return incoming_packet
|
||||
|
||||
packet = loop.run_until_complete(client())
|
||||
server.close()
|
||||
self.assertEquals(packet.fixed_header.packet_type, PacketType.CONNECT)
|
||||
|
||||
def test_write_loop(self):
|
||||
test_packet = ConnectPacket(vh=ConnectVariableHeader(), payload=ConnectPayload('Id', 'WillTopic', 'WillMessage', 'user', 'password'))
|
||||
event=asyncio.Event()
|
||||
|
||||
@asyncio.coroutine
|
||||
def serve_test(reader, writer):
|
||||
global ret_packet
|
||||
packet = yield from ConnectPacket.from_stream(reader)
|
||||
ret_packet = packet
|
||||
writer.close()
|
||||
event.set()
|
||||
|
||||
@asyncio.coroutine
|
||||
def client():
|
||||
S = Session()
|
||||
S.reader, S.writer = yield from asyncio.open_connection('127.0.0.1', 8888, loop=loop)
|
||||
h = ProtocolHandler(S, config)
|
||||
yield from h.start()
|
||||
yield from h.outgoing_queue.put(test_packet)
|
||||
yield from h.stop()
|
||||
|
||||
# Start server
|
||||
loop = asyncio.get_event_loop()
|
||||
coro = asyncio.start_server(serve_test, '127.0.0.1', 8888, loop=loop)
|
||||
server = loop.run_until_complete(coro)
|
||||
|
||||
# Schedule client
|
||||
loop.call_soon(asyncio.async, client())
|
||||
|
||||
# Wait for server to complete client request
|
||||
loop.run_until_complete(asyncio.wait([event.wait()]))
|
||||
server.close()
|
||||
self.logger.info(ret_packet)
|
||||
self.assertEquals(ret_packet.fixed_header.packet_type, PacketType.CONNECT)
|
|
@ -10,22 +10,40 @@ class PublishPacketTest(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
def test_from_stream(self):
|
||||
data = b'\x3f\x09\x00\x05topic\x00\x0a0123456789'
|
||||
def test_from_stream_qos_0(self):
|
||||
data = b'\x31\x11\x00\x05topic0123456789'
|
||||
stream = asyncio.StreamReader(loop=self.loop)
|
||||
stream.feed_data(data)
|
||||
stream.feed_eof()
|
||||
message = self.loop.run_until_complete(PublishPacket.from_stream(stream))
|
||||
self.assertEqual(message.variable_header.topic_name, 'topic')
|
||||
self.assertEqual(message.variable_header.packet_id, None)
|
||||
self.assertFalse((message.fixed_header.flags >> 1) & 0x03)
|
||||
self.assertTrue(message.fixed_header.flags & 0x01)
|
||||
self.assertTrue(message.payload.data, b'0123456789')
|
||||
|
||||
def test_from_stream_qos_2(self):
|
||||
data = b'\x37\x13\x00\x05topic\x00\x0a0123456789'
|
||||
stream = asyncio.StreamReader(loop=self.loop)
|
||||
stream.feed_data(data)
|
||||
stream.feed_eof()
|
||||
message = self.loop.run_until_complete(PublishPacket.from_stream(stream))
|
||||
self.assertEqual(message.variable_header.topic_name, 'topic')
|
||||
self.assertEqual(message.variable_header.packet_id, 10)
|
||||
self.assertEqual(message.fixed_header.qos, 0x03)
|
||||
self.assertTrue(message.fixed_header.dup_flag)
|
||||
self.assertTrue(message.fixed_header.retain_flag)
|
||||
self.assertTrue((message.fixed_header.flags >> 1) & 0x03)
|
||||
self.assertTrue(message.fixed_header.flags & 0x01)
|
||||
self.assertTrue(message.payload.data, b'0123456789')
|
||||
|
||||
def test_to_stream(self):
|
||||
def test_to_stream_no_packet_id(self):
|
||||
variable_header = PublishVariableHeader('topic', None)
|
||||
payload = PublishPayload(b'0123456789')
|
||||
publish = PublishPacket(variable_header=variable_header, payload=payload)
|
||||
out = publish.to_bytes()
|
||||
self.assertEqual(out, b'\x30\x11\x00\x05topic0123456789')
|
||||
|
||||
def test_to_stream_packet(self):
|
||||
variable_header = PublishVariableHeader('topic', 10)
|
||||
payload = PublishPayload(b'0123456789')
|
||||
publish = PublishPacket(variable_header=variable_header, payload=payload)
|
||||
out = publish.to_bytes()
|
||||
self.assertEqual(out, b'\x30\x13\x00\x05topic\x00\x0a0123456789')
|
||||
self.assertEqual(out, b'\x30\x13\x00\x05topic\00\x0a0123456789')
|
||||
|
|
Ładowanie…
Reference in New Issue