Merge branch 'rc' into discord

pull/214/head
Andrew Mirsky 2025-06-14 10:17:28 -04:00
commit 1e7f62b7a4
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
23 zmienionych plików z 499 dodań i 273 usunięć

Wyświetl plik

@ -7,7 +7,9 @@
jobs:
pre_install:
- pip install --upgrade pip
- pip install --group docs
- pip install uv
- uv pip install --group dev --group docs
- uv run pytest
mkdocs:
configuration: mkdocs.rtd.yml

Wyświetl plik

@ -17,9 +17,7 @@
- Communication over TCP and/or websocket, including support for SSL/TLS
- Support QoS 0, QoS 1 and QoS 2 messages flow
- Client auto-reconnection on network lost
- Functionality expansion; plugins included:
- Authentication through password file
- Basic `$SYS` topics
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
## Installation

Wyświetl plik

@ -28,6 +28,7 @@ from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
from .mqtt.disconnect import DisconnectPacket
from .plugins.manager import BaseContext, PluginManager
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
@ -41,18 +42,24 @@ _defaults = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yam
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
EVENT_BROKER_PRE_START = "broker_pre_start"
EVENT_BROKER_POST_START = "broker_post_start"
EVENT_BROKER_PRE_SHUTDOWN = "broker_pre_shutdown"
EVENT_BROKER_POST_SHUTDOWN = "broker_post_shutdown"
EVENT_BROKER_CLIENT_CONNECTED = "broker_client_connected"
EVENT_BROKER_CLIENT_DISCONNECTED = "broker_client_disconnected"
EVENT_BROKER_CLIENT_SUBSCRIBED = "broker_client_subscribed"
EVENT_BROKER_CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed"
EVENT_BROKER_MESSAGE_RECEIVED = "broker_message_received"
class EventBroker(Enum):
"""Events issued by the broker."""
PRE_START = "broker_pre_start"
POST_START = "broker_post_start"
PRE_SHUTDOWN = "broker_pre_shutdown"
POST_SHUTDOWN = "broker_post_shutdown"
CLIENT_CONNECTED = "broker_client_connected"
CLIENT_DISCONNECTED = "broker_client_disconnected"
CLIENT_SUBSCRIBED = "broker_client_subscribed"
CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed"
MESSAGE_RECEIVED = "broker_message_received"
class Action(Enum):
"""Actions issued by the broker."""
SUBSCRIBE = "subscribe"
PUBLISH = "publish"
@ -142,9 +149,12 @@ class Broker:
Args:
config: dictionary of configuration options (see [broker configuration](broker_config.md)).
loop: asyncio loop. defaults to `asyncio.get_event_loop()`.
loop: asyncio loop. defaults to `asyncio.new_event_loop()`.
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
Raises:
BrokerError, ParserError, PluginError
"""
states: ClassVar[list[str]] = [
@ -170,7 +180,7 @@ class Broker:
self.config.update(config)
self._build_listeners_config(self.config)
self._loop = loop or asyncio.get_event_loop()
self._loop = loop or asyncio.new_event_loop()
self._servers: dict[str, Server] = {}
self._init_states()
self._sessions: dict[str, tuple[Session, BrokerProtocolHandler]] = {}
@ -242,11 +252,11 @@ class Broker:
msg = f"Broker instance can't be started: {exc}"
raise BrokerError(msg) from exc
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
await self.plugins_manager.fire_event(EventBroker.PRE_START.value)
try:
await self._start_listeners()
self.transitions.starting_success()
await self.plugins_manager.fire_event(EVENT_BROKER_POST_START)
await self.plugins_manager.fire_event(EventBroker.POST_START.value)
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
self.logger.debug("Broker started")
except Exception as e:
@ -327,7 +337,7 @@ class Broker:
"""Stop broker instance."""
self.logger.info("Shutting down broker...")
# Fire broker_shutdown event to plugins
await self.plugins_manager.fire_event(EVENT_BROKER_PRE_SHUTDOWN)
await self.plugins_manager.fire_event(EventBroker.PRE_SHUTDOWN.value)
# Cleanup all sessions
for client_id in list(self._sessions.keys()):
@ -351,7 +361,7 @@ class Broker:
self._broadcast_queue.get_nowait()
self.logger.info("Broker closed")
await self.plugins_manager.fire_event(EVENT_BROKER_POST_SHUTDOWN)
await self.plugins_manager.fire_event(EventBroker.POST_SHUTDOWN.value)
self.transitions.stopping_success()
async def _cleanup_session(self, client_id: str) -> None:
@ -494,7 +504,7 @@ class Broker:
self._sessions[client_session.client_id] = (client_session, handler)
await handler.mqtt_connack_authorize(authenticated)
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)
await self.plugins_manager.fire_event(EventBroker.CLIENT_CONNECTED.value, client_id=client_session.client_id)
self.logger.debug(f"{client_session.client_id} Start messages handling")
await handler.start()
@ -525,8 +535,12 @@ class Broker:
)
if disconnect_waiter in done:
connected = await self._handle_disconnect(client_session, handler, disconnect_waiter)
disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
# handle the disconnection: normal or abnormal result, either way, the client is no longer connected
await self._handle_disconnect(client_session, handler, disconnect_waiter)
connected = False
# no need to reschedule the `disconnect_waiter` since we're exiting the message loop
if subscribe_waiter in done:
await self._handle_subscription(client_session, handler, subscribe_waiter)
@ -556,11 +570,20 @@ class Broker:
client_session: Session,
handler: BrokerProtocolHandler,
disconnect_waiter: asyncio.Future[Any],
) -> bool:
"""Handle client disconnection."""
) -> None:
"""Handle client disconnection.
Args:
client_session (Session): client session
handler (BrokerProtocolHandler): broker protocol handler
disconnect_waiter (asyncio.Future[Any]): future to wait for disconnection
"""
# check the disconnected waiter result
result = disconnect_waiter.result()
self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
if result is None:
# if the client disconnects abruptly by sending no message or the message isn't a disconnect packet
if result is None or not isinstance(result, DisconnectPacket):
self.logger.debug(f"Will flag: {client_session.will_flag}")
if client_session.will_flag:
self.logger.debug(
@ -579,12 +602,13 @@ class Broker:
client_session.will_message,
client_session.will_qos,
)
self.logger.debug(f"{client_session.client_id} Disconnecting session")
await self._stop_handler(handler)
client_session.transitions.disconnect()
await self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
return False
return True
# normal or not, let's end the client's session
self.logger.debug(f"{client_session.client_id} Disconnecting session")
await self._stop_handler(handler)
client_session.transitions.disconnect()
await self.plugins_manager.fire_event(EventBroker.CLIENT_DISCONNECTED.value, client_id=client_session.client_id)
async def _handle_subscription(
self,
@ -600,7 +624,7 @@ class Broker:
for index, subscription in enumerate(subscriptions.topics):
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
await self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
EventBroker.CLIENT_SUBSCRIBED.value,
client_id=client_session.client_id,
topic=subscription[0],
qos=subscription[1],
@ -619,7 +643,7 @@ class Broker:
for topic in unsubscription.topics:
self._del_subscription(topic, client_session)
await self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
EventBroker.CLIENT_UNSUBSCRIBED.value,
client_id=client_session.client_id,
topic=topic,
)
@ -654,7 +678,7 @@ class Broker:
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
else:
await self.plugins_manager.fire_event(
EVENT_BROKER_MESSAGE_RECEIVED,
EventBroker.MESSAGE_RECEIVED.value,
client_id=client_session.client_id,
message=app_message,
)
@ -663,7 +687,6 @@ class Broker:
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
return True
# TODO: Remove this method, not found it used
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
"""Create a BrokerProtocolHandler and attach to a session."""
handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
@ -686,7 +709,6 @@ class Broker:
- False if user authentication fails
- None if authentication can't be achieved (then plugin result is then ignored)
:param session:
:param listener:
:return:
"""
auth_plugins = None
@ -757,7 +779,6 @@ class Broker:
- False if MQTT client is not allowed to subscribe to the topic
- None if topic filtering can't be achieved (then plugin result is then ignored)
:param session:
:param listener:
:param topic: Topic in which the client wants to subscribe / publish
:param action: What is being done with the topic? subscribe or publish
:return:

