From 0e318671ddb1795b8cb3caa7128da4e46fcdeefd Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Sat, 9 Aug 2025 15:37:33 -0400 Subject: [PATCH] Sample: broker and client communicating with mqtt over unix socket (#291) * Yakifo/aqmtt#290 : create a sample that implements mqtt over unix socket. documentation and test case. --- amqtt/broker.py | 43 +++++----- amqtt/contexts.py | 5 ++ samples/unix_sockets.py | 179 ++++++++++++++++++++++++++++++++++++++++ tests/test_samples.py | 32 +++++++ 4 files changed, 239 insertions(+), 20 deletions(-) create mode 100644 samples/unix_sockets.py diff --git a/amqtt/broker.py b/amqtt/broker.py index 5354391..02ae1c5 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -291,7 +291,7 @@ class Broker: msg = f"Invalid port value in bind value: {listener['bind']}" raise BrokerError(msg) from e - instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) + instance = await self._create_server_instance(listener_name, listener.type, address, port, ssl_context) self._servers[listener_name] = Server(listener_name, instance, max_connections) self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") @@ -319,30 +319,33 @@ class Broker: async def _create_server_instance( self, listener_name: str, - listener_type: str, + listener_type: ListenerType, address: str | None, port: int, ssl_context: ssl.SSLContext | None, ) -> asyncio.Server | websockets.asyncio.server.Server: """Create a server instance for a listener.""" - if listener_type == "tcp": - return await asyncio.start_server( - partial(self.stream_connected, listener_name=listener_name), - address, - port, - reuse_address=True, - ssl=ssl_context, - ) - if listener_type == "ws": - return await websockets.serve( - partial(self.ws_connected, listener_name=listener_name), - address, - port, - ssl=ssl_context, - subprotocols=[websockets.Subprotocol("mqtt")], - ) - msg = f"Unsupported listener type: {listener_type}" - raise BrokerError(msg) + + match listener_type: + case ListenerType.TCP: + return await asyncio.start_server( + partial(self.stream_connected, listener_name=listener_name), + address, + port, + reuse_address=True, + ssl=ssl_context, + ) + case ListenerType.WS: + return await websockets.serve( + partial(self.ws_connected, listener_name=listener_name), + address, + port, + ssl=ssl_context, + subprotocols=[websockets.Subprotocol("mqtt")], + ) + case _: + msg = f"Unsupported listener type: {listener_type}" + raise BrokerError(msg) async def _session_monitor(self) -> None: diff --git a/amqtt/contexts.py b/amqtt/contexts.py index 722318b..ff24e7e 100644 --- a/amqtt/contexts.py +++ b/amqtt/contexts.py @@ -127,6 +127,9 @@ class ListenerConfig(Dictable): for fn in ("cafile", "capath", "certfile", "keyfile"): if isinstance(getattr(self, fn), str): setattr(self, fn, Path(getattr(self, fn))) + if getattr(self, fn) and not getattr(self, fn).exists(): + msg = f"'{fn}' does not exist : {getattr(self, fn)}" + raise FileNotFoundError(msg) def apply(self, other: "ListenerConfig") -> None: """Apply the field from 'other', if 'self' field is default.""" @@ -134,12 +137,14 @@ class ListenerConfig(Dictable): if getattr(self, f.name) == f.default: setattr(self, f.name, other[f.name]) + def default_listeners() -> dict[str, Any]: """Create defaults for BrokerConfig.listeners.""" return { "default": ListenerConfig() } + def default_broker_plugins() -> dict[str, Any]: """Create defaults for BrokerConfig.plugins.""" return { diff --git a/samples/unix_sockets.py b/samples/unix_sockets.py new file mode 100644 index 0000000..6bafd97 --- /dev/null +++ b/samples/unix_sockets.py @@ -0,0 +1,179 @@ +import contextlib +import logging +import asyncio +from asyncio import StreamWriter, StreamReader, Event +from functools import partial +from pathlib import Path + +import typer + +from amqtt.broker import Broker +from amqtt.client import ClientContext +from amqtt.contexts import ClientConfig, BrokerConfig, ListenerConfig, ListenerType +from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler +from amqtt.plugins.manager import PluginManager +from amqtt.session import Session +from amqtt.adapters import ReaderAdapter, WriterAdapter + +logger = logging.getLogger(__name__) + + +app = typer.Typer(add_completion=False, rich_markup_mode=None) + +# Usage: unix_sockets.py [OPTIONS] COMMAND [ARGS]... +# +# Options: +# --help Show this message and exit. +# +# Commands: +# broker Run an mqtt broker that communicates over a unix (file) socket. +# client Run an mqtt client that communicates over a unix (file) socket. + + +class UnixStreamReaderAdapter(ReaderAdapter): + + def __init__(self, reader: StreamReader) -> None: + self._reader = reader + + async def read(self, n:int = -1) -> bytes: + if n < 0: + return await self._reader.read() + return await self._reader.readexactly(n) + + def feed_eof(self) -> None: + return self._reader.feed_eof() + + +class UnixStreamWriterAdapter(WriterAdapter): + + def __init__(self, writer: StreamWriter) -> None: + self._writer = writer + self.is_closed = Event() + + def write(self, data: bytes) -> None: + if not self.is_closed.is_set(): + self._writer.write(data) + + async def drain(self) -> None: + if self.is_closed.is_set(): + await self._writer.drain() + + def get_peer_info(self) -> tuple[str, int]: + extra_info = self._writer.get_extra_info('socket') + return extra_info.getsockname(), 0 + + async def close(self) -> None: + if self.is_closed.is_set(): + return + self.is_closed.set() + + await self._writer.drain() + if self._writer.can_write_eof(): + self._writer.write_eof() + + self._writer.close() + + with contextlib.suppress(AttributeError): + await self._writer.wait_closed() + + +async def run_broker(socket_file: Path): + + # configure the broker with a single, external listener + cfg = BrokerConfig( + listeners={ + 'default': ListenerConfig( + type=ListenerType.EXTERNAL + ) + }, + plugins={ + "amqtt.plugins.logging_amqtt.EventLoggerPlugin":{}, + "amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{}, + "amqtt.plugins.authentication.AnonymousAuthPlugin":{"allow_anonymous":True}, + } + ) + + b = Broker(cfg) + + # new connection handler + async def unix_stream_connected(reader: StreamReader, writer: StreamWriter, listener_name: str): + logger.info("received new unix connection....") + # wraps the reader/writer in a compatible interface + r = UnixStreamReaderAdapter(reader) + w = UnixStreamWriterAdapter(writer) + + # passes the connection to the broker for protocol communications + await b.external_connected(reader=r, writer=w, listener_name=listener_name) + + await asyncio.start_unix_server(partial(unix_stream_connected, listener_name='default'), path=socket_file) + await b.start() + + try: + logger.info("starting mqtt unix server") + # run until ctrl-c + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + await b.shutdown() + + +@app.command() +def broker( + socket_file: str | None = typer.Option("/tmp/mqtt", "-s", "--socket", help="path and file for unix socket"), + verbose: bool = typer.Option(False, "-v", "--verbose", help="set logging level to DEBUG"), +): + """Run an mqtt broker that communicates over a unix (file) socket.""" + logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO) + asyncio.run(run_broker(Path(socket_file))) + + +async def run_client(socket_file: Path): + # 'MQTTClient' establishes the connection but uses the ClientProtocolHandler for MQTT protocol communications + + # create a plugin manager + config = ClientConfig() + context = ClientContext() + context.config = config + plugins_manager = PluginManager("amqtt.client.plugins", context) + + # create a client protocol handler + cph = ClientProtocolHandler(plugins_manager) + + # connect to the unix socket + conn_reader, conn_writer = await asyncio.open_unix_connection(path=socket_file) + + # anonymous session connection just needs a client_id + s = Session() + s.client_id = "myUnixClientID" + + # wraps the reader/writer in compatible interface + r = UnixStreamReaderAdapter(conn_reader) + w = UnixStreamWriterAdapter(conn_writer) + + # pass the connection to the protocol handler for mqtt communications and initiate CONNECT/CONNACK + cph.attach(session=s, reader=r, writer=w) + logger.debug("handler attached") + ret = await cph.mqtt_connect() + logger.info(f"client connected: {ret}") + + try: + while True: + # periodically send a message + await cph.mqtt_publish('my/topic', b'my message', 0, False) + await asyncio.sleep(1) + except KeyboardInterrupt: + cph.detach() + + +@app.command() +def client( + socket_file: str | None = typer.Option("/tmp/mqtt", "-s", "--socket", help="path and file for unix socket"), + verbose: bool = typer.Option(False, "-v", "--verbose", help="set logging level to DEBUG"), +): + """Run an mqtt client that communicates over a unix (file) socket.""" + logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO) + asyncio.run(run_client(Path(socket_file))) + + +if __name__ == "__main__": + app() diff --git a/tests/test_samples.py b/tests/test_samples.py index 26e4ce8..1ddf9b9 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -5,7 +5,11 @@ import subprocess from multiprocessing import Process from pathlib import Path + +from typer.testing import CliRunner + from samples.http_server_integration import main as http_server_main +from samples.unix_sockets import app as unix_sockets_app import pytest @@ -302,3 +306,31 @@ async def test_external_http_server(external_http_server): await client.disconnect() # Send the interrupt signal await asyncio.sleep(1) + + +@pytest.mark.asyncio +async def test_unix_connection(): + + unix_socket_script = Path(__file__).parent.parent / "samples/unix_sockets.py" + broker_process = subprocess.Popen(["python", unix_socket_script, "broker", "-s", "/tmp/mqtt"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + # start the broker + await asyncio.sleep(1) + + # start the client + client_process = subprocess.Popen(["python", unix_socket_script, "client", "-s", "/tmp/mqtt"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + await asyncio.sleep(3) + + # stop the client (ctrl-c) + client_process.send_signal(signal.SIGINT) + _ = client_process.communicate() + + # stop the broker (ctrl-c) + broker_process.send_signal(signal.SIGINT) + broker_stdout, broker_stderr = broker_process.communicate() + + logger.debug(broker_stderr.decode("utf-8")) + + # verify that the broker received client connected/disconnected + assert "on_broker_client_connected" in broker_stderr.decode("utf-8") + assert "on_broker_client_disconnected" in broker_stderr.decode("utf-8")