kopia lustrzana https://github.com/Yakifo/amqtt
Implement CONNECT handler
rodzic
df5ff26ece
commit
713c3aaaf6
|
@ -30,6 +30,7 @@ class Broker:
|
|||
self._loop_thread = None
|
||||
self._message_handlers = None
|
||||
self._codecs = None
|
||||
self._sessions = dict()
|
||||
|
||||
def _init_states(self):
|
||||
self.machine = Machine(states=Broker.states, initial='new')
|
||||
|
@ -100,6 +101,32 @@ class Broker:
|
|||
self.logger.debug("Broker started, ready to serve")
|
||||
loop.run_forever()
|
||||
|
||||
def discard_session(self, client_id):
|
||||
if client_id in self._sessions:
|
||||
del self._sessions[client_id]
|
||||
else:
|
||||
self.logger.warn("Asked to discard an unknown client ID session")
|
||||
|
||||
def create_session(self, remote_address, remote_port, client_id, clean_session):
|
||||
session = Session(remote_address, remote_port, client_id, clean_session)
|
||||
if client_id in self._sessions:
|
||||
raise BrokerException("Session already exists for client ID: %s", client_id)
|
||||
self.save_session(session)
|
||||
return session
|
||||
|
||||
def get_session(self, client_id):
|
||||
if client_id not in self._sessions:
|
||||
raise BrokerException("Unknown session for client ID: %s", client_id)
|
||||
else:
|
||||
return self._sessions[client_id]
|
||||
|
||||
def resume_session(self, session: Session):
|
||||
# TBD
|
||||
pass
|
||||
|
||||
def save_session(self, session: Session):
|
||||
self._sessions[session.client_id] = session
|
||||
|
||||
@asyncio.coroutine
|
||||
def _handle_message(self, message: MQTTMessage) -> MQTTMessage:
|
||||
handler = self._message_handlers[message.mqtt_header.message_type]
|
||||
|
@ -130,7 +157,8 @@ class Broker:
|
|||
request = yield from self._decode_message(header.message_type, reader)
|
||||
|
||||
(remote_address, remote_port) = writer.get_extra_info('peername')
|
||||
session = Session(remote_address, remote_port, request.client_id)
|
||||
request.remote_address = remote_address
|
||||
request.remote_address = remote_port
|
||||
|
||||
response = self._handle_message(request)
|
||||
encoded_response = yield from self._encode_message(response)
|
||||
|
|
|
@ -3,13 +3,69 @@
|
|||
# See the file license.txt for copying permission.
|
||||
|
||||
import asyncio
|
||||
from hbmqtt.message import MQTTMessage, ConnectMessage
|
||||
from hbmqtt.message import MQTTMessage, ConnectMessage, ConnackMessage
|
||||
from hbmqtt.broker.session import Session, ClientState
|
||||
from hbmqtt.errors import BrokerException
|
||||
|
||||
class ConnectHandler:
|
||||
def __init__(self, broker):
|
||||
self._broker = broker
|
||||
self.broker = broker
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle(message: ConnectMessage) -> MQTTMessage:
|
||||
# TBD
|
||||
pass
|
||||
def handle(self, message: ConnectMessage) -> MQTTMessage:
|
||||
session = None
|
||||
response = None
|
||||
|
||||
# Check Protocol
|
||||
# protocol level (only MQTT 3.1.1 supported)
|
||||
if message.proto_level != 4:
|
||||
return ConnackMessage(False, ConnackMessage.ReturnCode.UNACCEPTABLE_PROTOCOL_VERSION)
|
||||
|
||||
# No client ID provided
|
||||
if message.client_id is None or message.client_id == "":
|
||||
if message.is_clean_session():
|
||||
# [MQTT-3.1.3-6] and [MQTT-3.1.3-7]
|
||||
message.client_id = self.gen_client_id()
|
||||
else:
|
||||
# [MQTT-3.1.3-8] : Identifier rejected
|
||||
return ConnackMessage(False, ConnackMessage.ReturnCode.IDENTIFIER_REJECTED)
|
||||
|
||||
if message.is_clean_session():
|
||||
try:
|
||||
self.broker.discard_session(message.client_id)
|
||||
except BrokerException:
|
||||
pass
|
||||
session = self.broker.create_session(message._remote_address, message._remote_port, message.client_id, message.is_clean_session())
|
||||
# [MQTT-3.2.2-1]
|
||||
response = ConnackMessage(False, 0)
|
||||
else:
|
||||
try:
|
||||
session = self.broker.get_session(message.client_id)
|
||||
if session.client_state == ClientState.CONNECTED:
|
||||
# [MQTT-3.1.4-2]
|
||||
# TODO : Add logging
|
||||
return ConnackMessage(False, ConnackMessage.ReturnCode.IDENTIFIER_REJECTED)
|
||||
else:
|
||||
# [MQTT-3.2.2-2]
|
||||
response = ConnackMessage(True, 0)
|
||||
except BrokerException:
|
||||
session = self.broker.create_session(message._remote_address, message._remote_port, message.client_id, message.is_clean_session())
|
||||
response = ConnackMessage(False, 0)
|
||||
if session.client_state == ClientState.DISCONNECTED:
|
||||
self.broker.resume_session(session)
|
||||
session.client_state = ClientState.CONNECTED
|
||||
|
||||
if message.is_will_flag():
|
||||
session.will_flag = True
|
||||
session.will_message = message.will_message
|
||||
session.will_qos = message.will_qos()
|
||||
session.will_retain = message.is_will_retain()
|
||||
|
||||
session.keep_alive = message.keep_alive
|
||||
|
||||
self.broker.save_session(session)
|
||||
return response
|
||||
|
||||
def gen_client_id(self):
|
||||
import uuid
|
||||
return uuid.uuid4()
|
|
@ -1,9 +1,21 @@
|
|||
# Copyright (c) 2015 Nicolas JOUANIN
|
||||
#
|
||||
# See the file license.txt for copying permission.
|
||||
import enum
|
||||
|
||||
class ClientState(enum):
|
||||
CONNECTED = 1
|
||||
DISCONNECTED = 2
|
||||
|
||||
class Session:
|
||||
def __init__(self, remote_address, remote_port, client_id):
|
||||
def __init__(self, remote_address, remote_port, client_id, clean_session):
|
||||
self.remote_address = remote_address
|
||||
self.remote_port = remote_port
|
||||
self.client_id = client_id
|
||||
self.clean_session = clean_session
|
||||
self.client_state = ClientState.CONNECTED
|
||||
self.will_flag = False
|
||||
self.will_message = None
|
||||
self.will_qos = None
|
||||
self.will_retain = None
|
||||
self.keep_alive = 0
|
||||
|
|
|
@ -36,9 +36,6 @@ class ConnectCodec:
|
|||
# protocol level (only MQTT 3.1.1 supported)
|
||||
protocol_level_byte = yield from read_or_raise(reader, 1)
|
||||
protocol_level = bytes_to_int(protocol_level_byte)
|
||||
if protocol_level != 4:
|
||||
raise ConnectException(
|
||||
'[MQTT-3.1.2-2] Unsupported protocol level %s' % bytes_to_hex_str(protocol_level_byte))
|
||||
|
||||
# flags
|
||||
flags_byte = yield from read_or_raise(reader, 1)
|
||||
|
|
|
@ -26,7 +26,7 @@ def get_message_type(byte):
|
|||
return MessageType(byte)
|
||||
|
||||
class MQTTHeader:
|
||||
def __init__(self, msg_type, flags, length):
|
||||
def __init__(self, msg_type, flags=0, length=0):
|
||||
if isinstance(msg_type, int):
|
||||
enum_type = msg_type
|
||||
else:
|
||||
|
@ -36,7 +36,7 @@ class MQTTHeader:
|
|||
self.flags = flags
|
||||
|
||||
class MQTTMessage:
|
||||
def __init__(self, header: MQTTHeader):
|
||||
def __init__(self, header):
|
||||
# MQTT header
|
||||
self.mqtt_header = header
|
||||
|
||||
|
@ -79,3 +79,19 @@ class ConnectMessage(MQTTMessage):
|
|||
|
||||
def will_qos(self):
|
||||
return (self.flags & 0x18) >> 3
|
||||
|
||||
|
||||
class ConnackMessage(MQTTMessage):
|
||||
def __init__(self, session_parent, return_code):
|
||||
header = MQTTHeader(MessageType.CONNACK)
|
||||
super().__init__(header)
|
||||
self.session_parent = session_parent
|
||||
self.return_code = return_code
|
||||
|
||||
class ReturnCode(enum):
|
||||
CONNECTION_ACCEPTED = 0
|
||||
UNACCEPTABLE_PROTOCOL_VERSION = 1
|
||||
IDENTIFIER_REJECTED = 2
|
||||
SERVER_UNAVAILABLE = 3
|
||||
BAD_USERNAME_PASSWORD = 4
|
||||
NOT_AUTHORIZED = 5
|
Ładowanie…
Reference in New Issue