pull/283/head
Andrew Mirsky 2025-08-04 17:47:29 -04:00
rodzic ad4854df00
commit c1625e264e
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
4 zmienionych plików z 27 dodań i 62 usunięć

Wyświetl plik

@ -10,9 +10,6 @@ import ssl
import time import time
from typing import Any, ClassVar, TypeAlias from typing import Any, ClassVar, TypeAlias
from aiohttp.pytest_plugin import AiohttpServer
from aiohttp.web_ws import WebSocketResponse
from pygments.token import Other
from transitions import Machine, MachineError from transitions import Machine, MachineError
import websockets.asyncio.server import websockets.asyncio.server
from websockets.asyncio.server import ServerConnection from websockets.asyncio.server import ServerConnection
@ -25,7 +22,7 @@ from amqtt.adapters import (
WebSocketsWriter, WebSocketsWriter,
WriterAdapter, WriterAdapter,
) )
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
@ -39,7 +36,7 @@ from .plugins.manager import PluginManager
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None] _BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
# Default port numbers # Default port numbers
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883, 'aiohttp': 8080} DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
@ -92,7 +89,7 @@ class Server:
await self.instance.wait_closed() await self.instance.wait_closed()
class OtherServer(Server): class ExternalServer(Server):
def __init__(self): def __init__(self):
super().__init__('aiohttp', None) super().__init__('aiohttp', None)
@ -260,19 +257,20 @@ class Broker:
max_connections = listener.get("max_connections", -1) max_connections = listener.get("max_connections", -1)
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
try: if listener.type == ListenerType.EXTERNAL:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) self.logger.info(f"External listener exists for '{listener_name}' ")
except ValueError as e: self._servers[listener_name] = ExternalServer()
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
if listener["type"] == 'aiohttp':
self._servers[listener_name] = OtherServer()
else: else:
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections) self._servers[listener_name] = Server(listener_name, instance, max_connections)
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
@staticmethod @staticmethod
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext: def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
@ -301,7 +299,7 @@ class Broker:
address: str | None, address: str | None,
port: int, port: int,
ssl_context: ssl.SSLContext | None, ssl_context: ssl.SSLContext | None,
) -> asyncio.Server | websockets.asyncio.server.Server | OtherServer: ) -> asyncio.Server | websockets.asyncio.server.Server:
"""Create a server instance for a listener.""" """Create a server instance for a listener."""
if listener_type == "tcp": if listener_type == "tcp":
return await asyncio.start_server( return await asyncio.start_server(
@ -319,8 +317,6 @@ class Broker:
ssl=ssl_context, ssl=ssl_context,
subprotocols=[websockets.Subprotocol("mqtt")], subprotocols=[websockets.Subprotocol("mqtt")],
) )
if listener_type == "aiohttp":
return OtherServer()
msg = f"Unsupported listener type: {listener_type}" msg = f"Unsupported listener type: {listener_type}"
raise BrokerError(msg) raise BrokerError(msg)

Wyświetl plik

@ -75,10 +75,8 @@ async def decode_string(reader: ReaderAdapter | asyncio.StreamReader) -> str:
length_bytes = await read_or_raise(reader, 2) length_bytes = await read_or_raise(reader, 2)
if len(length_bytes) < 1: if len(length_bytes) < 1:
raise ZeroLengthReadError raise ZeroLengthReadError
print(f"attempting to decode {length_bytes}")
str_length = unpack("!H", length_bytes)[0] str_length = unpack("!H", length_bytes)[0]
if str_length: if str_length:
print(f"reading {str_length} bytes")
byte_str = await read_or_raise(reader, str_length) byte_str = await read_or_raise(reader, str_length)
try: try:
return byte_str.decode(encoding="utf-8") return byte_str.decode(encoding="utf-8")

Wyświetl plik

@ -44,7 +44,7 @@ class ListenerType(StrEnum):
TCP = "tcp" TCP = "tcp"
WS = "ws" WS = "ws"
AIOHTTP = "aiohttp" EXTERNAL = "external"
def __repr__(self) -> str: def __repr__(self) -> str:
"""Display the string value, instead of the enum member.""" """Display the string value, instead of the enum member."""
@ -115,6 +115,8 @@ class ListenerConfig(Dictable):
certificates needed to establish the certificate's authenticity.)""" certificates needed to establish the certificate's authenticity.)"""
keyfile: str | Path | None = None keyfile: str | Path | None = None
"""Full path to file in PEM format containing the server's private key.""" """Full path to file in PEM format containing the server's private key."""
reader: str | None = None
writer: str | None = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use.""" """Check config for errors and transform fields for easier use."""
@ -126,6 +128,10 @@ class ListenerConfig(Dictable):
if isinstance(getattr(self, fn), str): if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn))) setattr(self, fn, Path(getattr(self, fn)))
# if self.type == ListenerType.EXTERNAL and not all([self.reader, self.writer]):
# msg = "external type requires specifying reader, writer and server classes"
# raise ValueError(msg)
def apply(self, other: "ListenerConfig") -> None: def apply(self, other: "ListenerConfig") -> None:
"""Apply the field from 'other', if 'self' field is default.""" """Apply the field from 'other', if 'self' field is default."""
for f in fields(self): for f in fields(self):

