pull/283/head
Andrew Mirsky 2025-08-04 17:47:29 -04:00
rodzic ad4854df00
commit c1625e264e
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
4 zmienionych plików z 27 dodań i 62 usunięć

Wyświetl plik

@ -10,9 +10,6 @@ import ssl
import time
from typing import Any, ClassVar, TypeAlias
from aiohttp.pytest_plugin import AiohttpServer
from aiohttp.web_ws import WebSocketResponse
from pygments.token import Other
from transitions import Machine, MachineError
import websockets.asyncio.server
from websockets.asyncio.server import ServerConnection
@ -25,7 +22,7 @@ from amqtt.adapters import (
WebSocketsWriter,
WriterAdapter,
)
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
@ -39,7 +36,7 @@ from .plugins.manager import PluginManager
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
# Default port numbers
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883, 'aiohttp': 8080}
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
@ -92,7 +89,7 @@ class Server:
await self.instance.wait_closed()
class OtherServer(Server):
class ExternalServer(Server):
def __init__(self):
super().__init__('aiohttp', None)
@ -260,19 +257,20 @@ class Broker:
max_connections = listener.get("max_connections", -1)
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
if listener["type"] == 'aiohttp':
self._servers[listener_name] = OtherServer()
if listener.type == ListenerType.EXTERNAL:
self.logger.info(f"External listener exists for '{listener_name}' ")
self._servers[listener_name] = ExternalServer()
else:
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
@staticmethod
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
@ -301,7 +299,7 @@ class Broker:
address: str | None,
port: int,
ssl_context: ssl.SSLContext | None,
) -> asyncio.Server | websockets.asyncio.server.Server | OtherServer:
) -> asyncio.Server | websockets.asyncio.server.Server:
"""Create a server instance for a listener."""
if listener_type == "tcp":
return await asyncio.start_server(
@ -319,8 +317,6 @@ class Broker:
ssl=ssl_context,
subprotocols=[websockets.Subprotocol("mqtt")],
)
if listener_type == "aiohttp":
return OtherServer()
msg = f"Unsupported listener type: {listener_type}"
raise BrokerError(msg)

Wyświetl plik

@ -75,10 +75,8 @@ async def decode_string(reader: ReaderAdapter | asyncio.StreamReader) -> str:
length_bytes = await read_or_raise(reader, 2)
if len(length_bytes) < 1:
raise ZeroLengthReadError
print(f"attempting to decode {length_bytes}")
str_length = unpack("!H", length_bytes)[0]
if str_length:
print(f"reading {str_length} bytes")
byte_str = await read_or_raise(reader, str_length)
try:
return byte_str.decode(encoding="utf-8")

Wyświetl plik

@ -44,7 +44,7 @@ class ListenerType(StrEnum):
TCP = "tcp"
WS = "ws"
AIOHTTP = "aiohttp"
EXTERNAL = "external"
def __repr__(self) -> str:
"""Display the string value, instead of the enum member."""
@ -115,6 +115,8 @@ class ListenerConfig(Dictable):
certificates needed to establish the certificate's authenticity.)"""
keyfile: str | Path | None = None
"""Full path to file in PEM format containing the server's private key."""
reader: str | None = None
writer: str | None = None
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
@ -126,6 +128,10 @@ class ListenerConfig(Dictable):
if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn)))
# if self.type == ListenerType.EXTERNAL and not all([self.reader, self.writer]):
# msg = "external type requires specifying reader, writer and server classes"
# raise ValueError(msg)
def apply(self, other: "ListenerConfig") -> None:
"""Apply the field from 'other', if 'self' field is default."""
for f in fields(self):

Wyświetl plik

@ -14,15 +14,12 @@ from amqtt.errors import BrokerError
async def hello(request):
return web.Response(text="Hello, world")
class AIOWebSocketsReader(ReaderAdapter):
def __init__(self, ws: web.WebSocketResponse):
self.ws = ws
self.buffer = bytearray()
async def read(self, n: int = -1) -> bytes:
print(f"attempting to read {n} bytes")
while not self.buffer or len(self.buffer) < n:
if self.ws.closed:
raise BrokenPipeError()
@ -30,16 +27,15 @@ class AIOWebSocketsReader(ReaderAdapter):
if msg.type == aiohttp.WSMsgType.BINARY:
self.buffer.extend(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSE:
print("received a close message!")
break
print(f"buffer size: {len(self.buffer)}")
raise BrokenPipeError()
if n == -1:
result = bytes(self.buffer)
self.buffer.clear()
else:
result = self.buffer[:n]
del self.buffer[:n]
print(f"bytes: {result}")
return result
def feed_eof(self) -> None:
@ -53,7 +49,6 @@ class AIOWebSocketsWriter(WriterAdapter):
self._stream = io.BytesIO(b"")
def write(self, data: bytes) -> None:
print(f"broker wants to write data: {data}")
self._stream.write(data)
async def drain(self) -> None:
@ -63,7 +58,7 @@ class AIOWebSocketsWriter(WriterAdapter):
self._stream = io.BytesIO(b"")
def get_peer_info(self) -> tuple[str, int] | None:
return "aiohttp", 1234567
return "external", 0
async def close(self) -> None:
pass
@ -73,39 +68,9 @@ async def websocket_handler(request):
print()
ws = web.WebSocketResponse(protocols=['mqtt',])
await ws.prepare(request)
#
# readQ = asyncio.Queue()
# writeQ = asyncio.Queue()
#
# async def receiver():
# async for msg in ws:
# match msg.type:
# case aiohttp.WSMsgType.BINARY:
# readQ.put_nowait(msg.data)
# case _:
# return
#
# async def send_items():
# while not ws.closed:
# if not writeQ.empty():
# item = await writeQ.get()
# await ws.send_bytes(item)
#
# await asyncio.create_task(send_items())
#
b: Broker = request.app['broker']
await b._client_connected('aiohttp', AIOWebSocketsReader(ws), AIOWebSocketsWriter(ws))
# async for msg in ws:
# print(f"ws: {msg}")
#
# elif msg.type == aiohttp.WSMsgType.ERROR:
# print('ws connection closed with exception %s' %
# ws.exception())
print('websocket connection closed')
return ws
@ -129,7 +94,7 @@ async def run_broker(_app):
cfg = BrokerConfig(
listeners={
'default':ListenerConfig(type=ListenerType.WS, bind='127.0.0.1:8883'),
'aiohttp': ListenerConfig(type=ListenerType.AIOHTTP),
'aiohttp': ListenerConfig(type=ListenerType.EXTERNAL),
}
)