Yakifo/amqtt#73 : refinement of external websocket server

pull/283/head
Andrew Mirsky 2025-08-04 19:19:51 -04:00
rodzic c1625e264e
commit 0141cddeeb
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
3 zmienionych plików z 95 dodań i 26 usunięć

Wyświetl plik

@ -403,6 +403,9 @@ class Broker:
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
await self._client_connected(listener_name, reader, writer)
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
"""Handle a new client connection."""
server = self._servers.get(listener_name)

Wyświetl plik

@ -8,25 +8,38 @@ from aiohttp import web
from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.broker import Broker
from amqtt.contexts import BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import BrokerError
logger = logging.getLogger(__name__)
async def hello(request):
return web.Response(text="Hello, world")
class AIOWebSocketsReader(ReaderAdapter):
class WebSocketResponseReader(ReaderAdapter):
"""Interface to allow mqtt broker to read from an aiohttp websocket connection."""
def __init__(self, ws: web.WebSocketResponse):
self.ws = ws
self.buffer = bytearray()
async def read(self, n: int = -1) -> bytes:
"""
Raises:
BrokerPipeError : if reading on a closed websocket connection
"""
# continue until buffer contains at least the amount of data being requested
while not self.buffer or len(self.buffer) < n:
if self.ws.closed:
raise BrokenPipeError()
msg = await self.ws.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
self.buffer.extend(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSE:
try:
async with asyncio.timeout(0.5):
msg = await self.ws.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
self.buffer.extend(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSE:
raise BrokenPipeError()
except asyncio.TimeoutError:
raise BrokenPipeError()
if n == -1:
@ -41,11 +54,24 @@ class AIOWebSocketsReader(ReaderAdapter):
def feed_eof(self) -> None:
pass
class AIOWebSocketsWriter(WriterAdapter):
class WebSocketResponseWriter(WriterAdapter):
"""Interface to allow mqtt broker to write to an aiohttp websocket connection."""
def __init__(self, ws: web.WebSocketResponse):
def __init__(self, ws: web.WebSocketResponse, request: web.Request):
super().__init__()
self.ws = ws
# needed for `get_peer_info`
# https://docs.python.org/3/library/socket.html#socket.socket.getpeername
peer_name = request.transport.get_extra_info('peername')
if peer_name is not None:
self.client_ip, self.port = peer_name[0:2]
else:
self.client_ip, self.port = request.remote, 0
# interpret AF_INET6
self.client_ip = "localhost" if self.client_ip == "::1" else self.client_ip
self._stream = io.BytesIO(b"")
def write(self, data: bytes) -> None:
@ -58,21 +84,25 @@ class AIOWebSocketsWriter(WriterAdapter):
self._stream = io.BytesIO(b"")
def get_peer_info(self) -> tuple[str, int] | None:
return "external", 0
return self.client_ip, self.port
async def close(self) -> None:
pass
async def websocket_handler(request):
print()
async def websocket_handler(request: web.Request) -> web.StreamResponse:
# respond to the websocket request with the 'mqtt' protocol
ws = web.WebSocketResponse(protocols=['mqtt',])
await ws.prepare(request)
# access the broker created when the server started and notify the broker of this new connection
b: Broker = request.app['broker']
await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws))
print('websocket connection closed')
# send/receive data to the websocket. must pass the name of the externalized listener in the broker config
await b.external_connected(WebSocketResponseReader(ws), WebSocketResponseWriter(ws, request), 'myAIOHttp')
logger.debug('websocket connection closed')
return ws
@ -89,17 +119,16 @@ def main():
async def run_broker(_app):
"""https://docs.aiohttp.org/en/stable/web_advanced.html#background-tasks"""
loop = asyncio.get_event_loop()
cfg = BrokerConfig(
listeners={
'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'),
'aiohttp': ListenerConfig(type=ListenerType.EXTERNAL),
'default':ListenerConfig(type=ListenerType.TCP, bind='127.0.0.1:1883'),
'myAIOHttp': ListenerConfig(type=ListenerType.EXTERNAL),
}
)
broker = Broker(config=cfg, loop=loop)
_app['broker'] = broker
await broker.start()
@ -109,8 +138,6 @@ async def run_broker(_app):
await broker.shutdown()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
main()

Wyświetl plik

@ -1,11 +1,19 @@
import asyncio
import logging
import random
import threading
import time
from pathlib import Path
from threading import Thread
from typing import Any
from unittest.mock import MagicMock, call, patch
import pytest
import yaml
from paho.mqtt import client as paho_client
from yaml import Loader
from amqtt.broker import Broker
from amqtt.events import BrokerEvents
from amqtt.client import MQTTClient
from amqtt.mqtt.constants import QOS_1, QOS_2
@ -125,20 +133,51 @@ async def test_paho_qos2(broker, mock_plugin_manager):
await sub_client.disconnect()
await asyncio.sleep(0.1)
async def test_paho_ws():
def run_paho_client(flag):
client_id = 'websocket_client_1'
logging.info("creating paho client")
test_client = paho_client.Client(callback_api_version=paho_client.CallbackAPIVersion.VERSION2,
transport='websockets',
client_id=client_id)
test_client.ws_set_options('/ws')
test_client.connect('localhost', 8080)
test_client.ws_set_options('')
logging.info("client connecting...")
test_client.connect('127.0.0.1', 8080)
logging.info("starting loop")
test_client.loop_start()
await asyncio.sleep(0.1)
logging.info("client connected")
time.sleep(1)
logging.info("sending messages")
test_client.publish("/qos2", "test message", qos=2)
test_client.publish("/qos2", "test message", qos=2)
test_client.publish("/qos2", "test message", qos=2)
test_client.publish("/qos2", "test message", qos=2)
await asyncio.sleep(0.1)
test_client.loop_stop()
time.sleep(1)
test_client.loop_stop()
test_client.disconnect()
flag.set()
@pytest.mark.asyncio
async def test_paho_ws():
path = Path('docs_test/test.amqtt.local.yaml')
with path.open() as f:
cfg: dict[str, Any] = yaml.load(f, Loader=Loader)
logger.warning(cfg)
broker = Broker(config=cfg)
await broker.start()
# python websockets and paho mqtt don't play well with each other in the same thread
flag = threading.Event()
thread = Thread(target=run_paho_client, args=(flag,))
thread.start()
await asyncio.sleep(5)
thread.join(1)
assert flag.is_set(), "paho thread didn't execute completely"
logging.info("client disconnected")
await broker.shutdown()