Wyświetl plik

@ -14,15 +14,12 @@ from amqtt.errors import BrokerError
async def hello(request): async def hello(request):
return web.Response(text="Hello, world") return web.Response(text="Hello, world")
class AIOWebSocketsReader(ReaderAdapter): class AIOWebSocketsReader(ReaderAdapter):
def __init__(self, ws: web.WebSocketResponse): def __init__(self, ws: web.WebSocketResponse):
self.ws = ws self.ws = ws
self.buffer = bytearray() self.buffer = bytearray()
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
print(f"attempting to read {n} bytes")
while not self.buffer or len(self.buffer) < n: while not self.buffer or len(self.buffer) < n:
if self.ws.closed: if self.ws.closed:
raise BrokenPipeError() raise BrokenPipeError()
@ -30,16 +27,15 @@ class AIOWebSocketsReader(ReaderAdapter):
if msg.type == aiohttp.WSMsgType.BINARY: if msg.type == aiohttp.WSMsgType.BINARY:
self.buffer.extend(msg.data) self.buffer.extend(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSE: elif msg.type == aiohttp.WSMsgType.CLOSE:
print("received a close message!") raise BrokenPipeError()
break
print(f"buffer size: {len(self.buffer)}")
if n == -1: if n == -1:
result = bytes(self.buffer) result = bytes(self.buffer)
self.buffer.clear() self.buffer.clear()
else: else:
result = self.buffer[:n] result = self.buffer[:n]
del self.buffer[:n] del self.buffer[:n]
print(f"bytes: {result}")
return result return result
def feed_eof(self) -> None: def feed_eof(self) -> None:
@ -53,7 +49,6 @@ class AIOWebSocketsWriter(WriterAdapter):
self._stream = io.BytesIO(b"") self._stream = io.BytesIO(b"")
def write(self, data: bytes) -> None: def write(self, data: bytes) -> None:
print(f"broker wants to write data: {data}")
self._stream.write(data) self._stream.write(data)
async def drain(self) -> None: async def drain(self) -> None:
@ -63,7 +58,7 @@ class AIOWebSocketsWriter(WriterAdapter):
self._stream = io.BytesIO(b"") self._stream = io.BytesIO(b"")
def get_peer_info(self) -> tuple[str, int] | None: def get_peer_info(self) -> tuple[str, int] | None:
return "aiohttp", 1234567 return "external", 0
async def close(self) -> None: async def close(self) -> None:
pass pass
@ -73,39 +68,9 @@ async def websocket_handler(request):
print() print()
ws = web.WebSocketResponse(protocols=['mqtt',]) ws = web.WebSocketResponse(protocols=['mqtt',])
await ws.prepare(request) await ws.prepare(request)
#
# readQ = asyncio.Queue()
# writeQ = asyncio.Queue()
#
# async def receiver():
# async for msg in ws:
# match msg.type:
# case aiohttp.WSMsgType.BINARY:
# readQ.put_nowait(msg.data)
# case _:
# return
#
# async def send_items():
# while not ws.closed:
# if not writeQ.empty():
# item = await writeQ.get()
# await ws.send_bytes(item)
#
# await asyncio.create_task(send_items())
#
b: Broker = request.app['broker'] b: Broker = request.app['broker']
await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws)) await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws))
# async for msg in ws:
# print(f"ws: {msg}")
#
# elif msg.type == aiohttp.WSMsgType.ERROR:
# print('ws connection closed with exception %s' %
# ws.exception())
print('websocket connection closed') print('websocket connection closed')
return ws return ws
@ -129,7 +94,7 @@ async def run_broker(_app):
cfg = BrokerConfig( cfg = BrokerConfig(
listeners={ listeners={
'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'), 'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'),
'aiohttp': ListenerConfig(type=ListenerType.AIOHTTP), 'aiohttp': ListenerConfig(type=ListenerType.EXTERNAL),
} }
) )