diff --git a/amqtt/broker.py b/amqtt/broker.py index 028af12..1a797d6 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -10,9 +10,6 @@ import ssl import time from typing import Any, ClassVar, TypeAlias -from aiohttp.pytest_plugin import AiohttpServer -from aiohttp.web_ws import WebSocketResponse -from pygments.token import Other from transitions import Machine, MachineError import websockets.asyncio.server from websockets.asyncio.server import ServerConnection @@ -25,7 +22,7 @@ from amqtt.adapters import ( WebSocketsWriter, WriterAdapter, ) -from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig +from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session @@ -39,7 +36,7 @@ from .plugins.manager import PluginManager _BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None] # Default port numbers -DEFAULT_PORTS = {"tcp": 1883, "ws": 8883, 'aiohttp': 8080} +DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 @@ -92,7 +89,7 @@ class Server: await self.instance.wait_closed() -class OtherServer(Server): +class ExternalServer(Server): def __init__(self): super().__init__('aiohttp', None) @@ -260,19 +257,20 @@ class Broker: max_connections = listener.get("max_connections", -1) ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None - try: - address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) - except ValueError as e: - msg = f"Invalid port value in bind value: {listener['bind']}" - raise BrokerError(msg) from e - - if listener["type"] == 'aiohttp': - self._servers[listener_name] = OtherServer() + if listener.type == ListenerType.EXTERNAL: + self.logger.info(f"External listener exists for '{listener_name}' ") + self._servers[listener_name] = ExternalServer() else: + try: + address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) + except ValueError as e: + msg = f"Invalid port value in bind value: {listener['bind']}" + raise BrokerError(msg) from e + instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) self._servers[listener_name] = Server(listener_name, instance, max_connections) - self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") + self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") @staticmethod def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext: @@ -301,7 +299,7 @@ class Broker: address: str | None, port: int, ssl_context: ssl.SSLContext | None, - ) -> asyncio.Server | websockets.asyncio.server.Server | OtherServer: + ) -> asyncio.Server | websockets.asyncio.server.Server: """Create a server instance for a listener.""" if listener_type == "tcp": return await asyncio.start_server( @@ -319,8 +317,6 @@ class Broker: ssl=ssl_context, subprotocols=[websockets.Subprotocol("mqtt")], ) - if listener_type == "aiohttp": - return OtherServer() msg = f"Unsupported listener type: {listener_type}" raise BrokerError(msg) diff --git a/amqtt/codecs_amqtt.py b/amqtt/codecs_amqtt.py index 20bb536..1db2d9c 100644 --- a/amqtt/codecs_amqtt.py +++ b/amqtt/codecs_amqtt.py @@ -75,10 +75,8 @@ async def decode_string(reader: ReaderAdapter | asyncio.StreamReader) -> str: length_bytes = await read_or_raise(reader, 2) if len(length_bytes) < 1: raise ZeroLengthReadError - print(f"attempting to decode {length_bytes}") str_length = unpack("!H", length_bytes)[0] if str_length: - print(f"reading {str_length} bytes") byte_str = await read_or_raise(reader, str_length) try: return byte_str.decode(encoding="utf-8") diff --git a/amqtt/contexts.py b/amqtt/contexts.py index 1bd733c..b397d62 100644 --- a/amqtt/contexts.py +++ b/amqtt/contexts.py @@ -44,7 +44,7 @@ class ListenerType(StrEnum): TCP = "tcp" WS = "ws" - AIOHTTP = "aiohttp" + EXTERNAL = "external" def __repr__(self) -> str: """Display the string value, instead of the enum member.""" @@ -115,6 +115,8 @@ class ListenerConfig(Dictable): certificates needed to establish the certificate's authenticity.)""" keyfile: str | Path | None = None """Full path to file in PEM format containing the server's private key.""" + reader: str | None = None + writer: str | None = None def __post_init__(self) -> None: """Check config for errors and transform fields for easier use.""" @@ -126,6 +128,10 @@ class ListenerConfig(Dictable): if isinstance(getattr(self, fn), str): setattr(self, fn, Path(getattr(self, fn))) + # if self.type == ListenerType.EXTERNAL and not all([self.reader, self.writer]): + # msg = "external type requires specifying reader, writer and server classes" + # raise ValueError(msg) + def apply(self, other: "ListenerConfig") -> None: """Apply the field from 'other', if 'self' field is default.""" for f in fields(self): diff --git a/samples/web.py b/samples/web.py index 5f8bc68..552953f 100644 --- a/samples/web.py +++ b/samples/web.py @@ -14,15 +14,12 @@ from amqtt.errors import BrokerError async def hello(request): return web.Response(text="Hello, world") - - class AIOWebSocketsReader(ReaderAdapter): def __init__(self, ws: web.WebSocketResponse): self.ws = ws self.buffer = bytearray() async def read(self, n: int = -1) -> bytes: - print(f"attempting to read {n} bytes") while not self.buffer or len(self.buffer) < n: if self.ws.closed: raise BrokenPipeError() @@ -30,16 +27,15 @@ class AIOWebSocketsReader(ReaderAdapter): if msg.type == aiohttp.WSMsgType.BINARY: self.buffer.extend(msg.data) elif msg.type == aiohttp.WSMsgType.CLOSE: - print("received a close message!") - break - print(f"buffer size: {len(self.buffer)}") + raise BrokenPipeError() + if n == -1: result = bytes(self.buffer) self.buffer.clear() else: result = self.buffer[:n] del self.buffer[:n] - print(f"bytes: {result}") + return result def feed_eof(self) -> None: @@ -53,7 +49,6 @@ class AIOWebSocketsWriter(WriterAdapter): self._stream = io.BytesIO(b"") def write(self, data: bytes) -> None: - print(f"broker wants to write data: {data}") self._stream.write(data) async def drain(self) -> None: @@ -63,7 +58,7 @@ class AIOWebSocketsWriter(WriterAdapter): self._stream = io.BytesIO(b"") def get_peer_info(self) -> tuple[str, int] | None: - return "aiohttp", 1234567 + return "external", 0 async def close(self) -> None: pass @@ -73,39 +68,9 @@ async def websocket_handler(request): print() ws = web.WebSocketResponse(protocols=['mqtt',]) await ws.prepare(request) - # - # readQ = asyncio.Queue() - # writeQ = asyncio.Queue() - # - # async def receiver(): - # async for msg in ws: - # match msg.type: - # case aiohttp.WSMsgType.BINARY: - # readQ.put_nowait(msg.data) - # case _: - # return - # - # async def send_items(): - # while not ws.closed: - # if not writeQ.empty(): - # item = await writeQ.get() - # await ws.send_bytes(item) - # - # await asyncio.create_task(send_items()) - # + b: Broker = request.app['broker'] await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws)) - - # async for msg in ws: - # print(f"ws: {msg}") - - - - # - # elif msg.type == aiohttp.WSMsgType.ERROR: - # print('ws connection closed with exception %s' % - # ws.exception()) - print('websocket connection closed') return ws @@ -129,7 +94,7 @@ async def run_broker(_app): cfg = BrokerConfig( listeners={ 'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'), - 'aiohttp': ListenerConfig(type=ListenerType.AIOHTTP), + 'aiohttp': ListenerConfig(type=ListenerType.EXTERNAL), } )