From 76651e32f3a21569424e59a9c7d25728fa37f2ea Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Mon, 4 Aug 2025 16:27:51 -0400 Subject: [PATCH 1/5] websockets via aiohttp --- amqtt/broker.py | 30 +++++++-- amqtt/codecs_amqtt.py | 2 + amqtt/contexts.py | 1 + samples/web.py | 145 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+), 4 deletions(-) create mode 100644 samples/web.py diff --git a/amqtt/broker.py b/amqtt/broker.py index b89d45a..3bbc8c0 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -10,6 +10,9 @@ import ssl import time from typing import Any, ClassVar, TypeAlias +from aiohttp.pytest_plugin import AiohttpServer +from aiohttp.web_ws import WebSocketResponse +from pygments.token import Other from transitions import Machine, MachineError import websockets.asyncio.server from websockets.asyncio.server import ServerConnection @@ -36,7 +39,7 @@ from .plugins.manager import PluginManager _BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None] # Default port numbers -DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} +DEFAULT_PORTS = {"tcp": 1883, "ws": 8883, 'aiohttp': 8080} AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 @@ -89,6 +92,20 @@ class Server: await self.instance.wait_closed() +class OtherServer(Server): + def __init__(self): + super().__init__('aiohttp', None) + + async def acquire_connection(self) -> None: + pass + + async def release_connection(self) -> None: + pass + + async def close_instance(self) -> None: + pass + + class BrokerContext(BaseContext): """BrokerContext is used as the context passed to plugins interacting with the broker. @@ -249,8 +266,11 @@ class Broker: msg = f"Invalid port value in bind value: {listener['bind']}" raise BrokerError(msg) from e - instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) - self._servers[listener_name] = Server(listener_name, instance, max_connections) + if listener["type"] == 'aiohttp': + self._servers[listener_name] = OtherServer() + else: + instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) + self._servers[listener_name] = Server(listener_name, instance, max_connections) self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") @@ -281,7 +301,7 @@ class Broker: address: str | None, port: int, ssl_context: ssl.SSLContext | None, - ) -> asyncio.Server | websockets.asyncio.server.Server: + ) -> asyncio.Server | websockets.asyncio.server.Server | OtherServer: """Create a server instance for a listener.""" if listener_type == "tcp": return await asyncio.start_server( @@ -299,6 +319,8 @@ class Broker: ssl=ssl_context, subprotocols=[websockets.Subprotocol("mqtt")], ) + if listener_type == "aiohttp": + return OtherServer() msg = f"Unsupported listener type: {listener_type}" raise BrokerError(msg) diff --git a/amqtt/codecs_amqtt.py b/amqtt/codecs_amqtt.py index 1db2d9c..20bb536 100644 --- a/amqtt/codecs_amqtt.py +++ b/amqtt/codecs_amqtt.py @@ -75,8 +75,10 @@ async def decode_string(reader: ReaderAdapter | asyncio.StreamReader) -> str: length_bytes = await read_or_raise(reader, 2) if len(length_bytes) < 1: raise ZeroLengthReadError + print(f"attempting to decode {length_bytes}") str_length = unpack("!H", length_bytes)[0] if str_length: + print(f"reading {str_length} bytes") byte_str = await read_or_raise(reader, str_length) try: return byte_str.decode(encoding="utf-8") diff --git a/amqtt/contexts.py b/amqtt/contexts.py index 1b37153..1bd733c 100644 --- a/amqtt/contexts.py +++ b/amqtt/contexts.py @@ -44,6 +44,7 @@ class ListenerType(StrEnum): TCP = "tcp" WS = "ws" + AIOHTTP = "aiohttp" def __repr__(self) -> str: """Display the string value, instead of the enum member.""" diff --git a/samples/web.py b/samples/web.py new file mode 100644 index 0000000..e54d4e5 --- /dev/null +++ b/samples/web.py @@ -0,0 +1,145 @@ +import asyncio +import logging + +import aiohttp +from aiohttp import web + +from amqtt.adapters import ReaderAdapter, WriterAdapter +from amqtt.broker import Broker +from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType +from amqtt.errors import BrokerError + + +async def hello(request): + return web.Response(text="Hello, world") + + + +class AIOWebSocketsReader(ReaderAdapter): + def __init__(self, ws: web.WebSocketResponse): + self.ws = ws + self.buffer = bytearray() + self.closed = False + + async def read(self, n: int = -1) -> bytes: + print(f"attempting to read {n} bytes") + while not self.buffer and not self.closed or len(self.buffer) < n: + msg = await self.ws.receive() + if msg.type == aiohttp.WSMsgType.BINARY: + self.buffer.extend(msg.data) + elif msg.type == aiohttp.WSMsgType.CLOSE: + self.closed = True + print(f"buffer size: {len(self.buffer)}") + if n == -1: + result = bytes(self.buffer) + self.buffer.clear() + else: + result = self.buffer[:n] + del self.buffer[:n] + print(f"bytes: {result}") + return result + + def feed_eof(self) -> None: + pass + +class AIOWebSocketsWriter(WriterAdapter): + + def __init__(self, ws: web.WebSocketResponse): + super().__init__() + self.ws = ws + + def write(self, data: bytes) -> None: + print(f"broker wants to write data: {data}") + self.ws.send_bytes(data) + + + async def drain(self) -> None: + pass + + def get_peer_info(self) -> tuple[str, int] | None: + return "aiohttp", 1234567 + + async def close(self) -> None: + pass + + +async def websocket_handler(request): + print() + ws = web.WebSocketResponse(protocols=['mqtt',]) + await ws.prepare(request) + # + # readQ = asyncio.Queue() + # writeQ = asyncio.Queue() + # + # async def receiver(): + # async for msg in ws: + # match msg.type: + # case aiohttp.WSMsgType.BINARY: + # readQ.put_nowait(msg.data) + # case _: + # return + # + # async def send_items(): + # while not ws.closed: + # if not writeQ.empty(): + # item = await writeQ.get() + # await ws.send_bytes(item) + # + # await asyncio.create_task(send_items()) + # + b: Broker = request.app['broker'] + await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws)) + + # async for msg in ws: + # print(f"ws: {msg}") + + + + # + # elif msg.type == aiohttp.WSMsgType.ERROR: + # print('ws connection closed with exception %s' % + # ws.exception()) + + print('websocket connection closed') + + return ws + + +def main(): + + app = web.Application() + app.add_routes( + [ + web.get('/', hello), + web.get('/ws', websocket_handler) + ]) + app.cleanup_ctx.append(run_broker) + web.run_app(app) + + +async def run_broker(_app): + loop = asyncio.get_event_loop() + + cfg = BrokerConfig( + listeners={ + 'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'), + 'aiohttp': ListenerConfig(type=ListenerType.AIOHTTP), + } + ) + + + + broker = Broker(config=cfg, loop=loop) + _app['broker'] = broker + await broker.start() + + yield + + await broker.shutdown() + + + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() From ad4854df007baf1d191420c50cd078ae661fc7d2 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Mon, 4 Aug 2025 17:23:18 -0400 Subject: [PATCH 2/5] Yakifo/amqtt#73 : working example of using aiohttp server to receive the websocket connection --- amqtt/broker.py | 2 +- amqtt/mqtt/protocol/handler.py | 5 +++++ samples/web.py | 18 ++++++++++++------ tests/test_paho.py | 26 ++++++++++++++++++++++---- 4 files changed, 40 insertions(+), 11 deletions(-) diff --git a/amqtt/broker.py b/amqtt/broker.py index 3bbc8c0..028af12 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -99,7 +99,7 @@ class OtherServer(Server): async def acquire_connection(self) -> None: pass - async def release_connection(self) -> None: + def release_connection(self) -> None: pass async def close_instance(self) -> None: diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index 66bc1cb..8c53dc1 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -1,4 +1,5 @@ import asyncio +import traceback try: from asyncio import InvalidStateError, QueueFull, QueueShutDown @@ -535,6 +536,10 @@ class ProtocolHandler(Generic[C]): self.handle_read_timeout() except NoDataError: self.logger.debug(f"{self.session.client_id} No data available") + except RuntimeError: + self.logger.debug(f"{self.session.client_id} websocket closed") + traceback.print_exc() + break except Exception as e: # noqa: BLE001 self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}") break diff --git a/samples/web.py b/samples/web.py index e54d4e5..5f8bc68 100644 --- a/samples/web.py +++ b/samples/web.py @@ -1,4 +1,5 @@ import asyncio +import io import logging import aiohttp @@ -19,16 +20,18 @@ class AIOWebSocketsReader(ReaderAdapter): def __init__(self, ws: web.WebSocketResponse): self.ws = ws self.buffer = bytearray() - self.closed = False async def read(self, n: int = -1) -> bytes: print(f"attempting to read {n} bytes") - while not self.buffer and not self.closed or len(self.buffer) < n: + while not self.buffer or len(self.buffer) < n: + if self.ws.closed: + raise BrokenPipeError() msg = await self.ws.receive() if msg.type == aiohttp.WSMsgType.BINARY: self.buffer.extend(msg.data) elif msg.type == aiohttp.WSMsgType.CLOSE: - self.closed = True + print("received a close message!") + break print(f"buffer size: {len(self.buffer)}") if n == -1: result = bytes(self.buffer) @@ -47,14 +50,17 @@ class AIOWebSocketsWriter(WriterAdapter): def __init__(self, ws: web.WebSocketResponse): super().__init__() self.ws = ws + self._stream = io.BytesIO(b"") def write(self, data: bytes) -> None: print(f"broker wants to write data: {data}") - self.ws.send_bytes(data) - + self._stream.write(data) async def drain(self) -> None: - pass + data = self._stream.getvalue() + if data and len(data): + await self.ws.send_bytes(data) + self._stream = io.BytesIO(b"") def get_peer_info(self) -> tuple[str, int] | None: return "aiohttp", 1234567 diff --git a/tests/test_paho.py b/tests/test_paho.py index be53b2a..0e4fb3c 100644 --- a/tests/test_paho.py +++ b/tests/test_paho.py @@ -4,7 +4,7 @@ import random from unittest.mock import MagicMock, call, patch import pytest -from paho.mqtt import client as mqtt_client +from paho.mqtt import client as paho_client from amqtt.events import BrokerEvents from amqtt.client import MQTTClient @@ -40,7 +40,7 @@ async def test_paho_connect(broker, mock_plugin_manager): assert rc == 0, f"Disconnect failed with result code {rc}" test_complete.set() - test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id) + test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id) test_client.enable_logger(paho_logger) test_client.on_connect = on_connect @@ -76,7 +76,7 @@ async def test_paho_qos1(broker, mock_plugin_manager): port = 1883 client_id = f'python-mqtt-{random.randint(0, 1000)}' - test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id) + test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id) test_client.enable_logger(paho_logger) test_client.connect(host, port) @@ -107,7 +107,7 @@ async def test_paho_qos2(broker, mock_plugin_manager): port = 1883 client_id = f'python-mqtt-{random.randint(0, 1000)}' - test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id) + test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id) test_client.enable_logger(paho_logger) test_client.connect(host, port) @@ -124,3 +124,21 @@ async def test_paho_qos2(broker, mock_plugin_manager): assert message.data == b"test message" await sub_client.disconnect() await asyncio.sleep(0.1) + +async def test_paho_ws(): + client_id = 'websocket_client_1' + test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2, + transport='websockets', + client_id=client_id) + + test_client.ws_set_options('/ws') + + test_client.connect('localhost', 8080) + test_client.loop_start() + await asyncio.sleep(0.1) + test_client.publish("/qos2", "test message", qos=2) + test_client.publish("/qos2", "test message", qos=2) + test_client.publish("/qos2", "test message", qos=2) + test_client.publish("/qos2", "test message", qos=2) + await asyncio.sleep(0.1) + test_client.loop_stop() \ No newline at end of file From c1625e264e1e130956ffebf8e09b5b5042130c07 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Mon, 4 Aug 2025 17:47:29 -0400 Subject: [PATCH 3/5] Yakifo/amqtt#73 : removal of wip code --- amqtt/broker.py | 32 +++++++++++++---------------- amqtt/codecs_amqtt.py | 2 -- amqtt/contexts.py | 8 +++++++- samples/web.py | 47 ++++++------------------------------------- 4 files changed, 27 insertions(+), 62 deletions(-) diff --git a/amqtt/broker.py b/amqtt/broker.py index 028af12..1a797d6 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -10,9 +10,6 @@ import ssl import time from typing import Any, ClassVar, TypeAlias -from aiohttp.pytest_plugin import AiohttpServer -from aiohttp.web_ws import WebSocketResponse -from pygments.token import Other from transitions import Machine, MachineError import websockets.asyncio.server from websockets.asyncio.server import ServerConnection @@ -25,7 +22,7 @@ from amqtt.adapters import ( WebSocketsWriter, WriterAdapter, ) -from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig +from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session @@ -39,7 +36,7 @@ from .plugins.manager import PluginManager _BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None] # Default port numbers -DEFAULT_PORTS = {"tcp": 1883, "ws": 8883, 'aiohttp': 8080} +DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 @@ -92,7 +89,7 @@ class Server: await self.instance.wait_closed() -class OtherServer(Server): +class ExternalServer(Server): def __init__(self): super().__init__('aiohttp', None) @@ -260,19 +257,20 @@ class Broker: max_connections = listener.get("max_connections", -1) ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None - try: - address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) - except ValueError as e: - msg = f"Invalid port value in bind value: {listener['bind']}" - raise BrokerError(msg) from e - - if listener["type"] == 'aiohttp': - self._servers[listener_name] = OtherServer() + if listener.type == ListenerType.EXTERNAL: + self.logger.info(f"External listener exists for '{listener_name}' ") + self._servers[listener_name] = ExternalServer() else: + try: + address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) + except ValueError as e: + msg = f"Invalid port value in bind value: {listener['bind']}" + raise BrokerError(msg) from e + instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) self._servers[listener_name] = Server(listener_name, instance, max_connections) - self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") + self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") @staticmethod def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext: @@ -301,7 +299,7 @@ class Broker: address: str | None, port: int, ssl_context: ssl.SSLContext | None, - ) -> asyncio.Server | websockets.asyncio.server.Server | OtherServer: + ) -> asyncio.Server | websockets.asyncio.server.Server: """Create a server instance for a listener.""" if listener_type == "tcp": return await asyncio.start_server( @@ -319,8 +317,6 @@ class Broker: ssl=ssl_context, subprotocols=[websockets.Subprotocol("mqtt")], ) - if listener_type == "aiohttp": - return OtherServer() msg = f"Unsupported listener type: {listener_type}" raise BrokerError(msg) diff --git a/amqtt/codecs_amqtt.py b/amqtt/codecs_amqtt.py index 20bb536..1db2d9c 100644 --- a/amqtt/codecs_amqtt.py +++ b/amqtt/codecs_amqtt.py @@ -75,10 +75,8 @@ async def decode_string(reader: ReaderAdapter | asyncio.StreamReader) -> str: length_bytes = await read_or_raise(reader, 2) if len(length_bytes) < 1: raise ZeroLengthReadError - print(f"attempting to decode {length_bytes}") str_length = unpack("!H", length_bytes)[0] if str_length: - print(f"reading {str_length} bytes") byte_str = await read_or_raise(reader, str_length) try: return byte_str.decode(encoding="utf-8") diff --git a/amqtt/contexts.py b/amqtt/contexts.py index 1bd733c..b397d62 100644 --- a/amqtt/contexts.py +++ b/amqtt/contexts.py @@ -44,7 +44,7 @@ class ListenerType(StrEnum): TCP = "tcp" WS = "ws" - AIOHTTP = "aiohttp" + EXTERNAL = "external" def __repr__(self) -> str: """Display the string value, instead of the enum member.""" @@ -115,6 +115,8 @@ class ListenerConfig(Dictable): certificates needed to establish the certificate's authenticity.)""" keyfile: str | Path | None = None """Full path to file in PEM format containing the server's private key.""" + reader: str | None = None + writer: str | None = None def __post_init__(self) -> None: """Check config for errors and transform fields for easier use.""" @@ -126,6 +128,10 @@ class ListenerConfig(Dictable): if isinstance(getattr(self, fn), str): setattr(self, fn, Path(getattr(self, fn))) + # if self.type == ListenerType.EXTERNAL and not all([self.reader, self.writer]): + # msg = "external type requires specifying reader, writer and server classes" + # raise ValueError(msg) + def apply(self, other: "ListenerConfig") -> None: """Apply the field from 'other', if 'self' field is default.""" for f in fields(self): diff --git a/samples/web.py b/samples/web.py index 5f8bc68..552953f 100644 --- a/samples/web.py +++ b/samples/web.py @@ -14,15 +14,12 @@ from amqtt.errors import BrokerError async def hello(request): return web.Response(text="Hello, world") - - class AIOWebSocketsReader(ReaderAdapter): def __init__(self, ws: web.WebSocketResponse): self.ws = ws self.buffer = bytearray() async def read(self, n: int = -1) -> bytes: - print(f"attempting to read {n} bytes") while not self.buffer or len(self.buffer) < n: if self.ws.closed: raise BrokenPipeError() @@ -30,16 +27,15 @@ class AIOWebSocketsReader(ReaderAdapter): if msg.type == aiohttp.WSMsgType.BINARY: self.buffer.extend(msg.data) elif msg.type == aiohttp.WSMsgType.CLOSE: - print("received a close message!") - break - print(f"buffer size: {len(self.buffer)}") + raise BrokenPipeError() + if n == -1: result = bytes(self.buffer) self.buffer.clear() else: result = self.buffer[:n] del self.buffer[:n] - print(f"bytes: {result}") + return result def feed_eof(self) -> None: @@ -53,7 +49,6 @@ class AIOWebSocketsWriter(WriterAdapter): self._stream = io.BytesIO(b"") def write(self, data: bytes) -> None: - print(f"broker wants to write data: {data}") self._stream.write(data) async def drain(self) -> None: @@ -63,7 +58,7 @@ class AIOWebSocketsWriter(WriterAdapter): self._stream = io.BytesIO(b"") def get_peer_info(self) -> tuple[str, int] | None: - return "aiohttp", 1234567 + return "external", 0 async def close(self) -> None: pass @@ -73,39 +68,9 @@ async def websocket_handler(request): print() ws = web.WebSocketResponse(protocols=['mqtt',]) await ws.prepare(request) - # - # readQ = asyncio.Queue() - # writeQ = asyncio.Queue() - # - # async def receiver(): - # async for msg in ws: - # match msg.type: - # case aiohttp.WSMsgType.BINARY: - # readQ.put_nowait(msg.data) - # case _: - # return - # - # async def send_items(): - # while not ws.closed: - # if not writeQ.empty(): - # item = await writeQ.get() - # await ws.send_bytes(item) - # - # await asyncio.create_task(send_items()) - # + b: Broker = request.app['broker'] await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws)) - - # async for msg in ws: - # print(f"ws: {msg}") - - - - # - # elif msg.type == aiohttp.WSMsgType.ERROR: - # print('ws connection closed with exception %s' % - # ws.exception()) - print('websocket connection closed') return ws @@ -129,7 +94,7 @@ async def run_broker(_app): cfg = BrokerConfig( listeners={ 'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'), - 'aiohttp': ListenerConfig(type=ListenerType.AIOHTTP), + 'aiohttp': ListenerConfig(type=ListenerType.EXTERNAL), } ) From 0141cddeebb9d55f8290af99d8a5a69a9916fc74 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Mon, 4 Aug 2025 19:19:51 -0400 Subject: [PATCH 4/5] Yakifo/amqtt#73 : refinement of external websocket server --- amqtt/broker.py | 3 +++ samples/web.py | 65 ++++++++++++++++++++++++++++++++-------------- tests/test_paho.py | 53 ++++++++++++++++++++++++++++++++----- 3 files changed, 95 insertions(+), 26 deletions(-) diff --git a/amqtt/broker.py b/amqtt/broker.py index 1a797d6..7bcbeec 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -403,6 +403,9 @@ class Broker: async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None: await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) + async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None: + await self._client_connected(listener_name, reader, writer) + async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: """Handle a new client connection.""" server = self._servers.get(listener_name) diff --git a/samples/web.py b/samples/web.py index 552953f..787ccf4 100644 --- a/samples/web.py +++ b/samples/web.py @@ -8,25 +8,38 @@ from aiohttp import web from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.broker import Broker from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType -from amqtt.errors import BrokerError + + +logger = logging.getLogger(__name__) async def hello(request): return web.Response(text="Hello, world") -class AIOWebSocketsReader(ReaderAdapter): +class WebSocketResponseReader(ReaderAdapter): + """Interface to allow mqtt broker to read from an aiohttp websocket connection.""" + def __init__(self, ws: web.WebSocketResponse): self.ws = ws self.buffer = bytearray() async def read(self, n: int = -1) -> bytes: + """ + Raises: + BrokerPipeError : if reading on a closed websocket connection + """ + # continue until buffer contains at least the amount of data being requested while not self.buffer or len(self.buffer) < n: if self.ws.closed: raise BrokenPipeError() - msg = await self.ws.receive() - if msg.type == aiohttp.WSMsgType.BINARY: - self.buffer.extend(msg.data) - elif msg.type == aiohttp.WSMsgType.CLOSE: + try: + async with asyncio.timeout(0.5): + msg = await self.ws.receive() + if msg.type == aiohttp.WSMsgType.BINARY: + self.buffer.extend(msg.data) + elif msg.type == aiohttp.WSMsgType.CLOSE: + raise BrokenPipeError() + except asyncio.TimeoutError: raise BrokenPipeError() if n == -1: @@ -41,11 +54,24 @@ class AIOWebSocketsReader(ReaderAdapter): def feed_eof(self) -> None: pass -class AIOWebSocketsWriter(WriterAdapter): +class WebSocketResponseWriter(WriterAdapter): + """Interface to allow mqtt broker to write to an aiohttp websocket connection.""" - def __init__(self, ws: web.WebSocketResponse): + def __init__(self, ws: web.WebSocketResponse, request: web.Request): super().__init__() self.ws = ws + + # needed for `get_peer_info` + # https://docs.python.org/3/library/socket.html#socket.socket.getpeername + peer_name = request.transport.get_extra_info('peername') + if peer_name is not None: + self.client_ip, self.port = peer_name[0:2] + else: + self.client_ip, self.port = request.remote, 0 + + # interpret AF_INET6 + self.client_ip = "localhost" if self.client_ip == "::1" else self.client_ip + self._stream = io.BytesIO(b"") def write(self, data: bytes) -> None: @@ -58,21 +84,25 @@ class AIOWebSocketsWriter(WriterAdapter): self._stream = io.BytesIO(b"") def get_peer_info(self) -> tuple[str, int] | None: - return "external", 0 + return self.client_ip, self.port async def close(self) -> None: pass -async def websocket_handler(request): - print() +async def websocket_handler(request: web.Request) -> web.StreamResponse: + + # respond to the websocket request with the 'mqtt' protocol ws = web.WebSocketResponse(protocols=['mqtt',]) await ws.prepare(request) + # access the broker created when the server started and notify the broker of this new connection b: Broker = request.app['broker'] - await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws)) - print('websocket connection closed') + # send/receive data to the websocket. must pass the name of the externalized listener in the broker config + await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), 'myAIOHttp') + + logger.debug('websocket connection closed') return ws @@ -89,17 +119,16 @@ def main(): async def run_broker(_app): + """https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks""" loop = asyncio.get_event_loop() cfg = BrokerConfig( listeners={ - 'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'), - 'aiohttp': ListenerConfig(type=ListenerType.EXTERNAL), + 'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'), + 'myAIOHttp': ListenerConfig(type=ListenerType.EXTERNAL), } ) - - broker = Broker(config=cfg, loop=loop) _app['broker'] = broker await broker.start() @@ -109,8 +138,6 @@ async def run_broker(_app): await broker.shutdown() - - if __name__ == '__main__': logging.basicConfig(level=logging.INFO) main() diff --git a/tests/test_paho.py b/tests/test_paho.py index 0e4fb3c..9a254a0 100644 --- a/tests/test_paho.py +++ b/tests/test_paho.py @@ -1,11 +1,19 @@ import asyncio import logging import random +import threading +import time +from pathlib import Path +from threading import Thread +from typing import Any from unittest.mock import MagicMock, call, patch import pytest +import yaml from paho.mqtt import client as paho_client +from yaml import Loader +from amqtt.broker import Broker from amqtt.events import BrokerEvents from amqtt.client import MQTTClient from amqtt.mqtt.constants import QOS_1, QOS_2 @@ -125,20 +133,51 @@ async def test_paho_qos2(broker, mock_plugin_manager): await sub_client.disconnect() await asyncio.sleep(0.1) -async def test_paho_ws(): + + +def run_paho_client(flag): client_id = 'websocket_client_1' + logging.info("creating paho client") test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2, transport='websockets', client_id=client_id) - test_client.ws_set_options('/ws') - - test_client.connect('localhost', 8080) + test_client.ws_set_options('') + logging.info("client connecting...") + test_client.connect('127.0.0.1', 8080) + logging.info("starting loop") test_client.loop_start() - await asyncio.sleep(0.1) + logging.info("client connected") + time.sleep(1) + logging.info("sending messages") test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2) test_client.publish("/qos2", "test message", qos=2) - await asyncio.sleep(0.1) - test_client.loop_stop() \ No newline at end of file + time.sleep(1) + test_client.loop_stop() + test_client.disconnect() + flag.set() + + +@pytest.mark.asyncio +async def test_paho_ws(): + path = Path('docs_test/test.amqtt.local.yaml') + with path.open() as f: + cfg: dict[str, Any] = yaml.load(f, Loader=Loader) + logger.warning(cfg) + broker = Broker(config=cfg) + await broker.start() + + # python websockets and paho mqtt don't play well with each other in the same thread + flag = threading.Event() + thread = Thread(target=run_paho_client, args=(flag,)) + thread.start() + + await asyncio.sleep(5) + thread.join(1) + + assert flag.is_set(), "paho thread didn't execute completely" + + logging.info("client disconnected") + await broker.shutdown() From 6ff7c5a87abddeb18239c6ff5816a30f08607851 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Mon, 4 Aug 2025 19:31:41 -0400 Subject: [PATCH 5/5] Yakifo/amqtt#73 : adding test case for external http server integration. comments and documentation. --- amqtt/broker.py | 18 +++-- amqtt/contexts.py | 4 - amqtt/mqtt/protocol/handler.py | 5 -- .../{web.py => http_server_integration.py} | 77 ++++++++++++++----- tests/test_samples.py | 30 +++++++- 5 files changed, 95 insertions(+), 39 deletions(-) rename samples/{web.py => http_server_integration.py} (55%) diff --git a/amqtt/broker.py b/amqtt/broker.py index 7bcbeec..d2e8bd1 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -52,6 +52,8 @@ class RetainedApplicationMessage(ApplicationMessage): class Server: + """Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle.""" + def __init__( self, listener_name: str, @@ -90,8 +92,10 @@ class Server: class ExternalServer(Server): - def __init__(self): - super().__init__('aiohttp', None) + """For external listeners, the connection lifecycle is handled by that implementation so these are no-ops.""" + + def __init__(self) -> None: + super().__init__("aiohttp", None) # type: ignore[arg-type] async def acquire_connection(self) -> None: pass @@ -104,10 +108,7 @@ class ExternalServer(Server): class BrokerContext(BaseContext): - """BrokerContext is used as the context passed to plugins interacting with the broker. - - It act as an adapter to broker services from plugins developed for HBMQTT broker. - """ + """BrokerContext is used as the context passed to plugins interacting with the broker.""" def __init__(self, broker: "Broker") -> None: super().__init__() @@ -257,10 +258,14 @@ class Broker: max_connections = listener.get("max_connections", -1) ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None + # for listeners which are external, don't need to create a server if listener.type == ListenerType.EXTERNAL: + + # broker still needs to associate a new connection to the listener self.logger.info(f"External listener exists for '{listener_name}' ") self._servers[listener_name] = ExternalServer() else: + # for tcp and websockets, start servers to listen for inbound connections try: address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) except ValueError as e: @@ -404,6 +409,7 @@ class Broker: await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None: + """Engage the broker in handling the data stream to/from an established connection.""" await self._client_connected(listener_name, reader, writer) async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: diff --git a/amqtt/contexts.py b/amqtt/contexts.py index b397d62..66d7050 100644 --- a/amqtt/contexts.py +++ b/amqtt/contexts.py @@ -128,10 +128,6 @@ class ListenerConfig(Dictable): if isinstance(getattr(self, fn), str): setattr(self, fn, Path(getattr(self, fn))) - # if self.type == ListenerType.EXTERNAL and not all([self.reader, self.writer]): - # msg = "external type requires specifying reader, writer and server classes" - # raise ValueError(msg) - def apply(self, other: "ListenerConfig") -> None: """Apply the field from 'other', if 'self' field is default.""" for f in fields(self): diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index 8c53dc1..66bc1cb 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -1,5 +1,4 @@ import asyncio -import traceback try: from asyncio import InvalidStateError, QueueFull, QueueShutDown @@ -536,10 +535,6 @@ class ProtocolHandler(Generic[C]): self.handle_read_timeout() except NoDataError: self.logger.debug(f"{self.session.client_id} No data available") - except RuntimeError: - self.logger.debug(f"{self.session.client_id} websocket closed") - traceback.print_exc() - break except Exception as e: # noqa: BLE001 self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}") break diff --git a/samples/web.py b/samples/http_server_integration.py similarity index 55% rename from samples/web.py rename to samples/http_server_integration.py index 787ccf4..243d4ba 100644 --- a/samples/web.py +++ b/samples/http_server_integration.py @@ -8,12 +8,14 @@ from aiohttp import web from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.broker import Broker from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType - +from amqtt.errors import ConnectError logger = logging.getLogger(__name__) +MQTT_LISTENER_NAME = "myMqttListener" async def hello(request): + """get request handler""" return web.Response(text="Hello, world") class WebSocketResponseReader(ReaderAdapter): @@ -25,26 +27,34 @@ class WebSocketResponseReader(ReaderAdapter): async def read(self, n: int = -1) -> bytes: """ + read 'n' bytes from the datastream, if < 0 read all available bytes + Raises: BrokerPipeError : if reading on a closed websocket connection """ # continue until buffer contains at least the amount of data being requested while not self.buffer or len(self.buffer) < n: + # if the websocket is closed if self.ws.closed: raise BrokenPipeError() + try: - async with asyncio.timeout(0.5): - msg = await self.ws.receive() - if msg.type == aiohttp.WSMsgType.BINARY: - self.buffer.extend(msg.data) - elif msg.type == aiohttp.WSMsgType.CLOSE: - raise BrokenPipeError() + # read from stream + msg = await asyncio.wait_for(self.ws.receive(), timeout=0.5) + # mqtt streams should always be binary... + if msg.type == aiohttp.WSMsgType.BINARY: + self.buffer.extend(msg.data) + elif msg.type == aiohttp.WSMsgType.CLOSE: + raise BrokenPipeError() + except asyncio.TimeoutError: raise BrokenPipeError() + # return all bytes currently in the buffer if n == -1: result = bytes(self.buffer) self.buffer.clear() + # return the requested number of bytes from the buffer else: result = self.buffer[:n] del self.buffer[:n] @@ -75,9 +85,11 @@ class WebSocketResponseWriter(WriterAdapter): self._stream = io.BytesIO(b"") def write(self, data: bytes) -> None: + """Add bytes to stream buffer.""" self._stream.write(data) async def drain(self) -> None: + """Send the collected bytes in the buffer to the websocket connection.""" data = self._stream.getvalue() if data and len(data): await self.ws.send_bytes(data) @@ -87,54 +99,79 @@ class WebSocketResponseWriter(WriterAdapter): return self.client_ip, self.port async def close(self) -> None: + # no clean up needed, stream will be gc along with instance pass +async def mqtt_websocket_handler(request: web.Request) -> web.StreamResponse: -async def websocket_handler(request: web.Request) -> web.StreamResponse: - - # respond to the websocket request with the 'mqtt' protocol + # establish connection by responding to the websocket request with the 'mqtt' protocol ws = web.WebSocketResponse(protocols=['mqtt',]) await ws.prepare(request) - # access the broker created when the server started and notify the broker of this new connection + # access the broker created when the server started b: Broker = request.app['broker'] - # send/receive data to the websocket. must pass the name of the externalized listener in the broker config - await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), 'myAIOHttp') + # hand-off the websocket data stream to the broker for handling + # `listener_name` is the same name of the externalized listener in the broker config + await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), MQTT_LISTENER_NAME) logger.debug('websocket connection closed') return ws -def main(): +async def websocket_handler(request: web.Request) -> web.StreamResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + async for msg in ws: + logging.info(msg) + + logging.info("websocket connection closed") + return ws + +def main(): + # create an `aiohttp` server + lp = asyncio.get_event_loop() app = web.Application() app.add_routes( [ - web.get('/', hello), - web.get('/ws', websocket_handler) + web.get('/', hello), # http get request/response route + web.get('/ws', websocket_handler), # standard websocket handler + web.get('/mqtt', mqtt_websocket_handler), # websocket handler for mqtt connections ]) + # create background task for running the `amqtt` broker app.cleanup_ctx.append(run_broker) - web.run_app(app) + + # make sure that both `aiohttp` server and `amqtt` broker run in the same loop + # so the server can hand off the connection to the broker (prevents attached-to-a-different-loop `RuntimeError`) + web.run_app(app, loop=lp) async def run_broker(_app): - """https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks""" - loop = asyncio.get_event_loop() + """App init function to start (and then shutdown) the `amqtt` broker. + https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks""" + # standard TCP connection as well as an externalized-listener cfg = BrokerConfig( listeners={ 'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'), - 'myAIOHttp': ListenerConfig(type=ListenerType.EXTERNAL), + MQTT_LISTENER_NAME: ListenerConfig(type=ListenerType.EXTERNAL), } ) + # make sure the `Broker` runs in the same loop as the aiohttp server + loop = asyncio.get_event_loop() broker = Broker(config=cfg, loop=loop) + + # store broker instance so that incoming requests can hand off processing of a datastream _app['broker'] = broker + # start the broker await broker.start() + # pass control back to web app yield + # closing activities await broker.shutdown() diff --git a/tests/test_samples.py b/tests/test_samples.py index 1e46046..fca99f3 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -2,14 +2,15 @@ import asyncio import logging import signal import subprocess + +from multiprocessing import Process from pathlib import Path +from samples.http_server_integration import main as http_server_main import pytest from amqtt.broker import Broker -from samples.client_publish import __main__ as client_publish_main -from samples.client_subscribe import __main__ as client_subscribe_main -from samples.client_keepalive import __main__ as client_keepalive_main +from amqtt.client import MQTTClient from samples.broker_acl import config as broker_acl_config from samples.broker_taboo import config as broker_taboo_config @@ -275,4 +276,25 @@ async def test_client_subscribe_plugin_taboo(): assert "ERROR" not in stderr.decode("utf-8") assert "Exception" not in stderr.decode("utf-8") - await broker.shutdown() \ No newline at end of file + await broker.shutdown() + + +@pytest.fixture +def external_http_server(): + p = Process(target=http_server_main) + p.start() + yield p + p.terminate() + + +@pytest.mark.asyncio +async def test_external_http_server(external_http_server): + + await asyncio.sleep(1) + client = MQTTClient(config={'auto_reconnect': False}) + await client.connect("ws://127.0.0.1:8080/mqtt") + assert client.session is not None + await client.publish("my/topic", b'test message') + await client.disconnect() + # Send the interrupt signal + await asyncio.sleep(1)