diff --git a/amqtt/broker.py b/amqtt/broker.py index 6ddbcb7..f62a2cc 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -22,7 +22,7 @@ from amqtt.adapters import ( WebSocketsWriter, WriterAdapter, ) -from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig +from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session @@ -52,6 +52,8 @@ class RetainedApplicationMessage(ApplicationMessage): class Server: + """Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle.""" + def __init__( self, listener_name: str, @@ -89,11 +91,24 @@ class Server: await self.instance.wait_closed() -class BrokerContext(BaseContext): - """BrokerContext is used as the context passed to plugins interacting with the broker. +class ExternalServer(Server): + """For external listeners, the connection lifecycle is handled by that implementation so these are no-ops.""" - It act as an adapter to broker services from plugins developed for HBMQTT broker. - """ + def __init__(self) -> None: + super().__init__("aiohttp", None) # type: ignore[arg-type] + + async def acquire_connection(self) -> None: + pass + + def release_connection(self) -> None: + pass + + async def close_instance(self) -> None: + pass + + +class BrokerContext(BaseContext): + """BrokerContext is used as the context passed to plugins interacting with the broker.""" def __init__(self, broker: "Broker") -> None: super().__init__() @@ -243,16 +258,24 @@ class Broker: max_connections = listener.get("max_connections", -1) ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None - try: - address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) - except ValueError as e: - msg = f"Invalid port value in bind value: {listener['bind']}" - raise BrokerError(msg) from e + # for listeners which are external, don't need to create a server + if listener.type == ListenerType.EXTERNAL: - instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) - self._servers[listener_name] = Server(listener_name, instance, max_connections) + # broker still needs to associate a new connection to the listener + self.logger.info(f"External listener exists for '{listener_name}' ") + self._servers[listener_name] = ExternalServer() + else: + # for tcp and websockets, start servers to listen for inbound connections + try: + address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) + except ValueError as e: + msg = f"Invalid port value in bind value: {listener['bind']}" + raise BrokerError(msg) from e - self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") + instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) + self._servers[listener_name] = Server(listener_name, instance, max_connections) + + self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") @staticmethod def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext: @@ -385,6 +408,10 @@ class Broker: async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None: await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) + async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None: + """Engage the broker in handling the data stream to/from an established connection.""" + await self._client_connected(listener_name, reader, writer) + async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: """Handle a new client connection.""" server = self._servers.get(listener_name) diff --git a/amqtt/contexts.py b/amqtt/contexts.py index afd5dc1..bbff1a0 100644 --- a/amqtt/contexts.py +++ b/amqtt/contexts.py @@ -44,6 +44,7 @@ class ListenerType(StrEnum): TCP = "tcp" WS = "ws" + EXTERNAL = "external" def __repr__(self) -> str: """Display the string value, instead of the enum member.""" @@ -114,6 +115,8 @@ class ListenerConfig(Dictable): certificates needed to establish the certificate's authenticity.)""" keyfile: str | Path | None = None """Full path to file in PEM format containing the server's private key.""" + reader: str | None = None + writer: str | None = None def __post_init__(self) -> None: """Check config for errors and transform fields for easier use.""" diff --git a/samples/http_server_integration.py b/samples/http_server_integration.py new file mode 100644 index 0000000..243d4ba --- /dev/null +++ b/samples/http_server_integration.py @@ -0,0 +1,180 @@ +import asyncio +import io +import logging + +import aiohttp +from aiohttp import web + +from amqtt.adapters import ReaderAdapter, WriterAdapter +from amqtt.broker import Broker +from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType +from amqtt.errors import ConnectError + +logger = logging.getLogger(__name__) + +MQTT_LISTENER_NAME = "myMqttListener" + +async def hello(request): + """get request handler""" + return web.Response(text="Hello, world") + +class WebSocketResponseReader(ReaderAdapter): + """Interface to allow mqtt broker to read from an aiohttp websocket connection.""" + + def __init__(self, ws: web.WebSocketResponse): + self.ws = ws + self.buffer = bytearray() + + async def read(self, n: int = -1) -> bytes: + """ + read 'n' bytes from the datastream, if < 0 read all available bytes + + Raises: + BrokerPipeError : if reading on a closed websocket connection + """ + # continue until buffer contains at least the amount of data being requested + while not self.buffer or len(self.buffer) < n: + # if the websocket is closed + if self.ws.closed: + raise BrokenPipeError() + + try: + # read from stream + msg = await asyncio.wait_for(self.ws.receive(), timeout=0.5) + # mqtt streams should always be binary... + if msg.type == aiohttp.WSMsgType.BINARY: + self.buffer.extend(msg.data) + elif msg.type == aiohttp.WSMsgType.CLOSE: + raise BrokenPipeError() + + except asyncio.TimeoutError: + raise BrokenPipeError() + + # return all bytes currently in the buffer + if n == -1: + result = bytes(self.buffer) + self.buffer.clear() + # return the requested number of bytes from the buffer + else: + result = self.buffer[:n] + del self.buffer[:n] + + return result + + def feed_eof(self) -> None: + pass + +class WebSocketResponseWriter(WriterAdapter): + """Interface to allow mqtt broker to write to an aiohttp websocket connection.""" + + def __init__(self, ws: web.WebSocketResponse, request: web.Request): + super().__init__() + self.ws = ws + + # needed for `get_peer_info` + # https://docs.python.org/3/library/socket.html#socket.socket.getpeername + peer_name = request.transport.get_extra_info('peername') + if peer_name is not None: + self.client_ip, self.port = peer_name[0:2] + else: + self.client_ip, self.port = request.remote, 0 + + # interpret AF_INET6 + self.client_ip = "localhost" if self.client_ip == "::1" else self.client_ip + + self._stream = io.BytesIO(b"") + + def write(self, data: bytes) -> None: + """Add bytes to stream buffer.""" + self._stream.write(data) + + async def drain(self) -> None: + """Send the collected bytes in the buffer to the websocket connection.""" + data = self._stream.getvalue() + if data and len(data): + await self.ws.send_bytes(data) + self._stream = io.BytesIO(b"") + + def get_peer_info(self) -> tuple[str, int] | None: + return self.client_ip, self.port + + async def close(self) -> None: + # no clean up needed, stream will be gc along with instance + pass + +async def mqtt_websocket_handler(request: web.Request) -> web.StreamResponse: + + # establish connection by responding to the websocket request with the 'mqtt' protocol + ws = web.WebSocketResponse(protocols=['mqtt',]) + await ws.prepare(request) + + # access the broker created when the server started + b: Broker = request.app['broker'] + + # hand-off the websocket data stream to the broker for handling + # `listener_name` is the same name of the externalized listener in the broker config + await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), MQTT_LISTENER_NAME) + + logger.debug('websocket connection closed') + return ws + + +async def websocket_handler(request: web.Request) -> web.StreamResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + logging.info(msg) + + logging.info("websocket connection closed") + return ws + +def main(): + # create an `aiohttp` server + lp = asyncio.get_event_loop() + app = web.Application() + app.add_routes( + [ + web.get('/', hello), # http get request/response route + web.get('/ws', websocket_handler), # standard websocket handler + web.get('/mqtt', mqtt_websocket_handler), # websocket handler for mqtt connections + ]) + # create background task for running the `amqtt` broker + app.cleanup_ctx.append(run_broker) + + # make sure that both `aiohttp` server and `amqtt` broker run in the same loop + # so the server can hand off the connection to the broker (prevents attached-to-a-different-loop `RuntimeError`) + web.run_app(app, loop=lp) + + +async def run_broker(_app): + """App init function to start (and then shutdown) the `amqtt` broker. + https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks""" + + # standard TCP connection as well as an externalized-listener + cfg = BrokerConfig( + listeners={ + 'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'), + MQTT_LISTENER_NAME: ListenerConfig(type=ListenerType.EXTERNAL), + } + ) + + # make sure the `Broker` runs in the same loop as the aiohttp server + loop = asyncio.get_event_loop() + broker = Broker(config=cfg, loop=loop) + + # store broker instance so that incoming requests can hand off processing of a datastream + _app['broker'] = broker + # start the broker + await broker.start() + + # pass control back to web app + yield + + # closing activities + await broker.shutdown() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() diff --git a/tests/test_paho.py b/tests/test_paho.py index be53b2a..9a254a0 100644 --- a/tests/test_paho.py +++ b/tests/test_paho.py @@ -1,11 +1,19 @@ import asyncio import logging import random +import threading +import time +from pathlib import Path +from threading import Thread +from typing import Any from unittest.mock import MagicMock, call, patch import pytest -from paho.mqtt import client as mqtt_client +import yaml +from paho.mqtt import client as paho_client +from yaml import Loader +from amqtt.broker import Broker from amqtt.events import BrokerEvents from amqtt.client import MQTTClient from amqtt.mqtt.constants import QOS_1, QOS_2 @@ -40,7 +48,7 @@ async def test_paho_connect(broker, mock_plugin_manager): assert rc == 0, f"Disconnect failed with result code {rc}" test_complete.set() - test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id) + test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id) test_client.enable_logger(paho_logger) test_client.on_connect = on_connect @@ -76,7 +84,7 @@ async def test_paho_qos1(broker, mock_plugin_manager): port = 1883 client_id = f'python-mqtt-{random.randint(0, 1000)}' - test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id) + test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id) test_client.enable_logger(paho_logger) test_client.connect(host, port) @@ -107,7 +115,7 @@ async def test_paho_qos2(broker, mock_plugin_manager): port = 1883 client_id = f'python-mqtt-{random.randint(0, 1000)}' - test_client = mqtt_client.Client(mqtt_client.CallbackAPIVersion.VERSION2, client_id=client_id) + test_client = paho_client.Client(paho_client.CallbackAPIVersion.VERSION2, client_id=client_id) test_client.enable_logger(paho_logger) test_client.connect(host, port) @@ -124,3 +132,52 @@ async def test_paho_qos2(broker, mock_plugin_manager): assert message.data == b"test message" await sub_client.disconnect() await asyncio.sleep(0.1) + + + +def run_paho_client(flag): + client_id = 'websocket_client_1' + logging.info("creating paho client") + test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2, + transport='websockets', + client_id=client_id) + + test_client.ws_set_options('') + logging.info("client connecting...") + test_client.connect('127.0.0.1', 8080) + logging.info("starting loop") + test_client.loop_start() + logging.info("client connected") + time.sleep(1) + logging.info("sending messages") + test_client.publish("/qos2", "test message", qos=2) + test_client.publish("/qos2", "test message", qos=2) + test_client.publish("/qos2", "test message", qos=2) + test_client.publish("/qos2", "test message", qos=2) + time.sleep(1) + test_client.loop_stop() + test_client.disconnect() + flag.set() + + +@pytest.mark.asyncio +async def test_paho_ws(): + path = Path('docs_test/test.amqtt.local.yaml') + with path.open() as f: + cfg: dict[str, Any] = yaml.load(f, Loader=Loader) + logger.warning(cfg) + broker = Broker(config=cfg) + await broker.start() + + # python websockets and paho mqtt don't play well with each other in the same thread + flag = threading.Event() + thread = Thread(target=run_paho_client, args=(flag,)) + thread.start() + + await asyncio.sleep(5) + thread.join(1) + + assert flag.is_set(), "paho thread didn't execute completely" + + logging.info("client disconnected") + await broker.shutdown() diff --git a/tests/test_samples.py b/tests/test_samples.py index 1e46046..fca99f3 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -2,14 +2,15 @@ import asyncio import logging import signal import subprocess + +from multiprocessing import Process from pathlib import Path +from samples.http_server_integration import main as http_server_main import pytest from amqtt.broker import Broker -from samples.client_publish import __main__ as client_publish_main -from samples.client_subscribe import __main__ as client_subscribe_main -from samples.client_keepalive import __main__ as client_keepalive_main +from amqtt.client import MQTTClient from samples.broker_acl import config as broker_acl_config from samples.broker_taboo import config as broker_taboo_config @@ -275,4 +276,25 @@ async def test_client_subscribe_plugin_taboo(): assert "ERROR" not in stderr.decode("utf-8") assert "Exception" not in stderr.decode("utf-8") - await broker.shutdown() \ No newline at end of file + await broker.shutdown() + + +@pytest.fixture +def external_http_server(): + p = Process(target=http_server_main) + p.start() + yield p + p.terminate() + + +@pytest.mark.asyncio +async def test_external_http_server(external_http_server): + + await asyncio.sleep(1) + client = MQTTClient(config={'auto_reconnect': False}) + await client.connect("ws://127.0.0.1:8080/mqtt") + assert client.session is not None + await client.publish("my/topic", b'test message') + await client.disconnect() + # Send the interrupt signal + await asyncio.sleep(1)