From e15286e82f141aaeb8b504eb7288192bae439326 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Fri, 8 Aug 2025 22:15:46 -0400 Subject: [PATCH] moving unix socket into samples --- amqtt/broker.py | 6 -- amqtt/contexts.py | 11 --- samples/unix_socket_adapters.py | 57 ------------- samples/unix_socket_broker.py | 15 ---- samples/unix_socket_client.py | 27 ------- samples/unix_sockets.py | 138 ++++++++++++++++++++++++++++++++ 6 files changed, 138 insertions(+), 116 deletions(-) delete mode 100644 samples/unix_socket_adapters.py delete mode 100644 samples/unix_socket_broker.py delete mode 100644 samples/unix_socket_client.py create mode 100644 samples/unix_sockets.py diff --git a/amqtt/broker.py b/amqtt/broker.py index 7fc05ca..eba11be 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -327,12 +327,6 @@ class Broker: ssl=ssl_context, subprotocols=[websockets.Subprotocol("mqtt")], ) - case ListenerType.UNIX: - return await asyncio.start_unix_server( - partial(self.stream_connected, listener_name=listener_name), - address, - ssl=ssl_context, - loop=self._loop) case _: msg = f"Unsupported listener type: {listener_type}" raise BrokerError(msg) diff --git a/amqtt/contexts.py b/amqtt/contexts.py index 61bbcec..59de8a7 100644 --- a/amqtt/contexts.py +++ b/amqtt/contexts.py @@ -131,17 +131,6 @@ class ListenerConfig(Dictable): msg = f"'{fn}' does not exist : {getattr(self, fn)}" raise FileNotFoundError(msg) - if isinstance(self.bind, Path) and self.type != ListenerType.UNIX: - msg = "bind address can only be a `pathlib.Path` if listener type is unix" - raise ValueError(msg) - - if self.type == ListenerType.UNIX: - if isinstance(self.bind, str): - self.bind = Path(self.bind) - if self.bind and not self.bind.exists(): - msg = f"unix socket : '{self.bind}' does not exist" - raise FileNotFoundError(msg) - def apply(self, other: "ListenerConfig") -> None: """Apply the field from 'other', if 'self' field is default.""" for f in fields(self): diff --git a/samples/unix_socket_adapters.py b/samples/unix_socket_adapters.py deleted file mode 100644 index 35a5f35..0000000 --- a/samples/unix_socket_adapters.py +++ /dev/null @@ -1,57 +0,0 @@ -import contextlib -import logging -from asyncio import StreamWriter, StreamReader, Event - -from amqtt.adapters import ReaderAdapter, WriterAdapter - - -logger = logging.getLogger(__name__) - - -class UnixStreamReaderAdapter(ReaderAdapter): - - def __init__(self, reader: StreamReader): - self._reader = reader - - async def read(self, n:int = -1) -> bytes: - if n < 0: - return await self._reader.read() - return await self._reader.readexactly(n) - - def feed_eof(self): - return self._reader.feed_eof() - - -class UnixStreamWriterAdapter(WriterAdapter): - - def __init__(self, writer: StreamWriter): - self._writer = writer - self.is_closed = Event() - - def write(self, data): - if not self.is_closed: - self._writer.write(data) - - async def drain(self): - if self.is_closed.is_set(): - await self._writer.drain() - - def get_peer_info(self): - extra_info = self._writer.get_extra_info('socket') - return extra_info.getsockname(), 0 - - async def close(self): - if self.is_closed.is_set(): - return - self.is_closed.set() - - await self._writer.drain() - if self._writer.can_write_eof(): - self._writer.write_eof() - - self._writer.close() - - with contextlib.suppress(AttributeError): - await self._writer.wait_closed() - - diff --git a/samples/unix_socket_broker.py b/samples/unix_socket_broker.py deleted file mode 100644 index a3a8b72..0000000 --- a/samples/unix_socket_broker.py +++ /dev/null @@ -1,15 +0,0 @@ -from amqtt.broker import Broker -from amqtt.contexts import BrokerConfig, ListenerConfig - - -async def main(): - - cfg = BrokerConfig( - listeners={ - 'default': ListenerConfig( - ListenerType.External - ) - } - ) - - b = Broker() \ No newline at end of file diff --git a/samples/unix_socket_client.py b/samples/unix_socket_client.py deleted file mode 100644 index 5ac62d4..0000000 --- a/samples/unix_socket_client.py +++ /dev/null @@ -1,27 +0,0 @@ -import asyncio - -from amqtt.client import MQTTClient, ClientContext -from amqtt.contexts import ClientConfig -from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler -from amqtt.plugins.manager import PluginManager -from amqtt.session import Session -from samples.unix_socket_adapters import UnixStreamReaderAdapter, UnixStreamWriterAdapter - - -async def client(): - config = ClientConfig() - context = ClientContext() - context.config = config - plugins_manager = PluginManager("amqtt.client.plugins", context) - - cph = ClientProtocolHandler(plugins_manager) - - s = Session() - r = UnixStreamReaderAdapter() - w = UnixStreamWriterAdapter() - - cph.attach(session=s, reader=r, writer=w) - await cph.mqtt_connect() - -if __name__ == '__main__': - asyncio.run(client()) diff --git a/samples/unix_sockets.py b/samples/unix_sockets.py new file mode 100644 index 0000000..62a3b47 --- /dev/null +++ b/samples/unix_sockets.py @@ -0,0 +1,138 @@ +import contextlib +import logging +import asyncio +from asyncio import StreamWriter, StreamReader, Event +from functools import partial + +import typer + +from amqtt.broker import Broker +from amqtt.client import ClientContext +from amqtt.contexts import ClientConfig, BrokerConfig, ListenerConfig, ListenerType +from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler +from amqtt.plugins.manager import PluginManager +from amqtt.session import Session +from amqtt.adapters import ReaderAdapter, WriterAdapter + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +app = typer.Typer(add_completion=False, rich_markup_mode=None) + +class UnixStreamReaderAdapter(ReaderAdapter): + + def __init__(self, reader: StreamReader): + self._reader = reader + + async def read(self, n:int = -1) -> bytes: + if n < 0: + return await self._reader.read() + return await self._reader.readexactly(n) + + def feed_eof(self): + return self._reader.feed_eof() + + +class UnixStreamWriterAdapter(WriterAdapter): + + def __init__(self, writer: StreamWriter): + self._writer = writer + self.is_closed = Event() + + def write(self, data): + if not self.is_closed.is_set(): + self._writer.write(data) + + async def drain(self): + if self.is_closed.is_set(): + await self._writer.drain() + + def get_peer_info(self): + extra_info = self._writer.get_extra_info('socket') + return extra_info.getsockname(), 0 + + async def close(self): + if self.is_closed.is_set(): + return + self.is_closed.set() + + await self._writer.drain() + if self._writer.can_write_eof(): + self._writer.write_eof() + + self._writer.close() + + with contextlib.suppress(AttributeError): + await self._writer.wait_closed() + + + +async def run_broker(): + + cfg = BrokerConfig( + listeners={ + 'default': ListenerConfig( + type=ListenerType.EXTERNAL + ) + }, + plugins={ + "amqtt.plugins.logging_amqtt.EventLoggerPlugin":{}, + "amqtt.plugins.logging_amqtt.PacketLoggerPlugin":{}, + "amqtt.plugins.authentication.AnonymousAuthPlugin":{"allow_anonymous":True}, + } + ) + + b = Broker(cfg) + + async def unix_stream_connected(reader, writer, listener_name): + logger.info("received new unix connection....") + r = UnixStreamReaderAdapter(reader) + w = UnixStreamWriterAdapter(writer) + await b.external_connected(reader=r, writer=w, listener_name='default') + + await asyncio.start_unix_server(partial(unix_stream_connected, listener_name='default'), path="/tmp/mqtt") + await b.start() + + try: + logger.info("starting mqtt unix server") + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + await b.shutdown() + +@app.command() +def broker(): + asyncio.run(run_broker()) + +async def run_client(): + config = ClientConfig() + context = ClientContext() + context.config = config + plugins_manager = PluginManager("amqtt.client.plugins", context) + + cph = ClientProtocolHandler(plugins_manager) + conn_reader, conn_writer = await asyncio.open_unix_connection(path="/tmp/mqtt") + s = Session() + s.client_id = "myUnixClientID" + r = UnixStreamReaderAdapter(conn_reader) + w = UnixStreamWriterAdapter(conn_writer) + cph.attach(session=s, reader=r, writer=w) + logger.debug("handler attached") + ret = await cph.mqtt_connect() + logger.info(f"client connected: {ret}") + + try: + while True: + await cph.mqtt_publish('my/topic', b'my message', 0, False) + await asyncio.sleep(1) + except KeyboardInterrupt: + cph.detach() + +@app.command() +def client(): + asyncio.run(run_client()) + + +if __name__ == "__main__": + app()