kopia lustrzana https://github.com/Yakifo/amqtt
rodzic
752ca73af2
commit
0181795192
|
@ -6,7 +6,7 @@ import asyncio
|
|||
|
||||
from transitions import Machine, MachineError
|
||||
from hbmqtt.session import Session
|
||||
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
||||
from hbmqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler, Subscription
|
||||
from hbmqtt.mqtt.connect import ConnectPacket
|
||||
from hbmqtt.mqtt.connack import ConnackPacket, ReturnCode
|
||||
from hbmqtt.errors import HBMQTTException
|
||||
|
@ -41,6 +41,7 @@ class Broker:
|
|||
self._handlers = []
|
||||
self._init_states()
|
||||
self._sessions = dict()
|
||||
self._topics = dict()
|
||||
|
||||
def _init_states(self):
|
||||
self.machine = Machine(states=Broker.states, initial='new')
|
||||
|
@ -192,7 +193,24 @@ class Broker:
|
|||
self.logger.debug("Start messages handling")
|
||||
yield from handler.start()
|
||||
self.logger.debug("Wait for disconnect")
|
||||
yield from handler.wait_disconnect()
|
||||
|
||||
connected = True
|
||||
wait_disconnect = asyncio.Task(handler.wait_disconnect())
|
||||
wait_subscription = asyncio.Task(handler.get_next_pending_subscription())
|
||||
while connected:
|
||||
done, pending = yield from asyncio.wait([wait_disconnect, wait_subscription],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
if wait_disconnect in done:
|
||||
connected = False
|
||||
wait_subscription.cancel()
|
||||
elif wait_subscription in done:
|
||||
subscription = wait_subscription.result()
|
||||
return_codes = []
|
||||
for topic in subscription.topics:
|
||||
return_codes.append(self.add_subscription(topic, handler))
|
||||
yield from handler.mqtt_acknowledge_subscription(subscription.packet_id, return_codes)
|
||||
wait_subscription = asyncio.Task(handler.get_next_pending_subscription())
|
||||
|
||||
self.logger.debug("Client disconnected")
|
||||
yield from handler.stop()
|
||||
new_session.machine.disconnect()
|
||||
|
@ -212,3 +230,16 @@ class Broker:
|
|||
def authenticate(self, session: Session):
|
||||
# TODO : Handle client authentication here
|
||||
return True
|
||||
|
||||
def add_subscription(self, topic, handler):
|
||||
try:
|
||||
filter = topic['filter']
|
||||
qos = topic['qos']
|
||||
if 'max-qos' in self.config and qos > self.config['max-qos']:
|
||||
qos = self.config['max-qos']
|
||||
if filter not in self._topics:
|
||||
self._topics[filter] = []
|
||||
self._topics[filter].append({'handler': handler, 'qos': qos})
|
||||
return qos
|
||||
except KeyError:
|
||||
return 0x80
|
||||
|
|
|
@ -9,13 +9,23 @@ from hbmqtt.mqtt.connect import ConnectVariableHeader, ConnectPacket, ConnectPay
|
|||
from hbmqtt.mqtt.disconnect import DisconnectPacket
|
||||
from hbmqtt.mqtt.pingreq import PingReqPacket
|
||||
from hbmqtt.mqtt.pingresp import PingRespPacket
|
||||
from hbmqtt.mqtt.subscribe import SubscribePacket
|
||||
from hbmqtt.mqtt.suback import SubackPacket
|
||||
from hbmqtt.session import Session
|
||||
from hbmqtt.utils import format_client_message
|
||||
|
||||
|
||||
class Subscription:
|
||||
def __init__(self, packet_id, topics):
|
||||
self.packet_id = packet_id
|
||||
self.topics = topics
|
||||
|
||||
|
||||
class BrokerProtocolHandler(ProtocolHandler):
|
||||
def __init__(self, session: Session, loop=None):
|
||||
super().__init__(session, loop)
|
||||
self._disconnect_waiter = None
|
||||
self._pending_subscriptions = asyncio.Queue()
|
||||
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
|
@ -46,3 +56,18 @@ class BrokerProtocolHandler(ProtocolHandler):
|
|||
@asyncio.coroutine
|
||||
def handle_pingreq(self, pingreq: PingReqPacket):
|
||||
yield from self.outgoing_queue.put(PingRespPacket.build())
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_subscribe(self, subscribe: SubscribePacket):
|
||||
subscription = Subscription(subscribe.variable_header.packet_id, subscribe.payload.topics)
|
||||
yield from self._pending_subscriptions.put(subscription)
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_next_pending_subscription(self):
|
||||
subscription = yield from self._pending_subscriptions.get()
|
||||
return subscription
|
||||
|
||||
@asyncio.coroutine
|
||||
def mqtt_acknowledge_subscription(self, packet_id, return_codes):
|
||||
suback = SubackPacket.build(packet_id, return_codes)
|
||||
yield from self.outgoing_queue.put(suback)
|
Ładowanie…
Reference in New Issue