kopia lustrzana https://github.com/Yakifo/amqtt
Merge branch 'rc' into discord
commit
1e7f62b7a4
|
@ -7,7 +7,9 @@
|
|||
jobs:
|
||||
pre_install:
|
||||
- pip install --upgrade pip
|
||||
- pip install --group docs
|
||||
- pip install uv
|
||||
- uv pip install --group dev --group docs
|
||||
- uv run pytest
|
||||
|
||||
mkdocs:
|
||||
configuration: mkdocs.rtd.yml
|
||||
|
|
|
@ -17,9 +17,7 @@
|
|||
- Communication over TCP and/or websocket, including support for SSL/TLS
|
||||
- Support QoS 0, QoS 1 and QoS 2 messages flow
|
||||
- Client auto-reconnection on network lost
|
||||
- Functionality expansion; plugins included:
|
||||
- Authentication through password file
|
||||
- Basic `$SYS` topics
|
||||
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
|
|||
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
|
||||
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
|
||||
|
||||
from .mqtt.disconnect import DisconnectPacket
|
||||
from .plugins.manager import BaseContext, PluginManager
|
||||
|
||||
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
|
||||
|
@ -41,18 +42,24 @@ _defaults = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yam
|
|||
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
|
||||
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
|
||||
|
||||
EVENT_BROKER_PRE_START = "broker_pre_start"
|
||||
EVENT_BROKER_POST_START = "broker_post_start"
|
||||
EVENT_BROKER_PRE_SHUTDOWN = "broker_pre_shutdown"
|
||||
EVENT_BROKER_POST_SHUTDOWN = "broker_post_shutdown"
|
||||
EVENT_BROKER_CLIENT_CONNECTED = "broker_client_connected"
|
||||
EVENT_BROKER_CLIENT_DISCONNECTED = "broker_client_disconnected"
|
||||
EVENT_BROKER_CLIENT_SUBSCRIBED = "broker_client_subscribed"
|
||||
EVENT_BROKER_CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed"
|
||||
EVENT_BROKER_MESSAGE_RECEIVED = "broker_message_received"
|
||||
|
||||
class EventBroker(Enum):
|
||||
"""Events issued by the broker."""
|
||||
|
||||
PRE_START = "broker_pre_start"
|
||||
POST_START = "broker_post_start"
|
||||
PRE_SHUTDOWN = "broker_pre_shutdown"
|
||||
POST_SHUTDOWN = "broker_post_shutdown"
|
||||
CLIENT_CONNECTED = "broker_client_connected"
|
||||
CLIENT_DISCONNECTED = "broker_client_disconnected"
|
||||
CLIENT_SUBSCRIBED = "broker_client_subscribed"
|
||||
CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed"
|
||||
MESSAGE_RECEIVED = "broker_message_received"
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
"""Actions issued by the broker."""
|
||||
|
||||
SUBSCRIBE = "subscribe"
|
||||
PUBLISH = "publish"
|
||||
|
||||
|
@ -142,9 +149,12 @@ class Broker:
|
|||
|
||||
Args:
|
||||
config: dictionary of configuration options (see [broker configuration](broker_config.md)).
|
||||
loop: asyncio loop. defaults to `asyncio.get_event_loop()`.
|
||||
loop: asyncio loop. defaults to `asyncio.new_event_loop()`.
|
||||
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
|
||||
|
||||
Raises:
|
||||
BrokerError, ParserError, PluginError
|
||||
|
||||
"""
|
||||
|
||||
states: ClassVar[list[str]] = [
|
||||
|
@ -170,7 +180,7 @@ class Broker:
|
|||
self.config.update(config)
|
||||
self._build_listeners_config(self.config)
|
||||
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._loop = loop or asyncio.new_event_loop()
|
||||
self._servers: dict[str, Server] = {}
|
||||
self._init_states()
|
||||
self._sessions: dict[str, tuple[Session, BrokerProtocolHandler]] = {}
|
||||
|
@ -242,11 +252,11 @@ class Broker:
|
|||
msg = f"Broker instance can't be started: {exc}"
|
||||
raise BrokerError(msg) from exc
|
||||
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
|
||||
await self.plugins_manager.fire_event(EventBroker.PRE_START.value)
|
||||
try:
|
||||
await self._start_listeners()
|
||||
self.transitions.starting_success()
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
|
||||
await self.plugins_manager.fire_event(EventBroker.POST_START.value)
|
||||
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
|
||||
self.logger.debug("Broker started")
|
||||
except Exception as e:
|
||||
|
@ -327,7 +337,7 @@ class Broker:
|
|||
"""Stop broker instance."""
|
||||
self.logger.info("Shutting down broker...")
|
||||
# Fire broker_shutdown event to plugins
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
|
||||
await self.plugins_manager.fire_event(EventBroker.PRE_SHUTDOWN.value)
|
||||
|
||||
# Cleanup all sessions
|
||||
for client_id in list(self._sessions.keys()):
|
||||
|
@ -351,7 +361,7 @@ class Broker:
|
|||
self._broadcast_queue.get_nowait()
|
||||
|
||||
self.logger.info("Broker closed")
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
|
||||
await self.plugins_manager.fire_event(EventBroker.POST_SHUTDOWN.value)
|
||||
self.transitions.stopping_success()
|
||||
|
||||
async def _cleanup_session(self, client_id: str) -> None:
|
||||
|
@ -494,7 +504,7 @@ class Broker:
|
|||
self._sessions[client_session.client_id] = (client_session, handler)
|
||||
|
||||
await handler.mqtt_connack_authorize(authenticated)
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
|
||||
await self.plugins_manager.fire_event(EventBroker.CLIENT_CONNECTED.value, client_id=client_session.client_id)
|
||||
|
||||
self.logger.debug(f"{client_session.client_id} Start messages handling")
|
||||
await handler.start()
|
||||
|
@ -525,8 +535,12 @@ class Broker:
|
|||
)
|
||||
|
||||
if disconnect_waiter in done:
|
||||
connected = await self._handle_disconnect(client_session, handler, disconnect_waiter)
|
||||
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
|
||||
# handle the disconnection: normal or abnormal result, either way, the client is no longer connected
|
||||
await self._handle_disconnect(client_session, handler, disconnect_waiter)
|
||||
connected = False
|
||||
|
||||
# no need to reschedule the `disconnect_waiter` since we're exiting the message loop
|
||||
|
||||
|
||||
if subscribe_waiter in done:
|
||||
await self._handle_subscription(client_session, handler, subscribe_waiter)
|
||||
|
@ -556,11 +570,20 @@ class Broker:
|
|||
client_session: Session,
|
||||
handler: BrokerProtocolHandler,
|
||||
disconnect_waiter: asyncio.Future[Any],
|
||||
) -> bool:
|
||||
"""Handle client disconnection."""
|
||||
) -> None:
|
||||
"""Handle client disconnection.
|
||||
|
||||
Args:
|
||||
client_session (Session): client session
|
||||
handler (BrokerProtocolHandler): broker protocol handler
|
||||
disconnect_waiter (asyncio.Future[Any]): future to wait for disconnection
|
||||
|
||||
"""
|
||||
# check the disconnected waiter result
|
||||
result = disconnect_waiter.result()
|
||||
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
|
||||
if result is None:
|
||||
# if the client disconnects abruptly by sending no message or the message isn't a disconnect packet
|
||||
if result is None or not isinstance(result, DisconnectPacket):
|
||||
self.logger.debug(f"Will flag: {client_session.will_flag}")
|
||||
if client_session.will_flag:
|
||||
self.logger.debug(
|
||||
|
@ -579,12 +602,13 @@ class Broker:
|
|||
client_session.will_message,
|
||||
client_session.will_qos,
|
||||
)
|
||||
self.logger.debug(f"{client_session.client_id} Disconnecting session")
|
||||
await self._stop_handler(handler)
|
||||
client_session.transitions.disconnect()
|
||||
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
|
||||
return False
|
||||
return True
|
||||
|
||||
# normal or not, let's end the client's session
|
||||
self.logger.debug(f"{client_session.client_id} Disconnecting session")
|
||||
await self._stop_handler(handler)
|
||||
client_session.transitions.disconnect()
|
||||
await self.plugins_manager.fire_event(EventBroker.CLIENT_DISCONNECTED.value, client_id=client_session.client_id)
|
||||
|
||||
|
||||
async def _handle_subscription(
|
||||
self,
|
||||
|
@ -600,7 +624,7 @@ class Broker:
|
|||
for index, subscription in enumerate(subscriptions.topics):
|
||||
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
|
||||
await self.plugins_manager.fire_event(
|
||||
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
EventBroker.CLIENT_SUBSCRIBED.value,
|
||||
client_id=client_session.client_id,
|
||||
topic=subscription[0],
|
||||
qos=subscription[1],
|
||||
|
@ -619,7 +643,7 @@ class Broker:
|
|||
for topic in unsubscription.topics:
|
||||
self._del_subscription(topic, client_session)
|
||||
await self.plugins_manager.fire_event(
|
||||
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
|
||||
EventBroker.CLIENT_UNSUBSCRIBED.value,
|
||||
client_id=client_session.client_id,
|
||||
topic=topic,
|
||||
)
|
||||
|
@ -654,7 +678,7 @@ class Broker:
|
|||
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
|
||||
else:
|
||||
await self.plugins_manager.fire_event(
|
||||
EVENT_BROKER_MESSAGE_RECEIVED,
|
||||
EventBroker.MESSAGE_RECEIVED.value,
|
||||
client_id=client_session.client_id,
|
||||
message=app_message,
|
||||
)
|
||||
|
@ -663,7 +687,6 @@ class Broker:
|
|||
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
|
||||
return True
|
||||
|
||||
# TODO: Remove this method, not found it used
|
||||
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
|
||||
"""Create a BrokerProtocolHandler and attach to a session."""
|
||||
handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
|
||||
|
@ -686,7 +709,6 @@ class Broker:
|
|||
- False if user authentication fails
|
||||
- None if authentication can't be achieved (then plugin result is then ignored)
|
||||
:param session:
|
||||
:param listener:
|
||||
:return:
|
||||
"""
|
||||
auth_plugins = None
|
||||
|
@ -757,7 +779,6 @@ class Broker:
|
|||
- False if MQTT client is not allowed to subscribe to the topic
|
||||
- None if topic filtering can't be achieved (then plugin result is then ignored)
|
||||
:param session:
|
||||
:param listener:
|
||||
:param topic: Topic in which the client wants to subscribe / publish
|
||||
:param action: What is being done with the topic? subscribe or publish
|
||||
:return:
|
||||
|
|
|
@ -88,6 +88,9 @@ class MQTTClient:
|
|||
it will be generated randomly by `amqtt.utils.gen_client_id`
|
||||
config: dictionary of configuration options (see [client configuration](client_config.md)).
|
||||
|
||||
Raises:
|
||||
PluginError
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
|
||||
|
@ -142,7 +145,7 @@ class MQTTClient:
|
|||
[CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
|
||||
|
||||
Raises:
|
||||
amqtt.client.ConnectException: if connection fails
|
||||
ClientError, ConnectError
|
||||
|
||||
"""
|
||||
additional_headers = additional_headers if additional_headers is not None else {}
|
||||
|
@ -219,7 +222,7 @@ class MQTTClient:
|
|||
self.logger.debug(f"Reconnecting with session parameters: {self.session}")
|
||||
|
||||
reconnect_max_interval = self.config.get("reconnect_max_interval", 10)
|
||||
reconnect_retries = self.config.get("reconnect_retries", 5)
|
||||
reconnect_retries = self.config.get("reconnect_retries", 2)
|
||||
nb_attempt = 1
|
||||
|
||||
while True:
|
||||
|
@ -232,7 +235,7 @@ class MQTTClient:
|
|||
except Exception as e:
|
||||
self.logger.warning(f"Reconnection attempt failed: {e!r}")
|
||||
self.logger.debug("", exc_info=True)
|
||||
if reconnect_retries < nb_attempt: # reconnect_retries >= 0 and
|
||||
if 0 <= reconnect_retries < nb_attempt:
|
||||
self.logger.exception("Maximum connection attempts reached. Reconnection aborted.")
|
||||
self.logger.debug("", exc_info=True)
|
||||
msg = "Too many failed attempts"
|
||||
|
@ -470,6 +473,7 @@ class MQTTClient:
|
|||
reader: StreamReaderAdapter | WebSocketsReader | None = None
|
||||
writer: StreamWriterAdapter | WebSocketsWriter | None = None
|
||||
self._connected_state.clear()
|
||||
|
||||
# Open connection
|
||||
if scheme in ("mqtt", "mqtts"):
|
||||
conn_reader, conn_writer = await asyncio.open_connection(
|
||||
|
@ -489,11 +493,11 @@ class MQTTClient:
|
|||
)
|
||||
reader = WebSocketsReader(websocket)
|
||||
writer = WebSocketsWriter(websocket)
|
||||
|
||||
if reader is None or writer is None:
|
||||
self.session.transitions.disconnect()
|
||||
self.logger.warning("reader or writer not initialized")
|
||||
msg = "reader or writer not initialized"
|
||||
elif not self.session.broker_uri:
|
||||
msg = "missing broker uri"
|
||||
raise ClientError(msg)
|
||||
else:
|
||||
msg = f"incorrect scheme defined in uri: '{scheme!r}'"
|
||||
raise ClientError(msg)
|
||||
|
||||
# Start MQTT protocol
|
||||
|
@ -533,7 +537,7 @@ class MQTTClient:
|
|||
while self.client_tasks:
|
||||
task = self.client_tasks.popleft()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
task.cancel(msg="Connection closed.")
|
||||
|
||||
self.logger.debug("Monitoring broker disconnection")
|
||||
# Wait for disconnection from broker (like connection lost)
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
from struct import pack, unpack
|
||||
|
||||
from amqtt.adapters import ReaderAdapter
|
||||
from amqtt.errors import NoDataError
|
||||
from amqtt.errors import NoDataError, ZeroLengthReadError
|
||||
|
||||
|
||||
def bytes_to_hex_str(data: bytes | bytearray) -> str:
|
||||
|
@ -59,7 +59,7 @@ async def read_or_raise(reader: ReaderAdapter | asyncio.StreamReader, n: int = -
|
|||
data = await reader.read(n)
|
||||
except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError):
|
||||
data = None
|
||||
if not data:
|
||||
if data is None:
|
||||
msg = "No more data"
|
||||
raise NoDataError(msg)
|
||||
return data
|
||||
|
@ -72,6 +72,8 @@ async def decode_string(reader: ReaderAdapter | asyncio.StreamReader) -> str:
|
|||
:return: string read from stream.
|
||||
"""
|
||||
length_bytes = await read_or_raise(reader, 2)
|
||||
if len(length_bytes) < 1:
|
||||
raise ZeroLengthReadError
|
||||
str_length = unpack("!H", length_bytes)[0]
|
||||
if str_length:
|
||||
byte_str = await read_or_raise(reader, str_length)
|
||||
|
@ -90,6 +92,8 @@ async def decode_data_with_length(reader: ReaderAdapter | asyncio.StreamReader)
|
|||
:return: bytes read from stream (without length).
|
||||
"""
|
||||
length_bytes = await read_or_raise(reader, 2)
|
||||
if len(length_bytes) < 1:
|
||||
raise ZeroLengthReadError
|
||||
bytes_length = unpack("!H", length_bytes)[0]
|
||||
return await read_or_raise(reader, bytes_length)
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from typing import Any
|
||||
|
||||
|
||||
class AMQTTError(Exception):
|
||||
"""aMQTT base exception."""
|
||||
|
||||
|
@ -13,11 +16,28 @@ class CodecError(Exception):
|
|||
class NoDataError(Exception):
|
||||
"""Exceptions thrown by packet encode/decode functions."""
|
||||
|
||||
class ZeroLengthReadError(NoDataError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Decoding a string of length zero.")
|
||||
|
||||
class BrokerError(Exception):
|
||||
"""Exceptions thrown by broker."""
|
||||
|
||||
|
||||
class PluginError(Exception):
|
||||
"""Exceptions thrown when loading or initializing a plugin."""
|
||||
|
||||
|
||||
class PluginImportError(PluginError):
|
||||
def __init__(self, plugin: Any) -> None:
|
||||
super().__init__(f"Plugin import failed: {plugin!r}")
|
||||
|
||||
|
||||
class PluginInitError(PluginError):
|
||||
def __init__(self, plugin: Any) -> None:
|
||||
super().__init__(f"Plugin init failed: {plugin!r}")
|
||||
|
||||
|
||||
class ClientError(Exception):
|
||||
"""Exceptions thrown by client."""
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from amqtt.errors import AMQTTError
|
||||
from amqtt.errors import AMQTTError, NoDataError
|
||||
from amqtt.mqtt.connack import ConnackPacket
|
||||
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
|
||||
from amqtt.mqtt.disconnect import DisconnectPacket
|
||||
|
@ -87,8 +87,10 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
if self.reader is None:
|
||||
msg = "Reader is not initialized."
|
||||
raise AMQTTError(msg)
|
||||
|
||||
connack = await ConnackPacket.from_stream(self.reader)
|
||||
try:
|
||||
connack = await ConnackPacket.from_stream(self.reader)
|
||||
except NoDataError as e:
|
||||
raise ConnectionError from e
|
||||
await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session)
|
||||
return connack.return_code
|
||||
|
||||
|
|
|
@ -152,7 +152,8 @@ class ProtocolHandler:
|
|||
if self.writer is not None:
|
||||
await self.writer.close()
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("Writer close was cancelled.", exc_info=True)
|
||||
# canceling the task is the expected result
|
||||
self.logger.debug("Writer close was cancelled.")
|
||||
except TimeoutError:
|
||||
self.logger.debug("Writer close operation timed out.", exc_info=True)
|
||||
except OSError:
|
||||
|
|
|
@ -8,6 +8,8 @@ from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
|||
import logging
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from amqtt.errors import PluginImportError, PluginInitError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -80,17 +82,21 @@ class PluginManager:
|
|||
try:
|
||||
self.logger.debug(f" Loading plugin {ep!s}")
|
||||
plugin = ep.load()
|
||||
self.logger.debug(f" Initializing plugin {ep!s}")
|
||||
|
||||
plugin_context = copy.copy(self.app_context)
|
||||
plugin_context.logger = self.logger.getChild(ep.name)
|
||||
except ImportError as e:
|
||||
self.logger.debug(f"Plugin import failed: {ep!r}", exc_info=True)
|
||||
raise PluginImportError(ep) from e
|
||||
|
||||
self.logger.debug(f" Initializing plugin {ep!s}")
|
||||
|
||||
plugin_context = copy.copy(self.app_context)
|
||||
plugin_context.logger = self.logger.getChild(ep.name)
|
||||
try:
|
||||
obj = plugin(plugin_context)
|
||||
return Plugin(ep.name, ep, obj)
|
||||
except ImportError:
|
||||
self.logger.warning(f"Plugin {ep!r} import failed")
|
||||
self.logger.debug("", exc_info=True)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
|
||||
raise PluginInitError(ep) from e
|
||||
|
||||
def get_plugin(self, name: str) -> Plugin | None:
|
||||
"""Get a plugin by its name from the plugins loaded for the current namespace.
|
||||
|
|
|
@ -7,7 +7,7 @@ from yaml.parser import ParserError
|
|||
|
||||
from amqtt import __version__ as amqtt_version
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.errors import BrokerError
|
||||
from amqtt.errors import BrokerError, PluginError
|
||||
from amqtt.utils import read_yaml_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -55,20 +55,21 @@ def broker_main(
|
|||
raise typer.Exit(code=1) from exc
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
broker = Broker(config)
|
||||
except (BrokerError, ParserError) as exc:
|
||||
except (BrokerError, ParserError, PluginError) as exc:
|
||||
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
_ = loop.create_task(broker.start()) #noqa : RUF006
|
||||
try:
|
||||
loop.run_until_complete(broker.start())
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
loop.run_until_complete(broker.shutdown())
|
||||
except Exception as exc:
|
||||
typer.echo("❌ Connection failed", err=True)
|
||||
typer.echo("❌ Broker execution halted", err=True)
|
||||
raise typer.Exit(code=1) from exc
|
||||
finally:
|
||||
loop.close()
|
||||
|
|
|
@ -182,8 +182,6 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
|||
logger.debug(f"Using default configuration from {default_config_path}")
|
||||
config = read_yaml_config(default_config_path)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if not client_id:
|
||||
client_id = _gen_client_id()
|
||||
|
||||
|
@ -217,7 +215,7 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
|||
)
|
||||
with contextlib.suppress(KeyboardInterrupt):
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.run(
|
||||
do_pub(
|
||||
client=client,
|
||||
message_input=message_input,
|
||||
|
@ -234,8 +232,6 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
|||
typer.echo("❌ Connection failed", err=True)
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
|
|
|
@ -110,8 +110,7 @@ def _version(v:bool) -> None:
|
|||
def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
||||
url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False),
|
||||
config_file: str | None = typer.Option(None, "-c", help="Client configuration file"),
|
||||
client_id: str | None = typer.Option(None, "-i", help="client identification for mqtt connection. *default: process id and the hostname of the client*"),
|
||||
max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
|
||||
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
|
||||
topics: list[str] = typer.Option(..., "-t", help="Topic filter to subscribe, can be used multiple times."), # noqa: B008
|
||||
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),
|
||||
|
@ -147,8 +146,6 @@ def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
|||
logger.debug(f"Using default configuration from {default_config_path}")
|
||||
config = read_yaml_config(default_config_path)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if not client_id:
|
||||
client_id = _gen_client_id()
|
||||
|
||||
|
@ -175,7 +172,7 @@ def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
|||
)
|
||||
with contextlib.suppress(KeyboardInterrupt):
|
||||
try:
|
||||
loop.run_until_complete(do_sub(client,
|
||||
asyncio.run(do_sub(client,
|
||||
url=url,
|
||||
topics=topics,
|
||||
ca_info=ca_info,
|
||||
|
@ -184,10 +181,10 @@ def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
|||
max_count=max_count,
|
||||
clean_session=clean_session,
|
||||
))
|
||||
|
||||
except (ClientError, ConnectError) as exc:
|
||||
typer.echo("❌ Connection failed", err=True)
|
||||
raise typer.Exit(code=1) from exc
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
{% extends "base.html" %}
|
||||
{% block outdated %}
|
||||
You're not viewing the latest version. <a href="{{ '../' ~ base_url }}"><strong>Click here to go to latest.</strong></a>
|
||||
{% endblock %}
|
|
@ -56,8 +56,6 @@ nav:
|
|||
theme:
|
||||
name: material
|
||||
logo: assets/amqtt_bw.svg
|
||||
extend:
|
||||
- base.html
|
||||
features:
|
||||
- announce.dismiss
|
||||
- content.action.edit
|
||||
|
@ -75,6 +73,7 @@ theme:
|
|||
- search.highlight
|
||||
- search.suggest
|
||||
- toc.follow
|
||||
- version
|
||||
palette:
|
||||
# Palette toggle for light mode
|
||||
- scheme: default
|
||||
|
@ -168,7 +167,8 @@ plugins:
|
|||
extra:
|
||||
version:
|
||||
provider: readthedocs
|
||||
default: latest
|
||||
default: v0.11.0
|
||||
warning: true
|
||||
social:
|
||||
- icon: fontawesome/brands/github
|
||||
link: https://github.com/pawamoy
|
||||
|
|
|
@ -191,7 +191,7 @@ timeout = 10
|
|||
asyncio_default_fixture_loop_scope = "function"
|
||||
#addopts = ["--tb=short", "--capture=tee-sys"]
|
||||
#log_cli = true
|
||||
#log_level = "DEBUG"
|
||||
log_level = "DEBUG"
|
||||
|
||||
# ------------------------------------ MYPY ------------------------------------
|
||||
[tool.mypy]
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
|
||||
# intentional import error to test broker response
|
||||
from pathlib import Pat # noqa
|
||||
|
||||
class MockImportErrorPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
|
@ -0,0 +1,18 @@
|
|||
import logging
|
||||
|
||||
from amqtt.plugins.authentication import BaseAuthPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NoAuthPlugin(BaseAuthPlugin):
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
return False
|
||||
|
||||
class AuthPlugin(BaseAuthPlugin):
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
return True
|
||||
|
|
@ -1,12 +1,17 @@
|
|||
import inspect
|
||||
from importlib.metadata import EntryPoint
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import amqtt.plugins
|
||||
from amqtt.broker import Broker, BrokerContext
|
||||
from amqtt.errors import PluginError, PluginInitError, PluginImportError
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.plugins.manager import BaseContext
|
||||
|
||||
_INVALID_METHOD: str = "invalid_foo"
|
||||
|
@ -61,3 +66,62 @@ def test_plugins_correct_has_attr() -> None:
|
|||
__import__(name)
|
||||
|
||||
_verify_module(module, module.__name__)
|
||||
|
||||
|
||||
class MockInitErrorPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
raise KeyError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_exception_while_init() -> None:
|
||||
class MockEntryPoints:
|
||||
|
||||
def select(self, group) -> list[EntryPoint]:
|
||||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
||||
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1
|
||||
}
|
||||
|
||||
with pytest.raises(PluginInitError):
|
||||
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_exception_while_loading() -> None:
|
||||
class MockEntryPoints:
|
||||
|
||||
def select(self, group) -> list[EntryPoint]:
|
||||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
||||
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1
|
||||
}
|
||||
|
||||
with pytest.raises(PluginImportError):
|
||||
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
|
|
|
@ -13,49 +13,49 @@ from amqtt.mqtt.constants import QOS_0
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# test broker sys
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_sys_plugin() -> None:
|
||||
|
||||
class MockEntryPoints:
|
||||
|
||||
def select(self, group) -> list[EntryPoint]:
|
||||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
||||
|
||||
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1
|
||||
}
|
||||
|
||||
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
await broker.start()
|
||||
client = MQTTClient()
|
||||
await client.connect("mqtt://127.0.0.1:1883/")
|
||||
await client.subscribe([("$SYS/broker/uptime", QOS_0),])
|
||||
await client.publish('test/topic', b'my test message')
|
||||
await asyncio.sleep(2)
|
||||
sys_msg_count = 0
|
||||
try:
|
||||
while True:
|
||||
message = await client.deliver_message(timeout_duration=0.5)
|
||||
if '$SYS' in message.topic:
|
||||
sys_msg_count += 1
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
logger.warning(f">>> sys message: {message.topic} - {message.data}")
|
||||
await client.disconnect()
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
assert sys_msg_count > 1
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_broker_sys_plugin() -> None:
|
||||
#
|
||||
# class MockEntryPoints:
|
||||
#
|
||||
# def select(self, group) -> list[EntryPoint]:
|
||||
# match group:
|
||||
# case 'tests.mock_plugins':
|
||||
# return [
|
||||
# EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
|
||||
# ]
|
||||
# case _:
|
||||
# return list()
|
||||
#
|
||||
#
|
||||
# with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
#
|
||||
# config = {
|
||||
# "listeners": {
|
||||
# "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
# },
|
||||
# 'sys_interval': 1
|
||||
# }
|
||||
#
|
||||
# broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
# await broker.start()
|
||||
# client = MQTTClient()
|
||||
# await client.connect("mqtt://127.0.0.1:1883/")
|
||||
# await client.subscribe([("$SYS/broker/uptime", QOS_0),])
|
||||
# await client.publish('test/topic', b'my test message')
|
||||
# await asyncio.sleep(2)
|
||||
# sys_msg_count = 0
|
||||
# try:
|
||||
# while True:
|
||||
# message = await client.deliver_message(timeout_duration=0.5)
|
||||
# if '$SYS' in message.topic:
|
||||
# sys_msg_count += 1
|
||||
# except asyncio.TimeoutError:
|
||||
# pass
|
||||
#
|
||||
# logger.warning(f">>> sys message: {message.topic} - {message.data}")
|
||||
# await client.disconnect()
|
||||
# await broker.shutdown()
|
||||
#
|
||||
#
|
||||
# assert sys_msg_count > 1
|
||||
|
|
|
@ -7,18 +7,7 @@ import psutil
|
|||
import pytest
|
||||
|
||||
from amqtt.adapters import StreamReaderAdapter, StreamWriterAdapter
|
||||
from amqtt.broker import (
|
||||
EVENT_BROKER_CLIENT_CONNECTED,
|
||||
EVENT_BROKER_CLIENT_DISCONNECTED,
|
||||
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
|
||||
EVENT_BROKER_MESSAGE_RECEIVED,
|
||||
EVENT_BROKER_POST_SHUTDOWN,
|
||||
EVENT_BROKER_POST_START,
|
||||
EVENT_BROKER_PRE_SHUTDOWN,
|
||||
EVENT_BROKER_PRE_START,
|
||||
Broker,
|
||||
)
|
||||
from amqtt.broker import EventBroker, Broker
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import ConnectError
|
||||
from amqtt.mqtt.connack import ConnackPacket
|
||||
|
@ -67,8 +56,8 @@ def test_split_bindaddr_port(input_str, output_addr, output_port):
|
|||
async def test_start_stop(broker, mock_plugin_manager):
|
||||
mock_plugin_manager.assert_has_calls(
|
||||
[
|
||||
call().fire_event(EVENT_BROKER_PRE_START),
|
||||
call().fire_event(EVENT_BROKER_POST_START),
|
||||
call().fire_event(EventBroker.PRE_START.value),
|
||||
call().fire_event(EventBroker.POST_START.value),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
@ -76,8 +65,8 @@ async def test_start_stop(broker, mock_plugin_manager):
|
|||
await broker.shutdown()
|
||||
mock_plugin_manager.assert_has_calls(
|
||||
[
|
||||
call().fire_event(EVENT_BROKER_PRE_SHUTDOWN),
|
||||
call().fire_event(EVENT_BROKER_POST_SHUTDOWN),
|
||||
call().fire_event(EventBroker.PRE_SHUTDOWN.value),
|
||||
call().fire_event(EventBroker.POST_SHUTDOWN.value),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
@ -98,11 +87,11 @@ async def test_client_connect(broker, mock_plugin_manager):
|
|||
mock_plugin_manager.assert_has_calls(
|
||||
[
|
||||
call().fire_event(
|
||||
EVENT_BROKER_CLIENT_CONNECTED,
|
||||
EventBroker.CLIENT_CONNECTED.value,
|
||||
client_id=client.session.client_id,
|
||||
),
|
||||
call().fire_event(
|
||||
EVENT_BROKER_CLIENT_DISCONNECTED,
|
||||
EventBroker.CLIENT_DISCONNECTED.value,
|
||||
client_id=client.session.client_id,
|
||||
),
|
||||
],
|
||||
|
@ -235,7 +224,7 @@ async def test_client_subscribe(broker, mock_plugin_manager):
|
|||
mock_plugin_manager.assert_has_calls(
|
||||
[
|
||||
call().fire_event(
|
||||
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
EventBroker.CLIENT_SUBSCRIBED.value,
|
||||
client_id=client.session.client_id,
|
||||
topic="/topic",
|
||||
qos=QOS_0,
|
||||
|
@ -272,7 +261,7 @@ async def test_client_subscribe_twice(broker, mock_plugin_manager):
|
|||
mock_plugin_manager.assert_has_calls(
|
||||
[
|
||||
call().fire_event(
|
||||
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
EventBroker.CLIENT_SUBSCRIBED.value,
|
||||
client_id=client.session.client_id,
|
||||
topic="/topic",
|
||||
qos=QOS_0,
|
||||
|
@ -306,13 +295,13 @@ async def test_client_unsubscribe(broker, mock_plugin_manager):
|
|||
mock_plugin_manager.assert_has_calls(
|
||||
[
|
||||
call().fire_event(
|
||||
EVENT_BROKER_CLIENT_SUBSCRIBED,
|
||||
EventBroker.CLIENT_SUBSCRIBED.value,
|
||||
client_id=client.session.client_id,
|
||||
topic="/topic",
|
||||
qos=QOS_0,
|
||||
),
|
||||
call().fire_event(
|
||||
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
|
||||
EventBroker.CLIENT_UNSUBSCRIBED.value,
|
||||
client_id=client.session.client_id,
|
||||
topic="/topic",
|
||||
),
|
||||
|
@ -337,7 +326,7 @@ async def test_client_publish(broker, mock_plugin_manager):
|
|||
mock_plugin_manager.assert_has_calls(
|
||||
[
|
||||
call().fire_event(
|
||||
EVENT_BROKER_MESSAGE_RECEIVED,
|
||||
EventBroker.MESSAGE_RECEIVED.value,
|
||||
client_id=pub_client.session.client_id,
|
||||
message=ret_message,
|
||||
),
|
||||
|
@ -509,7 +498,7 @@ async def test_client_publish_big(broker, mock_plugin_manager):
|
|||
mock_plugin_manager.assert_has_calls(
|
||||
[
|
||||
call().fire_event(
|
||||
EVENT_BROKER_MESSAGE_RECEIVED,
|
||||
EventBroker.MESSAGE_RECEIVED.value,
|
||||
client_id=pub_client.session.client_id,
|
||||
message=ret_message,
|
||||
),
|
||||
|
@ -740,3 +729,16 @@ async def test_broker_broadcast_cancellation(broker):
|
|||
await _client_publish(topic, data, qos)
|
||||
message = await asyncio.wait_for(sub_client.deliver_message(), timeout=1)
|
||||
assert message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_socket_open_close(broker):
|
||||
|
||||
# check that https://github.com/Yakifo/amqtt/issues/86 is fixed
|
||||
|
||||
# mqtt 3.1 requires a connect packet, otherwise the socket connection is rejected
|
||||
static_connect_packet = b'\x10\x1b\x00\x04MQTT\x04\x02\x00<\x00\x0ftest-client-123'
|
||||
s = socket.create_connection(("127.0.0.1", 1883))
|
||||
s.send(static_connect_packet)
|
||||
await asyncio.sleep(0.1)
|
||||
s.close()
|
||||
|
|
|
@ -4,10 +4,12 @@ import os
|
|||
import signal
|
||||
import subprocess
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.mqtt.constants import QOS_0
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
|
@ -158,38 +160,6 @@ async def test_publish_subscribe(broker):
|
|||
assert sub_proc.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pub_sub_retain(broker):
|
||||
"""Test various pub/sub will retain options."""
|
||||
# Test publishing with retain flag
|
||||
pub_proc = subprocess.run(
|
||||
[
|
||||
"amqtt_pub",
|
||||
"--url", "mqtt://127.0.0.1:1884",
|
||||
"-t", "topic/test",
|
||||
"-m", "standard message",
|
||||
"--will-topic", "topic/retain",
|
||||
"--will-message", "last will message",
|
||||
"--will-retain",
|
||||
],
|
||||
capture_output=True,
|
||||
)
|
||||
assert pub_proc.returncode == 0, f"publisher error code: {pub_proc.returncode}\n{pub_proc.stderr}"
|
||||
logger.debug("publisher succeeded")
|
||||
# Verify retained message is received by new subscriber
|
||||
sub_proc = subprocess.run(
|
||||
[
|
||||
"amqtt_sub",
|
||||
"--url", "mqtt://127.0.0.1:1884",
|
||||
"-t", "topic/retain",
|
||||
"-n", "1",
|
||||
],
|
||||
capture_output=True,
|
||||
)
|
||||
assert sub_proc.returncode == 0, f"subscriber error code: {sub_proc.returncode}\n{sub_proc.stderr}"
|
||||
assert "last will message" in str(sub_proc.stdout)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pub_errors(client_config_file):
|
||||
"""Test error handling in pub/sub tools."""
|
||||
|
@ -275,74 +245,3 @@ async def test_pub_client_config(broker, client_config_file):
|
|||
logger.debug(f"Stderr: {stderr.decode()}")
|
||||
|
||||
assert proc.returncode == 0, f"publisher error code: {proc.returncode}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pub_client_config_will(broker, client_config_file):
|
||||
|
||||
# verifying client script functionality of will topic (publisher)
|
||||
# https://github.com/Yakifo/amqtt/issues/159
|
||||
await asyncio.sleep(1)
|
||||
client1 = MQTTClient(client_id="client1")
|
||||
await client1.connect('mqtt://localhost:1884')
|
||||
await client1.subscribe([
|
||||
("test/will/topic", QOS_0)
|
||||
])
|
||||
|
||||
cmd = ["amqtt_pub",
|
||||
"-t", "test/topic",
|
||||
"-m", "\"test of regular topic\"",
|
||||
"-c", client_config_file]
|
||||
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
" ".join(cmd), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
logger.debug(f"Command: {cmd}")
|
||||
logger.debug(f"Stdout: {stdout.decode()}")
|
||||
logger.debug(f"Stderr: {stderr.decode()}")
|
||||
|
||||
message = await client1.deliver_message(timeout_duration=1)
|
||||
assert message.topic == 'test/will/topic'
|
||||
assert message.data == b'client ABC has disconnected'
|
||||
await client1.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(20)
|
||||
async def test_sub_client_config_will(broker, client_config, client_config_file):
|
||||
|
||||
# verifying client script functionality of will topic (subscriber)
|
||||
# https://github.com/Yakifo/amqtt/issues/159
|
||||
|
||||
client1 = MQTTClient(client_id="client1")
|
||||
await client1.connect('mqtt://localhost:1884')
|
||||
await client1.subscribe([
|
||||
("test/will/topic", QOS_0)
|
||||
])
|
||||
|
||||
cmd = ["amqtt_sub",
|
||||
"-t", "test/topic",
|
||||
"-c", client_config_file,
|
||||
"-n", "1"]
|
||||
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
" ".join(cmd), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# cause `amqtt_sub` to exit after receiving this one message
|
||||
await client1.publish("test/topic", b'my test message')
|
||||
|
||||
|
||||
# validate the 'will' message was received correctly
|
||||
message = await client1.deliver_message(timeout_duration=3)
|
||||
assert message.topic == 'test/will/topic'
|
||||
assert message.data == b'client ABC has disconnected'
|
||||
await client1.disconnect()
|
||||
|
||||
|
||||
stdout, stderr = await proc.communicate()
|
||||
logger.debug(f"Command: {cmd}")
|
||||
logger.debug(f"Stdout: {stdout.decode()}")
|
||||
logger.debug(f"Stderr: {stderr.decode()}")
|
||||
|
|
@ -1,10 +1,13 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from importlib.metadata import EntryPoint
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from amqtt.broker import Broker
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import ConnectError
|
||||
from amqtt.errors import ClientError, ConnectError
|
||||
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
|
||||
|
||||
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||
|
@ -265,27 +268,121 @@ def client_config():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_publish_will_with_retain(broker_fixture, client_config):
|
||||
async def test_client_will_with_clean_disconnect(broker_fixture):
|
||||
config = {
|
||||
"will": {
|
||||
"topic": "test/will/topic",
|
||||
"retain": False,
|
||||
"message": "client ABC has disconnected",
|
||||
"qos": 1
|
||||
},
|
||||
}
|
||||
|
||||
# verifying client functionality of will topic
|
||||
# https://github.com/Yakifo/amqtt/issues/159
|
||||
client1 = MQTTClient(client_id="client1", config=config)
|
||||
await client1.connect("mqtt://localhost:1883")
|
||||
|
||||
client1 = MQTTClient(client_id="client1")
|
||||
client2 = MQTTClient(client_id="client2")
|
||||
await client2.connect("mqtt://localhost:1883")
|
||||
await client2.subscribe(
|
||||
[
|
||||
("test/will/topic", QOS_0),
|
||||
]
|
||||
)
|
||||
|
||||
await client1.disconnect()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
message = await client2.deliver_message(timeout_duration=2)
|
||||
# if we do get a message, make sure it's not a will message
|
||||
assert message.topic != "test/will/topic"
|
||||
|
||||
await client2.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_will_with_abrupt_disconnect(broker_fixture):
|
||||
config = {
|
||||
"will": {
|
||||
"topic": "test/will/topic",
|
||||
"retain": False,
|
||||
"message": "client ABC has disconnected",
|
||||
"qos": 1
|
||||
},
|
||||
}
|
||||
|
||||
client1 = MQTTClient(client_id="client1", config=config)
|
||||
await client1.connect("mqtt://localhost:1883")
|
||||
|
||||
client2 = MQTTClient(client_id="client2")
|
||||
await client2.connect("mqtt://localhost:1883")
|
||||
await client2.subscribe(
|
||||
[
|
||||
("test/will/topic", QOS_0),
|
||||
]
|
||||
)
|
||||
|
||||
# instead of client.disconnect, call the necessary closing but without sending the disconnect packet
|
||||
await client1.cancel_tasks()
|
||||
if client1._disconnect_task and not client1._disconnect_task.done():
|
||||
client1._disconnect_task.cancel()
|
||||
client1._connected_state.clear()
|
||||
await client1._handler.stop()
|
||||
client1.session.transitions.disconnect()
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
message = await client2.deliver_message(timeout_duration=1)
|
||||
# make sure we receive the will message
|
||||
assert message.topic == "test/will/topic"
|
||||
assert message.data == b'client ABC has disconnected'
|
||||
|
||||
await client2.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_retained_will_with_abrupt_disconnect(broker_fixture):
|
||||
|
||||
# verifying client functionality of retained will topic/message
|
||||
|
||||
config = {
|
||||
"will": {
|
||||
"topic": "test/will/topic",
|
||||
"retain": True,
|
||||
"message": "client ABC has disconnected",
|
||||
"qos": 1
|
||||
},
|
||||
}
|
||||
|
||||
# first client, connect with retained will message
|
||||
client1 = MQTTClient(client_id="client1", config=config)
|
||||
await client1.connect('mqtt://localhost:1883')
|
||||
await client1.subscribe([
|
||||
|
||||
client2 = MQTTClient(client_id="client2")
|
||||
await client2.connect('mqtt://localhost:1883')
|
||||
await client2.subscribe([
|
||||
("test/will/topic", QOS_0)
|
||||
])
|
||||
|
||||
client2 = MQTTClient(client_id="client2", config=client_config)
|
||||
await client2.connect('mqtt://localhost:1883')
|
||||
await client2.publish('my/topic', b'my message')
|
||||
await client2.disconnect()
|
||||
|
||||
message = await client1.deliver_message(timeout_duration=1)
|
||||
# let's abruptly disconnect client1
|
||||
await client1.cancel_tasks()
|
||||
if client1._disconnect_task and not client1._disconnect_task.done():
|
||||
client1._disconnect_task.cancel()
|
||||
client1._connected_state.clear()
|
||||
await client1._handler.stop()
|
||||
client1.session.transitions.disconnect()
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# make sure the client which is still connected that we get the 'will' message
|
||||
message = await client2.deliver_message(timeout_duration=1)
|
||||
assert message.topic == 'test/will/topic'
|
||||
assert message.data == b'client ABC has disconnected'
|
||||
await client1.disconnect()
|
||||
await client2.disconnect()
|
||||
|
||||
# make sure a client which is connected after client1 disconnected still receives the 'will' message from
|
||||
client3 = MQTTClient(client_id="client3")
|
||||
await client3.connect('mqtt://localhost:1883')
|
||||
await client3.subscribe([
|
||||
|
@ -295,3 +392,92 @@ async def test_client_publish_will_with_retain(broker_fixture, client_config):
|
|||
assert message3.topic == 'test/will/topic'
|
||||
assert message3.data == b'client ABC has disconnected'
|
||||
await client3.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_abruptly_disconnecting_with_empty_will_message(broker_fixture):
|
||||
|
||||
config = {
|
||||
"will": {
|
||||
"topic": "test/will/topic",
|
||||
"retain": True,
|
||||
"message": "",
|
||||
"qos": 1
|
||||
},
|
||||
}
|
||||
client1 = MQTTClient(client_id="client1", config=config)
|
||||
await client1.connect('mqtt://localhost:1883')
|
||||
|
||||
client2 = MQTTClient(client_id="client2")
|
||||
await client2.connect('mqtt://localhost:1883')
|
||||
await client2.subscribe([
|
||||
("test/will/topic", QOS_0)
|
||||
])
|
||||
|
||||
# let's abruptly disconnect client1
|
||||
await client1.cancel_tasks()
|
||||
if client1._disconnect_task and not client1._disconnect_task.done():
|
||||
client1._disconnect_task.cancel()
|
||||
client1._connected_state.clear()
|
||||
await client1._handler.stop()
|
||||
client1.session.transitions.disconnect()
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
message = await client2.deliver_message(timeout_duration=1)
|
||||
assert message.topic == 'test/will/topic'
|
||||
assert message.data == b''
|
||||
|
||||
await client2.disconnect()
|
||||
|
||||
|
||||
async def test_connect_broken_uri():
|
||||
config = {"auto_reconnect": False}
|
||||
client = MQTTClient(config=config)
|
||||
with pytest.raises(ClientError):
|
||||
await client.connect('"mqtt://someplace')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_incorrect_scheme():
|
||||
config = {"auto_reconnect": False}
|
||||
client = MQTTClient(config=config)
|
||||
with pytest.raises(ClientError):
|
||||
await client.connect('"mq://someplace')
|
||||
|
||||
|
||||
async def test_client_no_auth():
|
||||
|
||||
class MockEntryPoints:
|
||||
|
||||
def select(self, group) -> list[EntryPoint]:
|
||||
match group:
|
||||
case 'tests.mock_plugins':
|
||||
return [
|
||||
EntryPoint(name='auth_plugin', group='tests.mock_plugins', value='tests.plugins.mocks:NoAuthPlugin'),
|
||||
]
|
||||
case _:
|
||||
return list()
|
||||
|
||||
|
||||
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
|
||||
|
||||
config = {
|
||||
"listeners": {
|
||||
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
|
||||
},
|
||||
'sys_interval': 1,
|
||||
'auth': {
|
||||
'plugins': ['auth_plugin', ]
|
||||
}
|
||||
}
|
||||
|
||||
client = MQTTClient(client_id="client1", config={'auto_reconnect': False})
|
||||
|
||||
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
|
||||
await broker.start()
|
||||
|
||||
with pytest.raises(ConnectError):
|
||||
await client.connect("mqtt://127.0.0.1:1883/")
|
||||
|
||||
await broker.shutdown()
|
||||
|
|
|
@ -6,8 +6,7 @@ from unittest.mock import MagicMock, call, patch
|
|||
import pytest
|
||||
from paho.mqtt import client as mqtt_client
|
||||
|
||||
from amqtt.broker import EVENT_BROKER_CLIENT_CONNECTED, EVENT_BROKER_CLIENT_DISCONNECTED, EVENT_BROKER_PRE_START, \
|
||||
EVENT_BROKER_POST_START
|
||||
from amqtt.broker import EventBroker
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.mqtt.constants import QOS_1, QOS_2
|
||||
|
||||
|
@ -54,11 +53,11 @@ async def test_paho_connect(broker, mock_plugin_manager):
|
|||
broker.plugins_manager.assert_has_calls(
|
||||
[
|
||||
call.fire_event(
|
||||
EVENT_BROKER_CLIENT_CONNECTED,
|
||||
EventBroker.CLIENT_CONNECTED.value,
|
||||
client_id=client_id,
|
||||
),
|
||||
call.fire_event(
|
||||
EVENT_BROKER_CLIENT_DISCONNECTED,
|
||||
EventBroker.CLIENT_DISCONNECTED.value,
|
||||
client_id=client_id,
|
||||
),
|
||||
],
|
||||
|
|
Ładowanie…
Reference in New Issue