diff --git a/hbmqtt/broker.py b/hbmqtt/broker.py index bd1203a..2d2a497 100644 --- a/hbmqtt/broker.py +++ b/hbmqtt/broker.py @@ -10,6 +10,7 @@ from hbmqtt.codecs.header import MQTTHeaderCodec from hbmqtt.codecs.errors import CodecException from hbmqtt.codecs.connect import ConnectMessage from hbmqtt.message import MessageType +from hbmqtt.errors import MQTTException class BrokerProtocol(asyncio.Protocol): @@ -100,13 +101,20 @@ def init_message_codecs(): def client_connected(reader, writer): (remote_address, remote_port) = writer.get_extra_info('peername') codecs = init_message_codecs() + first_packet = True while True: try: # Read fixed header - fixed_header = yield from MQTTHeaderCodec.decode(reader) + header = yield from MQTTHeaderCodec.decode(reader) + if first_packet and header.message_type != MessageType.CONNECT: + raise MQTTException("[MQTT-3.1.0-1] First Packet sent from the Client MUST be a CONNECT Packet") + if not first_packet and header.message_type == MessageType.CONNECT: + raise MQTTException("[MQTT-3.1.0-2] Client can only send the CONNECT Packet once over a Network Connection") + first_packet = False # Find message decoder and decode - codec = codecs[fixed_header.message_type] - message = yield from codec.decode(fixed_header, reader) - except CodecException: + codec = codecs[header.message_type] + message = yield from codec.decode(header, reader) + + except MQTTException: #End connection break