Wyświetl plik

@ -88,6 +88,9 @@ class MQTTClient:
it will be generated randomly by `amqtt.utils.gen_client_id`
config: dictionary of configuration options (see [client configuration](client_config.md)).
Raises:
PluginError
"""
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
@ -142,7 +145,7 @@ class MQTTClient:
[CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
Raises:
amqtt.client.ConnectException: if connection fails
ClientError, ConnectError
"""
additional_headers = additional_headers if additional_headers is not None else {}
@ -219,7 +222,7 @@ class MQTTClient:
self.logger.debug(f"Reconnecting with session parameters: {self.session}")
reconnect_max_interval = self.config.get("reconnect_max_interval", 10)
reconnect_retries = self.config.get("reconnect_retries", 5)
reconnect_retries = self.config.get("reconnect_retries", 2)
nb_attempt = 1
while True:
@ -232,7 +235,7 @@ class MQTTClient:
except Exception as e:
self.logger.warning(f"Reconnection attempt failed: {e!r}")
self.logger.debug("", exc_info=True)
if reconnect_retries < nb_attempt: # reconnect_retries >= 0 and
if 0 <= reconnect_retries < nb_attempt:
self.logger.exception("Maximum connection attempts reached. Reconnection aborted.")
self.logger.debug("", exc_info=True)
msg = "Too many failed attempts"
@ -470,6 +473,7 @@ class MQTTClient:
reader: StreamReaderAdapter | WebSocketsReader | None = None
writer: StreamWriterAdapter | WebSocketsWriter | None = None
self._connected_state.clear()
# Open connection
if scheme in ("mqtt", "mqtts"):
conn_reader, conn_writer = await asyncio.open_connection(
@ -489,11 +493,11 @@ class MQTTClient:
)
reader = WebSocketsReader(websocket)
writer = WebSocketsWriter(websocket)
if reader is None or writer is None:
self.session.transitions.disconnect()
self.logger.warning("reader or writer not initialized")
msg = "reader or writer not initialized"
elif not self.session.broker_uri:
msg = "missing broker uri"
raise ClientError(msg)
else:
msg = f"incorrect scheme defined in uri: '{scheme!r}'"
raise ClientError(msg)
# Start MQTT protocol
@ -533,7 +537,7 @@ class MQTTClient:
while self.client_tasks:
task = self.client_tasks.popleft()
if not task.done():
task.cancel()
task.cancel(msg="Connection closed.")
self.logger.debug("Monitoring broker disconnection")
# Wait for disconnection from broker (like connection lost)

