kopia lustrzana https://github.com/Yakifo/amqtt
				
				
				
			fixes amqtt/Yakifo#210 : when reconnect is false, authentication failure causes NoDataError instead of ConnectError
							rodzic
							
								
									06053ce7ee
								
							
						
					
					
						commit
						e7882a3755
					
				| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
import asyncio
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from amqtt.errors import AMQTTError
 | 
			
		||||
from amqtt.errors import AMQTTError, NoDataError
 | 
			
		||||
from amqtt.mqtt.connack import ConnackPacket
 | 
			
		||||
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
 | 
			
		||||
from amqtt.mqtt.disconnect import DisconnectPacket
 | 
			
		||||
| 
						 | 
				
			
			@ -87,8 +87,10 @@ class ClientProtocolHandler(ProtocolHandler):
 | 
			
		|||
        if self.reader is None:
 | 
			
		||||
            msg = "Reader is not initialized."
 | 
			
		||||
            raise AMQTTError(msg)
 | 
			
		||||
 | 
			
		||||
        connack = await ConnackPacket.from_stream(self.reader)
 | 
			
		||||
        try:
 | 
			
		||||
            connack = await ConnackPacket.from_stream(self.reader)
 | 
			
		||||
        except NoDataError as e:
 | 
			
		||||
            raise ConnectionError from e
 | 
			
		||||
        await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session)
 | 
			
		||||
        return connack.return_code
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,18 @@
 | 
			
		|||
import logging
 | 
			
		||||
 | 
			
		||||
from amqtt.plugins.authentication import BaseAuthPlugin
 | 
			
		||||
from amqtt.session import Session
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NoAuthPlugin(BaseAuthPlugin):
 | 
			
		||||
 | 
			
		||||
    async def authenticate(self, *, session: Session) -> bool | None:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
class AuthPlugin(BaseAuthPlugin):
 | 
			
		||||
 | 
			
		||||
    async def authenticate(self, *, session: Session) -> bool | None:
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1,8 +1,11 @@
 | 
			
		|||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
from importlib.metadata import EntryPoint
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from amqtt.broker import Broker
 | 
			
		||||
from amqtt.client import MQTTClient
 | 
			
		||||
from amqtt.errors import ConnectError
 | 
			
		||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
 | 
			
		||||
| 
						 | 
				
			
			@ -295,3 +298,42 @@ async def test_client_publish_will_with_retain(broker_fixture, client_config):
 | 
			
		|||
    assert message3.topic == 'test/will/topic'
 | 
			
		||||
    assert message3.data == b'client ABC has disconnected'
 | 
			
		||||
    await client3.disconnect()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.asyncio
 | 
			
		||||
async def test_client_no_auth():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    class MockEntryPoints:
 | 
			
		||||
 | 
			
		||||
        def select(self, group) -> list[EntryPoint]:
 | 
			
		||||
            match group:
 | 
			
		||||
                case 'tests.mock_plugins':
 | 
			
		||||
                    return [
 | 
			
		||||
                            EntryPoint(name='auth_plugin', group='tests.mock_plugins', value='tests.plugins.mocks:NoAuthPlugin'),
 | 
			
		||||
                        ]
 | 
			
		||||
                case _:
 | 
			
		||||
                    return list()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
 | 
			
		||||
 | 
			
		||||
        config = {
 | 
			
		||||
            "listeners": {
 | 
			
		||||
                "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
 | 
			
		||||
            },
 | 
			
		||||
            'sys_interval': 1,
 | 
			
		||||
            'auth': {
 | 
			
		||||
                'plugins': ['auth_plugin', ]
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        client = MQTTClient(client_id="client1", config={'auto_reconnect': False})
 | 
			
		||||
 | 
			
		||||
        broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
 | 
			
		||||
        await broker.start()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        with pytest.raises(ConnectError):
 | 
			
		||||
            await client.connect("mqtt://127.0.0.1:1883/")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Ładowanie…
	
		Reference in New Issue