Wyświetl plik

@ -2,7 +2,7 @@ import asyncio
from struct import pack, unpack
from amqtt.adapters import ReaderAdapter
from amqtt.errors import NoDataError
from amqtt.errors import NoDataError, ZeroLengthReadError
def bytes_to_hex_str(data: bytes | bytearray) -> str:
@ -59,7 +59,7 @@ async def read_or_raise(reader: ReaderAdapter | asyncio.StreamReader, n: int = -
data = await reader.read(n)
except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError):
data = None
if not data:
if data is None:
msg = "No more data"
raise NoDataError(msg)
return data
@ -72,6 +72,8 @@ async def decode_string(reader: ReaderAdapter | asyncio.StreamReader) -> str:
:return: string read from stream.
"""
length_bytes = await read_or_raise(reader, 2)
if len(length_bytes) < 1:
raise ZeroLengthReadError
str_length = unpack("!H", length_bytes)[0]
if str_length:
byte_str = await read_or_raise(reader, str_length)
@ -90,6 +92,8 @@ async def decode_data_with_length(reader: ReaderAdapter | asyncio.StreamReader)
:return: bytes read from stream (without length).
"""
length_bytes = await read_or_raise(reader, 2)
if len(length_bytes) < 1:
raise ZeroLengthReadError
bytes_length = unpack("!H", length_bytes)[0]
return await read_or_raise(reader, bytes_length)

Wyświetl plik

@ -1,3 +1,6 @@
from typing import Any
class AMQTTError(Exception):
"""aMQTT base exception."""
@ -13,11 +16,28 @@ class CodecError(Exception):
class NoDataError(Exception):
"""Exceptions thrown by packet encode/decode functions."""
class ZeroLengthReadError(NoDataError):
def __init__(self) -> None:
super().__init__("Decoding a string of length zero.")
class BrokerError(Exception):
"""Exceptions thrown by broker."""
class PluginError(Exception):
"""Exceptions thrown when loading or initializing a plugin."""
class PluginImportError(PluginError):
def __init__(self, plugin: Any) -> None:
super().__init__(f"Plugin import failed: {plugin!r}")
class PluginInitError(PluginError):
def __init__(self, plugin: Any) -> None:
super().__init__(f"Plugin init failed: {plugin!r}")
class ClientError(Exception):
"""Exceptions thrown by client."""

Wyświetl plik

@ -1,7 +1,7 @@
import asyncio
from typing import Any
from amqtt.errors import AMQTTError
from amqtt.errors import AMQTTError, NoDataError
from amqtt.mqtt.connack import ConnackPacket
from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
from amqtt.mqtt.disconnect import DisconnectPacket
@ -87,8 +87,10 @@ class ClientProtocolHandler(ProtocolHandler):
if self.reader is None:
msg = "Reader is not initialized."
raise AMQTTError(msg)
connack = await ConnackPacket.from_stream(self.reader)
try:
connack = await ConnackPacket.from_stream(self.reader)
except NoDataError as e:
raise ConnectionError from e
await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session)
return connack.return_code

Wyświetl plik

@ -152,7 +152,8 @@ class ProtocolHandler:
if self.writer is not None:
await self.writer.close()
except asyncio.CancelledError:
self.logger.debug("Writer close was cancelled.", exc_info=True)
# canceling the task is the expected result
self.logger.debug("Writer close was cancelled.")
except TimeoutError:
self.logger.debug("Writer close operation timed out.", exc_info=True)
except OSError:

Wyświetl plik

@ -8,6 +8,8 @@ from importlib.metadata import EntryPoint, EntryPoints, entry_points
import logging
from typing import Any, NamedTuple
from amqtt.errors import PluginImportError, PluginInitError
_LOGGER = logging.getLogger(__name__)
@ -80,17 +82,21 @@ class PluginManager:
try:
self.logger.debug(f" Loading plugin {ep!s}")
plugin = ep.load()
self.logger.debug(f" Initializing plugin {ep!s}")
plugin_context = copy.copy(self.app_context)
plugin_context.logger = self.logger.getChild(ep.name)
except ImportError as e:
self.logger.debug(f"Plugin import failed: {ep!r}", exc_info=True)
raise PluginImportError(ep) from e
self.logger.debug(f" Initializing plugin {ep!s}")
plugin_context = copy.copy(self.app_context)
plugin_context.logger = self.logger.getChild(ep.name)
try:
obj = plugin(plugin_context)
return Plugin(ep.name, ep, obj)
except ImportError:
self.logger.warning(f"Plugin {ep!r} import failed")
self.logger.debug("", exc_info=True)
return None
except Exception as e:
self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
raise PluginInitError(ep) from e
def get_plugin(self, name: str) -> Plugin | None:
"""Get a plugin by its name from the plugins loaded for the current namespace.

Wyświetl plik

@ -7,7 +7,7 @@ from yaml.parser import ParserError
from amqtt import __version__ as amqtt_version
from amqtt.broker import Broker
from amqtt.errors import BrokerError
from amqtt.errors import BrokerError, PluginError
from amqtt.utils import read_yaml_config
logger = logging.getLogger(__name__)
@ -55,20 +55,21 @@ def broker_main(
raise typer.Exit(code=1) from exc
loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
broker = Broker(config)
except (BrokerError, ParserError) as exc:
except (BrokerError, ParserError, PluginError) as exc:
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
raise typer.Exit(code=1) from exc
_ = loop.create_task(broker.start()) #noqa : RUF006
try:
loop.run_until_complete(broker.start())
loop.run_forever()
except KeyboardInterrupt:
loop.run_until_complete(broker.shutdown())
except Exception as exc:
typer.echo("Connection failed", err=True)
typer.echo("Broker execution halted", err=True)
raise typer.Exit(code=1) from exc
finally:
loop.close()

Wyświetl plik

@ -182,8 +182,6 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
logger.debug(f"Using default configuration from {default_config_path}")
config = read_yaml_config(default_config_path)
loop = asyncio.get_event_loop()
if not client_id:
client_id = _gen_client_id()
@ -217,7 +215,7 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
)
with contextlib.suppress(KeyboardInterrupt):
try:
loop.run_until_complete(
asyncio.run(
do_pub(
client=client,
message_input=message_input,
@ -234,8 +232,6 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
typer.echo("❌ Connection failed", err=True)
raise typer.Exit(code=1) from exc
loop.close()
if __name__ == "__main__":
typer.run(main)

Wyświetl plik

@ -110,8 +110,7 @@ def _version(v:bool) -> None:
def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False),
config_file: str | None = typer.Option(None, "-c", help="Client configuration file"),
client_id: str | None = typer.Option(None, "-i", help="client identification for mqtt connection. *default: process id and the hostname of the client*"),
max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
topics: list[str] = typer.Option(..., "-t", help="Topic filter to subscribe, can be used multiple times."), # noqa: B008
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),
@ -147,8 +146,6 @@ def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
logger.debug(f"Using default configuration from {default_config_path}")
config = read_yaml_config(default_config_path)
loop = asyncio.get_event_loop()
if not client_id:
client_id = _gen_client_id()
@ -175,7 +172,7 @@ def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
)
with contextlib.suppress(KeyboardInterrupt):
try:
loop.run_until_complete(do_sub(client,
asyncio.run(do_sub(client,
url=url,
topics=topics,
ca_info=ca_info,
@ -184,10 +181,10 @@ def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
max_count=max_count,
clean_session=clean_session,
))
except (ClientError, ConnectError) as exc:
typer.echo("❌ Connection failed", err=True)
raise typer.Exit(code=1) from exc
loop.close()
if __name__ == "__main__":

Wyświetl plik

@ -1,4 +0,0 @@
{% extends "base.html" %}
{% block outdated %}
You're not viewing the latest version. <a href="{{ '../' ~ base_url }}"><strong>Click here to go to latest.</strong></a>
{% endblock %}

Wyświetl plik

@ -56,8 +56,6 @@ nav:
theme:
name: material
logo: assets/amqtt_bw.svg
extend:
- base.html
features:
- announce.dismiss
- content.action.edit
@ -75,6 +73,7 @@ theme:
- search.highlight
- search.suggest
- toc.follow
- version
palette:
# Palette toggle for light mode
- scheme: default
@ -168,7 +167,8 @@ plugins:
extra:
version:
provider: readthedocs
default: latest
default: v0.11.0
warning: true
social:
- icon: fontawesome/brands/github
link: https://github.com/pawamoy

Wyświetl plik

@ -191,7 +191,7 @@ timeout = 10
asyncio_default_fixture_loop_scope = "function"
#addopts = ["--tb=short", "--capture=tee-sys"]
#log_cli = true
#log_level = "DEBUG"
log_level = "DEBUG"
# ------------------------------------ MYPY ------------------------------------
[tool.mypy]

Wyświetl plik

@ -0,0 +1,10 @@
from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin
# intentional import error to test broker response
from pathlib import Pat # noqa
class MockImportErrorPlugin(BasePlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)

Wyświetl plik

@ -0,0 +1,18 @@
import logging
from amqtt.plugins.authentication import BaseAuthPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class NoAuthPlugin(BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
return False
class AuthPlugin(BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
return True

Wyświetl plik

@ -1,12 +1,17 @@
import inspect
from importlib.metadata import EntryPoint
from logging import getLogger
from pathlib import Path
from types import ModuleType
from typing import Any
from unittest.mock import patch
import pytest
import amqtt.plugins
from amqtt.broker import Broker, BrokerContext
from amqtt.errors import PluginError, PluginInitError, PluginImportError
from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
_INVALID_METHOD: str = "invalid_foo"
@ -61,3 +66,62 @@ def test_plugins_correct_has_attr() -> None:
__import__(name)
_verify_module(module, module.__name__)
class MockInitErrorPlugin(BasePlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
raise KeyError
@pytest.mark.asyncio
async def test_plugin_exception_while_init() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.test_plugins:MockInitErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
with pytest.raises(PluginInitError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)
@pytest.mark.asyncio
async def test_plugin_exception_while_loading() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='TestExceptionPlugin', group='tests.mock_plugins', value='tests.plugins.mock_plugins:MockImportErrorPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
with pytest.raises(PluginImportError):
_ = Broker(plugin_namespace='tests.mock_plugins', config=config)

Wyświetl plik

@ -13,49 +13,49 @@ from amqtt.mqtt.constants import QOS_0
logger = logging.getLogger(__name__)
# test broker sys
@pytest.mark.asyncio
async def test_broker_sys_plugin() -> None:
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1
}
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
await broker.start()
client = MQTTClient()
await client.connect("mqtt://127.0.0.1:1883/")
await client.subscribe([("$SYS/broker/uptime", QOS_0),])
await client.publish('test/topic', b'my test message')
await asyncio.sleep(2)
sys_msg_count = 0
try:
while True:
message = await client.deliver_message(timeout_duration=0.5)
if '$SYS' in message.topic:
sys_msg_count += 1
except asyncio.TimeoutError:
pass
logger.warning(f">>> sys message: {message.topic} - {message.data}")
await client.disconnect()
await broker.shutdown()
assert sys_msg_count > 1
# @pytest.mark.asyncio
# async def test_broker_sys_plugin() -> None:
#
# class MockEntryPoints:
#
# def select(self, group) -> list[EntryPoint]:
# match group:
# case 'tests.mock_plugins':
# return [
# EntryPoint(name='BrokerSysPlugin', group='tests.mock_plugins', value='amqtt.plugins.sys.broker:BrokerSysPlugin'),
# ]
# case _:
# return list()
#
#
# with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
#
# config = {
# "listeners": {
# "default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
# },
# 'sys_interval': 1
# }
#
# broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
# await broker.start()
# client = MQTTClient()
# await client.connect("mqtt://127.0.0.1:1883/")
# await client.subscribe([("$SYS/broker/uptime", QOS_0),])
# await client.publish('test/topic', b'my test message')
# await asyncio.sleep(2)
# sys_msg_count = 0
# try:
# while True:
# message = await client.deliver_message(timeout_duration=0.5)
# if '$SYS' in message.topic:
# sys_msg_count += 1
# except asyncio.TimeoutError:
# pass
#
# logger.warning(f">>> sys message: {message.topic} - {message.data}")
# await client.disconnect()
# await broker.shutdown()
#
#
# assert sys_msg_count > 1

Wyświetl plik

@ -7,18 +7,7 @@ import psutil
import pytest
from amqtt.adapters import StreamReaderAdapter, StreamWriterAdapter
from amqtt.broker import (
EVENT_BROKER_CLIENT_CONNECTED,
EVENT_BROKER_CLIENT_DISCONNECTED,
EVENT_BROKER_CLIENT_SUBSCRIBED,
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
EVENT_BROKER_MESSAGE_RECEIVED,
EVENT_BROKER_POST_SHUTDOWN,
EVENT_BROKER_POST_START,
EVENT_BROKER_PRE_SHUTDOWN,
EVENT_BROKER_PRE_START,
Broker,
)
from amqtt.broker import EventBroker, Broker
from amqtt.client import MQTTClient
from amqtt.errors import ConnectError
from amqtt.mqtt.connack import ConnackPacket
@ -67,8 +56,8 @@ def test_split_bindaddr_port(input_str, output_addr, output_port):
async def test_start_stop(broker, mock_plugin_manager):
mock_plugin_manager.assert_has_calls(
[
call().fire_event(EVENT_BROKER_PRE_START),
call().fire_event(EVENT_BROKER_POST_START),
call().fire_event(EventBroker.PRE_START.value),
call().fire_event(EventBroker.POST_START.value),
],
any_order=True,
)
@ -76,8 +65,8 @@ async def test_start_stop(broker, mock_plugin_manager):
await broker.shutdown()
mock_plugin_manager.assert_has_calls(
[
call().fire_event(EVENT_BROKER_PRE_SHUTDOWN),
call().fire_event(EVENT_BROKER_POST_SHUTDOWN),
call().fire_event(EventBroker.PRE_SHUTDOWN.value),
call().fire_event(EventBroker.POST_SHUTDOWN.value),
],
any_order=True,
)
@ -98,11 +87,11 @@ async def test_client_connect(broker, mock_plugin_manager):
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_CLIENT_CONNECTED,
EventBroker.CLIENT_CONNECTED.value,
client_id=client.session.client_id,
),
call().fire_event(
EVENT_BROKER_CLIENT_DISCONNECTED,
EventBroker.CLIENT_DISCONNECTED.value,
client_id=client.session.client_id,
),
],
@ -235,7 +224,7 @@ async def test_client_subscribe(broker, mock_plugin_manager):
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
EventBroker.CLIENT_SUBSCRIBED.value,
client_id=client.session.client_id,
topic="/topic",
qos=QOS_0,
@ -272,7 +261,7 @@ async def test_client_subscribe_twice(broker, mock_plugin_manager):
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
EventBroker.CLIENT_SUBSCRIBED.value,
client_id=client.session.client_id,
topic="/topic",
qos=QOS_0,
@ -306,13 +295,13 @@ async def test_client_unsubscribe(broker, mock_plugin_manager):
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
EventBroker.CLIENT_SUBSCRIBED.value,
client_id=client.session.client_id,
topic="/topic",
qos=QOS_0,
),
call().fire_event(
EVENT_BROKER_CLIENT_UNSUBSCRIBED,
EventBroker.CLIENT_UNSUBSCRIBED.value,
client_id=client.session.client_id,
topic="/topic",
),
@ -337,7 +326,7 @@ async def test_client_publish(broker, mock_plugin_manager):
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_MESSAGE_RECEIVED,
EventBroker.MESSAGE_RECEIVED.value,
client_id=pub_client.session.client_id,
message=ret_message,
),
@ -509,7 +498,7 @@ async def test_client_publish_big(broker, mock_plugin_manager):
mock_plugin_manager.assert_has_calls(
[
call().fire_event(
EVENT_BROKER_MESSAGE_RECEIVED,
EventBroker.MESSAGE_RECEIVED.value,
client_id=pub_client.session.client_id,
message=ret_message,
),
@ -740,3 +729,16 @@ async def test_broker_broadcast_cancellation(broker):
await _client_publish(topic, data, qos)
message = await asyncio.wait_for(sub_client.deliver_message(), timeout=1)
assert message
@pytest.mark.asyncio
async def test_broker_socket_open_close(broker):
# check that https://github.com/Yakifo/amqtt/issues/86 is fixed
# mqtt 3.1 requires a connect packet, otherwise the socket connection is rejected
static_connect_packet = b'\x10\x1b\x00\x04MQTT\x04\x02\x00<\x00\x0ftest-client-123'
s = socket.create_connection(("127.0.0.1", 1883))
s.send(static_connect_packet)
await asyncio.sleep(0.1)
s.close()

Wyświetl plik

@ -4,10 +4,12 @@ import os
import signal
import subprocess
import tempfile
from unittest.mock import patch
import pytest
import yaml
from amqtt.broker import Broker
from amqtt.mqtt.constants import QOS_0
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
@ -158,38 +160,6 @@ async def test_publish_subscribe(broker):
assert sub_proc.returncode == 0
@pytest.mark.asyncio
async def test_pub_sub_retain(broker):
"""Test various pub/sub will retain options."""
# Test publishing with retain flag
pub_proc = subprocess.run(
[
"amqtt_pub",
"--url", "mqtt://127.0.0.1:1884",
"-t", "topic/test",
"-m", "standard message",
"--will-topic", "topic/retain",
"--will-message", "last will message",
"--will-retain",
],
capture_output=True,
)
assert pub_proc.returncode == 0, f"publisher error code: {pub_proc.returncode}\n{pub_proc.stderr}"
logger.debug("publisher succeeded")
# Verify retained message is received by new subscriber
sub_proc = subprocess.run(
[
"amqtt_sub",
"--url", "mqtt://127.0.0.1:1884",
"-t", "topic/retain",
"-n", "1",
],
capture_output=True,
)
assert sub_proc.returncode == 0, f"subscriber error code: {sub_proc.returncode}\n{sub_proc.stderr}"
assert "last will message" in str(sub_proc.stdout)
@pytest.mark.asyncio
async def test_pub_errors(client_config_file):
"""Test error handling in pub/sub tools."""
@ -275,74 +245,3 @@ async def test_pub_client_config(broker, client_config_file):
logger.debug(f"Stderr: {stderr.decode()}")
assert proc.returncode == 0, f"publisher error code: {proc.returncode}"
@pytest.mark.asyncio
async def test_pub_client_config_will(broker, client_config_file):
# verifying client script functionality of will topic (publisher)
# https://github.com/Yakifo/amqtt/issues/159
await asyncio.sleep(1)
client1 = MQTTClient(client_id="client1")
await client1.connect('mqtt://localhost:1884')
await client1.subscribe([
("test/will/topic", QOS_0)
])
cmd = ["amqtt_pub",
"-t", "test/topic",
"-m", "\"test of regular topic\"",
"-c", client_config_file]
proc = await asyncio.create_subprocess_shell(
" ".join(cmd), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
logger.debug(f"Command: {cmd}")
logger.debug(f"Stdout: {stdout.decode()}")
logger.debug(f"Stderr: {stderr.decode()}")
message = await client1.deliver_message(timeout_duration=1)
assert message.topic == 'test/will/topic'
assert message.data == b'client ABC has disconnected'
await client1.disconnect()
@pytest.mark.asyncio
@pytest.mark.timeout(20)
async def test_sub_client_config_will(broker, client_config, client_config_file):
# verifying client script functionality of will topic (subscriber)
# https://github.com/Yakifo/amqtt/issues/159
client1 = MQTTClient(client_id="client1")
await client1.connect('mqtt://localhost:1884')
await client1.subscribe([
("test/will/topic", QOS_0)
])
cmd = ["amqtt_sub",
"-t", "test/topic",
"-c", client_config_file,
"-n", "1"]
proc = await asyncio.create_subprocess_shell(
" ".join(cmd), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
await asyncio.sleep(2)
# cause `amqtt_sub` to exit after receiving this one message
await client1.publish("test/topic", b'my test message')
# validate the 'will' message was received correctly
message = await client1.deliver_message(timeout_duration=3)
assert message.topic == 'test/will/topic'
assert message.data == b'client ABC has disconnected'
await client1.disconnect()
stdout, stderr = await proc.communicate()
logger.debug(f"Command: {cmd}")
logger.debug(f"Stdout: {stdout.decode()}")
logger.debug(f"Stderr: {stderr.decode()}")

Wyświetl plik

@ -1,10 +1,13 @@
import asyncio
import logging
from importlib.metadata import EntryPoint
from unittest.mock import patch
import pytest
from amqtt.broker import Broker
from amqtt.client import MQTTClient
from amqtt.errors import ConnectError
from amqtt.errors import ClientError, ConnectError
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
@ -265,27 +268,121 @@ def client_config():
@pytest.mark.asyncio
async def test_client_publish_will_with_retain(broker_fixture, client_config):
async def test_client_will_with_clean_disconnect(broker_fixture):
config = {
"will": {
"topic": "test/will/topic",
"retain": False,
"message": "client ABC has disconnected",
"qos": 1
},
}
# verifying client functionality of will topic
# https://github.com/Yakifo/amqtt/issues/159
client1 = MQTTClient(client_id="client1", config=config)
await client1.connect("mqtt://localhost:1883")
client1 = MQTTClient(client_id="client1")
client2 = MQTTClient(client_id="client2")
await client2.connect("mqtt://localhost:1883")
await client2.subscribe(
[
("test/will/topic", QOS_0),
]
)
await client1.disconnect()
await asyncio.sleep(1)
with pytest.raises(asyncio.TimeoutError):
message = await client2.deliver_message(timeout_duration=2)
# if we do get a message, make sure it's not a will message
assert message.topic != "test/will/topic"
await client2.disconnect()
@pytest.mark.asyncio
async def test_client_will_with_abrupt_disconnect(broker_fixture):
config = {
"will": {
"topic": "test/will/topic",
"retain": False,
"message": "client ABC has disconnected",
"qos": 1
},
}
client1 = MQTTClient(client_id="client1", config=config)
await client1.connect("mqtt://localhost:1883")
client2 = MQTTClient(client_id="client2")
await client2.connect("mqtt://localhost:1883")
await client2.subscribe(
[
("test/will/topic", QOS_0),
]
)
# instead of client.disconnect, call the necessary closing but without sending the disconnect packet
await client1.cancel_tasks()
if client1._disconnect_task and not client1._disconnect_task.done():
client1._disconnect_task.cancel()
client1._connected_state.clear()
await client1._handler.stop()
client1.session.transitions.disconnect()
await asyncio.sleep(1)
message = await client2.deliver_message(timeout_duration=1)
# make sure we receive the will message
assert message.topic == "test/will/topic"
assert message.data == b'client ABC has disconnected'
await client2.disconnect()
@pytest.mark.asyncio
async def test_client_retained_will_with_abrupt_disconnect(broker_fixture):
# verifying client functionality of retained will topic/message
config = {
"will": {
"topic": "test/will/topic",
"retain": True,
"message": "client ABC has disconnected",
"qos": 1
},
}
# first client, connect with retained will message
client1 = MQTTClient(client_id="client1", config=config)
await client1.connect('mqtt://localhost:1883')
await client1.subscribe([
client2 = MQTTClient(client_id="client2")
await client2.connect('mqtt://localhost:1883')
await client2.subscribe([
("test/will/topic", QOS_0)
])
client2 = MQTTClient(client_id="client2", config=client_config)
await client2.connect('mqtt://localhost:1883')
await client2.publish('my/topic', b'my message')
await client2.disconnect()
message = await client1.deliver_message(timeout_duration=1)
# let's abruptly disconnect client1
await client1.cancel_tasks()
if client1._disconnect_task and not client1._disconnect_task.done():
client1._disconnect_task.cancel()
client1._connected_state.clear()
await client1._handler.stop()
client1.session.transitions.disconnect()
await asyncio.sleep(0.5)
# make sure the client which is still connected that we get the 'will' message
message = await client2.deliver_message(timeout_duration=1)
assert message.topic == 'test/will/topic'
assert message.data == b'client ABC has disconnected'
await client1.disconnect()
await client2.disconnect()
# make sure a client which is connected after client1 disconnected still receives the 'will' message from
client3 = MQTTClient(client_id="client3")
await client3.connect('mqtt://localhost:1883')
await client3.subscribe([
@ -295,3 +392,92 @@ async def test_client_publish_will_with_retain(broker_fixture, client_config):
assert message3.topic == 'test/will/topic'
assert message3.data == b'client ABC has disconnected'
await client3.disconnect()
@pytest.mark.asyncio
async def test_client_abruptly_disconnecting_with_empty_will_message(broker_fixture):
config = {
"will": {
"topic": "test/will/topic",
"retain": True,
"message": "",
"qos": 1
},
}
client1 = MQTTClient(client_id="client1", config=config)
await client1.connect('mqtt://localhost:1883')
client2 = MQTTClient(client_id="client2")
await client2.connect('mqtt://localhost:1883')
await client2.subscribe([
("test/will/topic", QOS_0)
])
# let's abruptly disconnect client1
await client1.cancel_tasks()
if client1._disconnect_task and not client1._disconnect_task.done():
client1._disconnect_task.cancel()
client1._connected_state.clear()
await client1._handler.stop()
client1.session.transitions.disconnect()
await asyncio.sleep(0.5)
message = await client2.deliver_message(timeout_duration=1)
assert message.topic == 'test/will/topic'
assert message.data == b''
await client2.disconnect()
async def test_connect_broken_uri():
config = {"auto_reconnect": False}
client = MQTTClient(config=config)
with pytest.raises(ClientError):
await client.connect('"mqtt://someplace')
@pytest.mark.asyncio
async def test_connect_incorrect_scheme():
config = {"auto_reconnect": False}
client = MQTTClient(config=config)
with pytest.raises(ClientError):
await client.connect('"mq://someplace')
async def test_client_no_auth():
class MockEntryPoints:
def select(self, group) -> list[EntryPoint]:
match group:
case 'tests.mock_plugins':
return [
EntryPoint(name='auth_plugin', group='tests.mock_plugins', value='tests.plugins.mocks:NoAuthPlugin'),
]
case _:
return list()
with patch("amqtt.plugins.manager.entry_points", side_effect=MockEntryPoints) as mocked_mqtt_publish:
config = {
"listeners": {
"default": {"type": "tcp", "bind": "127.0.0.1:1883", "max_connections": 10},
},
'sys_interval': 1,
'auth': {
'plugins': ['auth_plugin', ]
}
}
client = MQTTClient(client_id="client1", config={'auto_reconnect': False})
broker = Broker(plugin_namespace='tests.mock_plugins', config=config)
await broker.start()
with pytest.raises(ConnectError):
await client.connect("mqtt://127.0.0.1:1883/")
await broker.shutdown()

Wyświetl plik

@ -6,8 +6,7 @@ from unittest.mock import MagicMock, call, patch
import pytest
from paho.mqtt import client as mqtt_client
from amqtt.broker import EVENT_BROKER_CLIENT_CONNECTED, EVENT_BROKER_CLIENT_DISCONNECTED, EVENT_BROKER_PRE_START, \
EVENT_BROKER_POST_START
from amqtt.broker import EventBroker
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_1, QOS_2
@ -54,11 +53,11 @@ async def test_paho_connect(broker, mock_plugin_manager):
broker.plugins_manager.assert_has_calls(
[
call.fire_event(
EVENT_BROKER_CLIENT_CONNECTED,
EventBroker.CLIENT_CONNECTED.value,
client_id=client_id,
),
call.fire_event(
EVENT_BROKER_CLIENT_DISCONNECTED,
EventBroker.CLIENT_DISCONNECTED.value,
client_id=client_id,
),
],