Merge pull request #276 from Yakifo/0.11.3-rc.1

Release 0.11.3
main v0.11.3
Andrew Mirsky 2025-08-12 09:39:47 -04:00 zatwierdzone przez GitHub
commit 2637127b41
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
152 zmienionych plików z 9821 dodań i 2941 usunięć

Wyświetl plik

@ -4,6 +4,7 @@ source = bumper
omit =
tests/*
amqtt/scripts/*.py
[report]
exclude_lines =

Wyświetl plik

@ -32,8 +32,11 @@ jobs:
cache-local-path: ${{ env.UV_CACHE_DIR }}
python-version: "3.13"
- name: install openldap dependencies
run: sudo apt-get install -y libldap2-dev libsasl2-dev
- name: 🏗 Install the project
run: uv sync --locked --dev
run: uv sync --locked --dev --all-extras
- name: Run mypy
run: uv run --frozen mypy ${{ env.PROJECT_PATH }}/
@ -68,8 +71,11 @@ jobs:
cache-local-path: ${{ env.UV_CACHE_DIR }}
python-version: ${{ matrix.python-version }}
- name: install openldap dependencies
run: sudo apt-get install -y libldap2-dev libsasl2-dev
- name: 🏗 Install the project
run: uv sync --locked --dev
run: uv sync --locked --dev --all-extras
- name: Run pytest
run: uv run --frozen pytest tests/ --cov=./ --cov-report=xml --junitxml=pytest-report.xml

3
.gitignore vendored
Wyświetl plik

@ -4,6 +4,9 @@ __pycache__
node_modules
.vite
*.pem
*.crt
*.key
*.patch
#------- Environment Files -------
.python-version

Wyświetl plik

@ -1,15 +1,22 @@
version: 2
version: 2
build:
os: "ubuntu-24.04"
tools:
python: "3.13"
jobs:
pre_install:
- pip install --upgrade pip
- pip install uv
- uv pip install --group dev --group docs
- uv run pytest
build:
os: "ubuntu-24.04"
tools:
python: "3.13"
apt_packages:
- libldap2-dev
- libsasl2-dev
jobs:
pre_install:
- pip install --upgrade pip
- pip install uv
- uv venv
- uv pip install --group dev --group docs ".[contrib]"
- uv run pytest --mock-docker=true
build:
html:
- uv run python -m mkdocs build --clean --site-dir $READTHEDOCS_OUTPUT/html --config-file mkdocs.rtd.yml
mkdocs:
configuration: mkdocs.rtd.yml
mkdocs:
configuration: mkdocs.rtd.yml

Wyświetl plik

@ -1,7 +1,7 @@
# Image name and tag
IMAGE_NAME := amqtt
IMAGE_TAG := latest
VERSION_TAG := 0.11.2
VERSION_TAG := 0.11.3
REGISTRY := amqtt/$(IMAGE_NAME)
# Platforms to build for

Wyświetl plik

@ -14,10 +14,21 @@
## Features
- Full set of [MQTT 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) protocol specifications
- Communication over TCP and/or websocket, including support for SSL/TLS
- Communication over multiple TCP and/or websocket ports, including support for SSL/TLS
- Support QoS 0, QoS 1 and QoS 2 messages flow
- Client auto-reconnection on network lost
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
- Plugin framework for functionality expansion; included plugins:
- `$SYS` topic publishing
- AWS IOT-style shadow states
- x509 certificate authentication (including cli cert creation)
- Secure file-based password authentication
- Configuration-based topic authorization
- MySQL, Postgres & SQLite user and/or topic auth (including cli manager)
- External server (HTTP) user and/or topic auth
- LDAP user and/or topic auth
- JWT user and/or topic auth
- Fail over session persistence
## Installation

Wyświetl plik

@ -1,3 +1,3 @@
"""INIT."""
__version__ = "0.11.2"
__version__ = "0.11.3"

Wyświetl plik

@ -3,6 +3,8 @@ from asyncio import StreamReader, StreamWriter
from contextlib import suppress
import io
import logging
import ssl
from typing import cast
from websockets import ConnectionClosed
from websockets.asyncio.connection import Connection
@ -52,6 +54,11 @@ class WriterAdapter(ABC):
"""Return peer socket info (remote address and remote port as tuple)."""
raise NotImplementedError
@abstractmethod
def get_ssl_info(self) -> ssl.SSLObject | None:
"""Return peer certificate information (if available) used to establish a TLS session."""
raise NotImplementedError
@abstractmethod
async def close(self) -> None:
"""Close the protocol connection."""
@ -121,6 +128,9 @@ class WebSocketsWriter(WriterAdapter):
remote_address: tuple[str, int] | None = self._protocol.remote_address[:2]
return remote_address
def get_ssl_info(self) -> ssl.SSLObject | None:
return cast("ssl.SSLObject", self._protocol.transport.get_extra_info("ssl_object"))
async def close(self) -> None:
await self._protocol.close()
@ -170,6 +180,9 @@ class StreamWriterAdapter(WriterAdapter):
extra_info = self._writer.get_extra_info("peername")
return extra_info[0], extra_info[1]
def get_ssl_info(self) -> ssl.SSLObject | None:
return cast("ssl.SSLObject", self._writer.get_extra_info("ssl_object"))
async def close(self) -> None:
if not self.is_closed:
self.is_closed = True # we first mark this closed so yields below don't cause races with waiting writes
@ -204,6 +217,9 @@ class BufferWriter(WriterAdapter):
This adapter simply adapts writing to a byte buffer.
"""
def get_ssl_info(self) -> ssl.SSLObject | None:
return None
def __init__(self, buffer: bytes = b"") -> None:
self._stream = io.BytesIO(buffer)

Wyświetl plik

@ -2,12 +2,12 @@ import asyncio
from asyncio import CancelledError, futures
from collections import deque
from collections.abc import Generator
import copy
from functools import partial
import logging
from pathlib import Path
from math import floor
import re
import ssl
import time
from typing import Any, ClassVar, TypeAlias
from transitions import Machine, MachineError
@ -22,24 +22,19 @@ from amqtt.adapters import (
WebSocketsWriter,
WriterAdapter,
)
from amqtt.contexts import Action, BaseContext
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
from amqtt.utils import format_client_message, gen_client_id
from .events import BrokerEvents
from .mqtt.constants import QOS_0, QOS_1, QOS_2
from .mqtt.disconnect import DisconnectPacket
from .plugins.manager import PluginManager
_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
_defaults = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yaml")
# Default port numbers
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
@ -57,6 +52,8 @@ class RetainedApplicationMessage(ApplicationMessage):
class Server:
"""Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle."""
def __init__(
self,
listener_name: str,
@ -94,28 +91,46 @@ class Server:
await self.instance.wait_closed()
class BrokerContext(BaseContext):
"""BrokerContext is used as the context passed to plugins interacting with the broker.
class ExternalServer(Server):
"""For external listeners, the connection lifecycle is handled by that implementation so these are no-ops."""
It act as an adapter to broker services from plugins developed for HBMQTT broker.
"""
def __init__(self) -> None:
super().__init__("aiohttp", None) # type: ignore[arg-type]
async def acquire_connection(self) -> None:
pass
def release_connection(self) -> None:
pass
async def close_instance(self) -> None:
pass
class BrokerContext(BaseContext):
"""Used to provide the server's context as well as public methods for accessing internal state."""
def __init__(self, broker: "Broker") -> None:
super().__init__()
self.config: _CONFIG_LISTENER | None = None
self.config: BrokerConfig | None = None
self._broker_instance = broker
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
"""Send message to all client sessions subscribing to `topic`."""
await self._broker_instance.internal_message_broadcast(topic, data, qos)
def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
self._broker_instance.retain_message(None, topic_name, data, qos)
async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
await self._broker_instance.retain_message(None, topic_name, data, qos)
@property
def sessions(self) -> Generator[Session]:
for session in self._broker_instance.sessions.values():
yield session[0]
def get_session(self, client_id: str) -> Session | None:
"""Return the session associated with `client_id`, if it exists."""
return self._broker_instance.sessions.get(client_id, (None, None))[0]
@property
def retained_messages(self) -> dict[str, RetainedApplicationMessage]:
return self._broker_instance.retained_messages
@ -124,17 +139,33 @@ class BrokerContext(BaseContext):
def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
return self._broker_instance.subscriptions
async def add_subscription(self, client_id: str, topic: str | None, qos: int | None) -> None:
"""Create a topic subscription for the given `client_id`.
If a client session doesn't exist for `client_id`, create a disconnected session.
If `topic` and `qos` are both `None`, only create the client session.
"""
if client_id not in self._broker_instance.sessions:
broker_handler, session = self._broker_instance.create_offline_session(client_id)
self._broker_instance._sessions[client_id] = (session, broker_handler) # noqa: SLF001
if topic is not None and qos is not None:
session, _ = self._broker_instance.sessions[client_id]
await self._broker_instance.add_subscription((topic, qos), session)
class Broker:
"""MQTT 3.1.1 compliant broker implementation.
Args:
config: dictionary of configuration options (see [broker configuration](broker_config.md)).
config: `BrokerConfig` or dictionary of equivalent structure options (see [broker configuration](broker_config.md)).
loop: asyncio loop. defaults to `asyncio.new_event_loop()`.
plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
Raises:
BrokerError, ParserError, PluginError
BrokerError: problem with broker configuration
PluginImportError: if importing a plugin from configuration
PluginInitError: if initialization plugin fails
"""
@ -150,20 +181,20 @@ class Broker:
def __init__(
self,
config: _CONFIG_LISTENER | None = None,
config: BrokerConfig | dict[str, Any] | None = None,
loop: asyncio.AbstractEventLoop | None = None,
plugin_namespace: str | None = None,
) -> None:
"""Initialize the broker."""
self.logger = logging.getLogger(__name__)
self.config = copy.deepcopy(_defaults or {})
if config is not None:
# if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config
if ("auth" in config or "topic-check" in config) and "plugins" not in config:
# set to None so that the config isn't updated with the new-style default plugin list
config["plugins"] = None # type: ignore[assignment]
self.config.update(config)
self._build_listeners_config(self.config)
if isinstance(config, dict):
self.config = BrokerConfig.from_dict(config)
else:
self.config = config or BrokerConfig()
# listeners are populated from default within BrokerConfig
self.listeners_config = self.config.listeners
self._loop = loop or asyncio.get_running_loop()
self._servers: dict[str, Server] = {}
@ -182,6 +213,9 @@ class Broker:
# Tasks queue for managing broadcasting tasks
self._tasks_queue: deque[asyncio.Task[OutgoingApplicationMessage]] = deque()
# Task for session monitor
self._session_monitor_task: asyncio.Task[Any] | None = None
# Initialize plugins manager
context = BrokerContext(self)
@ -189,26 +223,6 @@ class Broker:
namespace = plugin_namespace or "amqtt.broker.plugins"
self.plugins_manager = PluginManager(namespace, context, self._loop)
def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None:
self.listeners_config = {}
try:
listeners_config = broker_config.get("listeners")
if not isinstance(listeners_config, dict):
msg = "Listener config not found or invalid"
raise BrokerError(msg)
defaults = listeners_config.get("default")
if defaults is None:
msg = "Listener config has not default included or is invalid"
raise BrokerError(msg)
for listener_name, listener_conf in listeners_config.items():
config = defaults.copy()
config.update(listener_conf)
self.listeners_config[listener_name] = config
except KeyError as ke:
msg = f"Listener config not found or invalid: {ke}"
raise BrokerError(msg) from ke
def _init_states(self) -> None:
self.transitions = Machine(states=Broker.states, initial="new")
self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change)
@ -245,6 +259,7 @@ class Broker:
self.transitions.starting_success()
await self.plugins_manager.fire_event(BrokerEvents.POST_START)
self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
self._session_monitor_task = asyncio.create_task(self._session_monitor())
self.logger.debug("Broker started")
except Exception as e:
self.logger.exception("Broker startup failed")
@ -262,18 +277,27 @@ class Broker:
max_connections = listener.get("max_connections", -1)
ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
# for listeners which are external, don't need to create a server
if listener.type == ListenerType.EXTERNAL:
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
# broker still needs to associate a new connection to the listener
self.logger.info(f"External listener exists for '{listener_name}' ")
self._servers[listener_name] = ExternalServer()
else:
# for tcp and websockets, start servers to listen for inbound connections
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
instance = await self._create_server_instance(listener_name, listener.type, address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext:
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
@staticmethod
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
"""Create an SSL context for a listener."""
try:
ssl_context = ssl.create_default_context(
@ -295,30 +319,61 @@ class Broker:
async def _create_server_instance(
self,
listener_name: str,
listener_type: str,
listener_type: ListenerType,
address: str | None,
port: int,
ssl_context: ssl.SSLContext | None,
) -> asyncio.Server | websockets.asyncio.server.Server:
"""Create a server instance for a listener."""
if listener_type == "tcp":
return await asyncio.start_server(
partial(self.stream_connected, listener_name=listener_name),
address,
port,
reuse_address=True,
ssl=ssl_context,
)
if listener_type == "ws":
return await websockets.serve(
partial(self.ws_connected, listener_name=listener_name),
address,
port,
ssl=ssl_context,
subprotocols=[websockets.Subprotocol("mqtt")],
)
msg = f"Unsupported listener type: {listener_type}"
raise BrokerError(msg)
match listener_type:
case ListenerType.TCP:
return await asyncio.start_server(
partial(self.stream_connected, listener_name=listener_name),
address,
port,
reuse_address=True,
ssl=ssl_context,
)
case ListenerType.WS:
return await websockets.serve(
partial(self.ws_connected, listener_name=listener_name),
address,
port,
ssl=ssl_context,
subprotocols=[websockets.Subprotocol("mqtt")],
)
case _:
msg = f"Unsupported listener type: {listener_type}"
raise BrokerError(msg)
async def _session_monitor(self) -> None:
self.logger.info("Starting session expiration monitor.")
while True:
session_count_before = len(self._sessions)
# clean or anonymous sessions don't retain messages (or subscriptions); the session can be filtered out
sessions_to_remove = [client_id for client_id, (session, _) in self._sessions.items()
if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session)]
# if session expiration is enabled, check to see if any of the sessions are disconnected and past expiration
if self.config.session_expiry_interval is not None:
retain_after = floor(time.time() - self.config.session_expiry_interval)
sessions_to_remove += [client_id for client_id, (session, _) in self._sessions.items()
if session.transitions.state == "disconnected" and
session.last_disconnect_time and
session.last_disconnect_time < retain_after]
for client_id in sessions_to_remove:
await self._cleanup_session(client_id)
if session_count_before > (session_count_after := len(self._sessions)):
self.logger.debug(f"Expired {session_count_before - session_count_after} sessions")
await asyncio.sleep(1)
async def shutdown(self) -> None:
"""Stop broker instance."""
@ -337,6 +392,8 @@ class Broker:
self.transitions.shutdown()
await self._shutdown_broadcast_loop()
if self._session_monitor_task:
self._session_monitor_task.cancel()
for server in self._servers.values():
await server.close_instance()
@ -372,6 +429,10 @@ class Broker:
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
"""Engage the broker in handling the data stream to/from an established connection."""
await self._client_connected(listener_name, reader, writer)
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
"""Handle a new client connection."""
server = self._servers.get(listener_name)
@ -443,7 +504,17 @@ class Broker:
# Get session from cache
elif client_session.client_id in self._sessions:
self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}")
client_session, _ = self._sessions[client_session.client_id]
# even though the session previously existed, the new connection can bring updated configuration and credentials
existing_client_session, _ = self._sessions[client_session.client_id]
existing_client_session.will_flag = client_session.will_flag
existing_client_session.will_message = client_session.will_message
existing_client_session.will_topic = client_session.will_topic
existing_client_session.will_qos = client_session.will_qos
existing_client_session.keep_alive = client_session.keep_alive
existing_client_session.username = client_session.username
existing_client_session.password = client_session.password
client_session = existing_client_session
client_session.parent = 1
else:
client_session.parent = 0
@ -455,6 +526,14 @@ class Broker:
self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
return handler, client_session
def create_offline_session(self, client_id: str) -> tuple[BrokerProtocolHandler, Session]:
session = Session()
session.client_id = client_id
bph = BrokerProtocolHandler(self.plugins_manager, session)
session.transitions.disconnect()
return bph, session
async def _handle_client_session(
self,
reader: ReaderAdapter,
@ -506,9 +585,7 @@ class Broker:
# if this is not a new session, there are subscriptions associated with them; publish any topic retained messages
self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.")
for topic in self._subscriptions:
await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session)
await self._client_message_loop(client_session, handler)
@ -540,7 +617,6 @@ class Broker:
# no need to reschedule the `disconnect_waiter` since we're exiting the message loop
if subscribe_waiter in done:
await self._handle_subscription(client_session, handler, subscribe_waiter)
subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
@ -595,7 +671,7 @@ class Broker:
client_session.will_qos,
)
if client_session.will_retain:
self.retain_message(
await self.retain_message(
client_session,
client_session.will_topic,
client_session.will_message,
@ -610,7 +686,6 @@ class Broker:
client_id=client_session.client_id,
client_session=client_session)
async def _handle_subscription(
self,
client_session: Session,
@ -620,7 +695,7 @@ class Broker:
"""Handle client subscription."""
self.logger.debug(f"{client_session.client_id} handling subscription")
subscriptions = subscribe_waiter.result()
return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics]
return_codes = [await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics]
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
for index, subscription in enumerate(subscriptions.topics):
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
@ -660,6 +735,13 @@ class Broker:
self.logger.debug(f"{client_session.client_id} handling message delivery")
app_message = wait_deliver.result()
# notify of a message's receipt, even if a client isn't necessarily allowed to send it
await self.plugins_manager.fire_event(
BrokerEvents.MESSAGE_RECEIVED,
client_id=client_session.client_id,
message=app_message,
)
if app_message is None:
self.logger.debug("app_message was empty!")
return True
@ -681,16 +763,17 @@ class Broker:
permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
if not permitted:
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
self.logger.info(f"{client_session.client_id} not allowed to publish to TOPIC {app_message.topic}.")
else:
# notify that a received message is valid and is allowed to be distributed to other clients
await self.plugins_manager.fire_event(
BrokerEvents.MESSAGE_RECEIVED,
BrokerEvents.MESSAGE_BROADCAST,
client_id=client_session.client_id,
message=app_message,
)
await self._broadcast_message(client_session, app_message.topic, app_message.data)
if app_message.publish_packet and app_message.publish_packet.retain_flag:
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
await self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
return True
async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
@ -703,10 +786,11 @@ class Broker:
"""Stop a running handler and detach if from the session."""
try:
await handler.stop()
except Exception:
# a failure in stopping a handler shouldn't cause the broker to fail
except asyncio.QueueEmpty:
self.logger.exception("Failed to stop handler")
async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool:
async def _authenticate(self, session: Session, _: ListenerConfig) -> bool:
"""Call the authenticate method on registered plugins to test user authentication.
User is considered authenticated if all plugins called returns True.
@ -719,7 +803,7 @@ class Broker:
"""
returns = await self.plugins_manager.map_plugin_auth(session=session)
results = [ result for _, result in returns.items() if result is not None] if returns else []
results = [result for _, result in returns.items() if result is not None] if returns else []
if len(results) < 1:
self.logger.debug("Authentication failed: no plugin responded with a boolean")
return False
@ -733,7 +817,7 @@ class Broker:
return False
def retain_message(
async def retain_message(
self,
source_session: Session | None,
topic_name: str | None,
@ -744,12 +828,25 @@ class Broker:
# If retained flag set, store the message for further subscriptions
self.logger.debug(f"Retaining message on topic {topic_name}")
self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
client_id=None,
retained_message=self._retained_messages[topic_name])
# [MQTT-3.3.1-10]
elif topic_name in self._retained_messages:
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
cleared_message = self._retained_messages[topic_name]
cleared_message.data = b""
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
client_id=None,
retained_message=cleared_message)
del self._retained_messages[topic_name]
async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
topic_filter, qos = subscription
if "#" in topic_filter and not topic_filter.endswith("#"):
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
@ -853,7 +950,8 @@ class Broker:
task.result()
except CancelledError:
self.logger.info(f"Task has been cancelled: {task}")
except Exception:
# if a task fails, don't want it to cause the broker to fail
except Exception: # pylint: disable=W0718
self.logger.exception(f"Task failed and will be skipped: {task}")
run_broadcast_task = asyncio.ensure_future(self._run_broadcast(running_tasks))
@ -892,6 +990,12 @@ class Broker:
for target_session, sub_qos in subscriptions:
qos = broadcast.get("qos", sub_qos)
sendable = await self._topic_filtering(target_session, topic=broadcast["topic"], action=Action.RECEIVE)
if not sendable:
self.logger.info(
f"{target_session.client_id} not allowed to receive messages from TOPIC {broadcast['topic']}.")
continue
# Retain all messages which cannot be broadcasted, due to the session not being connected
# but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]
# and, if a client used anonymous authentication, there is no expectation that messages should be retained
@ -934,6 +1038,10 @@ class Broker:
retained_message = RetainedApplicationMessage(broadcast["session"], broadcast["topic"], broadcast["data"], qos)
await target_session.retained_messages.put(retained_message)
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
client_id=target_session.client_id,
retained_message=retained_message)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}")

Wyświetl plik

@ -2,10 +2,8 @@ import asyncio
from collections import deque
from collections.abc import Callable, Coroutine
import contextlib
import copy
from functools import wraps
import logging
from pathlib import Path
import ssl
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from urllib.parse import urlparse, urlunparse
@ -19,20 +17,18 @@ from amqtt.adapters import (
WebSocketsReader,
WebSocketsWriter,
)
from amqtt.contexts import BaseContext
from amqtt.contexts import BaseContext, ClientConfig
from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError
from amqtt.mqtt.connack import CONNECTION_ACCEPTED
from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
from amqtt.plugins.manager import PluginManager
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import gen_client_id, read_yaml_config
from amqtt.utils import gen_client_id
if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
_defaults: dict[str, Any] | None = read_yaml_config(Path(__file__).parent / "scripts/default_client.yaml")
class ClientContext(BaseContext):
"""ClientContext is used as the context passed to plugins interacting with the client.
@ -42,7 +38,7 @@ class ClientContext(BaseContext):
def __init__(self) -> None:
super().__init__()
self.config = None
self.config: ClientConfig | None = None
base_logger = logging.getLogger(__name__)
@ -79,26 +75,27 @@ def mqtt_connected(func: _F) -> _F:
class MQTTClient:
"""MQTT client implementation.
MQTTClient instances provides API for connecting to a broker and send/receive
messages using the MQTT protocol.
"""MQTT client implementation, providing an API for connecting to a broker and send/receive messages using the MQTT protocol.
Args:
client_id: MQTT client ID to use when connecting to the broker. If none,
it will be generated randomly by `amqtt.utils.gen_client_id`
config: dictionary of configuration options (see [client configuration](client_config.md)).
config: `ClientConfig` or dictionary of equivalent structure options (see [client configuration](client_config.md)).
Raises:
PluginError
PluginImportError: if importing a plugin from configuration fails
PluginInitError: if initialization plugin fails
"""
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
def __init__(self, client_id: str | None = None, config: ClientConfig | dict[str, Any] | None = None) -> None:
self.logger = logging.getLogger(__name__)
self.config = copy.deepcopy(_defaults or {})
if config is not None:
self.config.update(config)
if isinstance(config, dict):
self.config = ClientConfig.from_dict(config)
else:
self.config = config or ClientConfig()
self.client_id = client_id if client_id is not None else gen_client_id()
self.session: Session | None = None
@ -146,7 +143,7 @@ class MQTTClient:
[CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
Raises:
ClientError, ConnectError
ConnectError: could not connect to broker
"""
additional_headers = additional_headers if additional_headers is not None else {}
@ -159,7 +156,8 @@ class MQTTClient:
except asyncio.CancelledError as e:
msg = "Future or Task was cancelled"
raise ConnectError(msg) from e
except Exception as e:
# no matter the failure mode, still try to reconnect
except Exception as e: # pylint: disable=W0718
self.logger.warning(f"Connection failed: {e!r}")
if not self.config.get("auto_reconnect", False):
raise
@ -233,7 +231,8 @@ class MQTTClient:
except asyncio.CancelledError as e:
msg = "Future or Task was cancelled"
raise ConnectError(msg) from e
except Exception as e:
# no matter the failure mode, still try to reconnect
except Exception as e: # pylint: disable=W0718
self.logger.warning(f"Reconnection attempt failed: {e!r}")
self.logger.debug("", exc_info=True)
if 0 <= reconnect_retries < nb_attempt:
@ -381,6 +380,7 @@ class MQTTClient:
Raises:
asyncio.TimeoutError: if timeout occurs before a message is delivered
ClientError: if client is not connected
"""
if self._handler is None:
@ -424,14 +424,10 @@ class MQTTClient:
scheme = uri_attributes.scheme
secure = scheme in ("mqtts", "wss")
self.session.username = (
self.session.username
if self.session.username
else (str(uri_attributes.username) if uri_attributes.username else None)
self.session.username or (str(uri_attributes.username) if uri_attributes.username else None)
)
self.session.password = (
self.session.password
if self.session.password
else (str(uri_attributes.password) if uri_attributes.password else None)
self.session.password or (str(uri_attributes.password) if uri_attributes.password else None)
)
self.session.remote_address = str(uri_attributes.hostname) if uri_attributes.hostname else None
self.session.remote_port = uri_attributes.port
@ -462,15 +458,15 @@ class MQTTClient:
if secure:
sc = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH,
cafile=self.session.cafile,
capath=self.session.capath,
cadata=self.session.cadata,
cafile=self.session.cafile
)
if "certfile" in self.config:
sc.load_verify_locations(cafile=self.config["certfile"])
if "check_hostname" in self.config and isinstance(self.config["check_hostname"], bool):
sc.check_hostname = self.config["check_hostname"]
if self.config.connection.certfile and self.config.connection.keyfile:
sc.load_cert_chain(certfile=self.config.connection.certfile, keyfile=self.config.connection.keyfile)
if self.config.connection.cafile:
sc.load_verify_locations(cafile=self.config.connection.cafile)
if self.config.check_hostname is not None:
sc.check_hostname = self.config.check_hostname
sc.verify_mode = ssl.CERT_REQUIRED
kwargs["ssl"] = sc
@ -525,7 +521,7 @@ class MQTTClient:
self._connected_state.set()
self.logger.debug(f"Connected to {self.session.remote_address}:{self.session.remote_port}")
except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError) as e:
except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError, asyncio.TimeoutError) as e:
self.logger.debug(f"Connection failed : {self.session.broker_uri} [{e!r}]")
self.session.transitions.disconnect()
raise ConnectError(e) from e
@ -581,10 +577,18 @@ class MQTTClient:
cadata: str | None = None,
) -> Session:
"""Initialize the MQTT session."""
broker_conf = self.config.get("broker", {}).copy()
broker_conf.update(
{k: v for k, v in {"uri": uri, "cafile": cafile, "capath": capath, "cadata": cadata}.items() if v is not None},
)
broker_conf = self.config.get("connection", {}).copy()
if uri is not None:
broker_conf.uri = uri
if cleansession is not None:
self.config.cleansession = cleansession
if cafile is not None:
broker_conf.cafile = cafile
if capath is not None:
broker_conf.capath = capath
if cadata is not None:
broker_conf.cadata = cadata
if not broker_conf.get("uri"):
msg = "Missing connection parameter 'uri'"
@ -593,15 +597,12 @@ class MQTTClient:
session = Session()
session.broker_uri = broker_conf["uri"]
session.client_id = self.client_id
session.cafile = broker_conf.get("cafile")
session.capath = broker_conf.get("capath")
session.cadata = broker_conf.get("cadata")
if cleansession is not None:
broker_conf["cleansession"] = cleansession # noop?
session.clean_session = cleansession
else:
session.clean_session = self.config.get("cleansession", True)
session.clean_session = self.config.get("cleansession", True)
session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"]

Wyświetl plik

@ -142,8 +142,8 @@ def int_to_bytes_str(value: int) -> bytes:
return str(value).encode("utf-8")
def float_to_bytes_str(value: float, places:int=3) -> bytes:
def float_to_bytes_str(value: float, places: int = 3) -> bytes:
"""Convert an float value to a bytes array containing the numeric character."""
quant = Decimal(f"0.{''.join(['0' for i in range(places-1)])}1")
quant = Decimal(f"0.{''.join(['0' for i in range(places - 1)])}1")
rounded = Decimal(value).quantize(quant, rounding=ROUND_HALF_UP)
return str(rounded).encode("utf-8")

Wyświetl plik

@ -1,22 +1,379 @@
from enum import Enum
from dataclasses import dataclass, field, fields, replace
import logging
from typing import TYPE_CHECKING, Any
import warnings
_LOGGER = logging.getLogger(__name__)
try:
from enum import Enum, StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass
from amqtt.mqtt.constants import QOS_0, QOS_2
if TYPE_CHECKING:
import asyncio
logger = logging.getLogger(__name__)
class BaseContext:
def __init__(self) -> None:
self.loop: asyncio.AbstractEventLoop | None = None
self.logger: logging.Logger = _LOGGER
self.config: dict[str, Any] | None = None
self.logger: logging.Logger = logging.getLogger(__name__)
# cleanup with a `Generic` type
self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None
class Action(Enum):
class Action(StrEnum):
"""Actions issued by the broker."""
SUBSCRIBE = "subscribe"
PUBLISH = "publish"
RECEIVE = "receive"
class ListenerType(StrEnum):
"""Types of mqtt listeners."""
TCP = "tcp"
WS = "ws"
EXTERNAL = "external"
def __repr__(self) -> str:
"""Display the string value, instead of the enum member."""
return f'"{self.value!s}"'
class Dictable:
"""Add dictionary methods to a dataclass."""
def __getitem__(self, key: str) -> Any:
"""Allow dict-style `[]` access to a dataclass."""
return self.get(key)
def get(self, name: str, default: Any = None) -> Any:
"""Allow dict-style access to a dataclass."""
name = name.replace("-", "_")
if hasattr(self, name):
return getattr(self, name)
if default is not None:
return default
msg = f"'{name}' is not defined"
raise ValueError(msg)
def __contains__(self, name: str) -> bool:
"""Provide dict-style 'in' check."""
return getattr(self, name.replace("-", "_"), None) is not None
def __iter__(self) -> Iterator[Any]:
"""Provide dict-style iteration."""
for f in fields(self): # type: ignore[arg-type]
yield getattr(self, f.name)
def copy(self) -> dataclass: # type: ignore[valid-type]
"""Return a copy of the dataclass."""
return replace(self) # type: ignore[type-var]
@staticmethod
def _coerce_lists(value: list[Any] | dict[str, Any] | Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return value # It's already a list of dicts
if isinstance(value, dict):
return [value] # Promote single dict to a list
msg = "Could not convert 'list' to 'list[dict[str, Any]]'"
raise ValueError(msg)
@dataclass
class ListenerConfig(Dictable):
"""Structured configuration for a broker's listeners."""
type: ListenerType = ListenerType.TCP
"""Type of listener: `tcp` for 'mqtt' or `ws` for 'websocket' when specified in dictionary or yaml.'"""
bind: str | None = "0.0.0.0:1883"
"""address and port for the listener to bind to"""
max_connections: int = 0
"""max number of connections allowed for this listener"""
ssl: bool = False
"""secured by ssl"""
cafile: str | Path | None = None
"""Path to a file of concatenated CA certificates in PEM format. See
[Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
capath: str | Path | None = None
"""Path to a directory containing one or more CA certificates in PEM format, following the
[OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
cadata: str | Path | None = None
"""Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates."""
certfile: str | Path | None = None
"""Full path to file in PEM format containing the server's certificate (as well as any number of CA
certificates needed to establish the certificate's authenticity.)"""
keyfile: str | Path | None = None
"""Full path to file in PEM format containing the server's private key."""
reader: str | None = None
writer: str | None = None
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if (self.certfile is None) ^ (self.keyfile is None):
msg = "If specifying the 'certfile' or 'keyfile', both are required."
raise ValueError(msg)
for fn in ("cafile", "capath", "certfile", "keyfile"):
if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn)))
if getattr(self, fn) and not getattr(self, fn).exists():
msg = f"'{fn}' does not exist : {getattr(self, fn)}"
raise FileNotFoundError(msg)
def apply(self, other: "ListenerConfig") -> None:
"""Apply the field from 'other', if 'self' field is default."""
for f in fields(self):
if getattr(self, f.name) == f.default:
setattr(self, f.name, other[f.name])
def default_listeners() -> dict[str, Any]:
"""Create defaults for BrokerConfig.listeners."""
return {
"default": ListenerConfig()
}
def default_broker_plugins() -> dict[str, Any]:
"""Create defaults for BrokerConfig.plugins."""
return {
"amqtt.plugins.logging_amqtt.EventLoggerPlugin": {},
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {},
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 20}
}
@dataclass
class BrokerConfig(Dictable):
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
"""Network of listeners used by the services. a 'default' named listener is required; if another listener
does not set a value, the 'default' settings are applied. See
[`ListenerConfig`](broker_config.md#amqtt.contexts.ListenerConfig) for more information."""
sys_interval: int | None = None
"""*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../plugins/packaged_plugins.md#sys-topics)
for recommended configuration.*"""
timeout_disconnect_delay: int | None = 0
"""Client disconnect timeout without a keep-alive."""
session_expiry_interval: int | None = None
"""Seconds for an inactive session to be retained."""
auth: dict[str, Any] | None = None
"""*Deprecated field used to config EntryPoint-loaded plugins. See
[`AnonymousAuthPlugin`](../plugins/packaged_plugins.md#anonymous-auth-plugin) and
[`FileAuthPlugin`](../plugins/packaged_plugins.md#password-file-auth-plugin) for recommended configuration.*"""
topic_check: dict[str, Any] | None = None
"""*Deprecated field used to config EntryPoint-loaded plugins. See
[`TopicTabooPlugin`](../plugins/packaged_plugins.md#taboo-topic-plugin) and
[`TopicACLPlugin`](../plugins/packaged_plugins.md#acl-topic-plugin) for recommended configuration method.*"""
plugins: dict[str, Any] | list[str | dict[str, Any]] | None = field(default_factory=default_broker_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
[custom plugins](../plugins/custom_plugins.md) for more information. `list[str | dict[str,Any]]` is deprecated but available
to support legacy use cases."""
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.sys_interval is not None:
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
if self.auth is not None or self.topic_check is not None:
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
default_listener = self.listeners["default"]
for listener_name, listener in self.listeners.items():
if listener_name == "default":
continue
listener.apply(default_listener)
if isinstance(self.plugins, list):
_plugins: dict[str, Any] = {}
for plugin in self.plugins:
# in case a plugin in a yaml file is listed without config map
if isinstance(plugin, str):
_plugins |= {plugin: {}}
continue
_plugins |= plugin
self.plugins = _plugins
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig":
"""Create a broker config from a dictionary."""
if d is None:
return BrokerConfig()
# patch the incoming dictionary so it can be loaded correctly
if "topic-check" in d:
d["topic_check"] = d["topic-check"]
del d["topic-check"]
# identify EntryPoint plugin loading and prevent 'plugins' from getting defaults
if ("auth" in d or "topic-check" in d) and "plugins" not in d:
d["plugins"] = None
return dict_to_dataclass(data_class=BrokerConfig,
data=d,
config=DaciteConfig(
cast=[StrEnum, ListenerType],
strict=True,
type_hooks={list[dict[str, Any]]: cls._coerce_lists}
))
@dataclass
class ConnectionConfig(Dictable):
"""Properties for connecting to the broker."""
uri: str | None = "mqtt://127.0.0.1:1883"
"""URI of the broker"""
cafile: str | Path | None = None
"""Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
[Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
capath: str | Path | None = None
"""Path to a directory containing one or more CA certificates in PEM format, following the
[OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
cadata: str | None = None
"""The certificate to verify the broker's authenticity in an ASCII string format of one or more PEM-encoded
certificates or a bytes-like object of DER-encoded certificates."""
certfile: str | Path | None = None
"""Full path to file in PEM format containing the client's certificate (as well as any number of CA
certificates needed to establish the certificate's authenticity.)"""
keyfile: str | Path | None = None
"""Full path to file in PEM format containing the client's private key associated with the certfile."""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if (self.certfile is None) ^ (self.keyfile is None):
msg = "If specifying the 'certfile' or 'keyfile', both are required."
raise ValueError(msg)
for fn in ("cafile", "capath", "certfile", "keyfile"):
if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn)))
@dataclass
class TopicConfig(Dictable):
"""Configuration of how messages to specific topics are published.
The topic name is specified as the key in the dictionary of the `ClientConfig.topics.
"""
qos: int = 0
"""The quality of service associated with the publishing to this topic."""
retain: bool = False
"""Determines if the message should be retained by the topic it was published."""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
msg = "Topic config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
@dataclass
class WillConfig(Dictable):
"""Configuration of the 'last will & testament' of the client upon improper disconnection."""
topic: str
"""The will message will be published to this topic."""
message: str
"""The contents of the message to be published."""
qos: int | None = QOS_0
"""The quality of service associated with sending this message."""
retain: bool | None = False
"""Determines if the message should be retained by the topic it was published."""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
msg = "Will config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
def default_client_plugins() -> dict[str, Any]:
"""Create defaults for `ClientConfig.plugins`."""
return {
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {}
}
@dataclass
class ClientConfig(Dictable):
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
keep_alive: int | None = 10
"""Keep-alive timeout sent to the broker."""
ping_delay: int | None = 1
"""Auto-ping delay before keep-alive timeout. Setting to 0 will disable which may lead to broker disconnection."""
default_qos: int | None = QOS_0
"""Default QoS for messages published."""
default_retain: bool | None = False
"""Default retain value to messages published."""
auto_reconnect: bool | None = True
"""Enable or disable auto-reconnect if connection with the broker is interrupted."""
connection_timeout: int | None = 60
"""The number of seconds before a connection times out"""
reconnect_retries: int | None = 2
"""Number of reconnection retry attempts. Negative value will cause client to reconnect indefinitely."""
reconnect_max_interval: int | None = 10
"""Maximum seconds to wait before retrying a connection."""
cleansession: bool | None = True
"""Upon reconnect, should subscriptions be cleared. Can be overridden by `MQTTClient.connect`"""
topics: dict[str, TopicConfig] | None = field(default_factory=dict)
"""Specify the topics and what flags should be set for messages published to them."""
broker: ConnectionConfig | None = None
"""*Deprecated* Configuration for connecting to the broker. Use `connection` field instead."""
connection: ConnectionConfig = field(default_factory=ConnectionConfig)
"""Configuration for connecting to the broker. See
[`ConnectionConfig`](client_config.md#amqtt.contexts.ConnectionConfig) for more information."""
plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is
a dictionary of configuration options for that plugin. See [custom plugins](../plugins/custom_plugins.md) for
more information. `list[str | dict[str,Any]]` is deprecated but available to support legacy use cases."""
check_hostname: bool | None = True
"""If establishing a secure connection, should the hostname of the certificate be verified."""
will: WillConfig | None = None
"""Message, topic and flags that should be sent to if the client disconnects. See
[`WillConfig`](client_config.md#amqtt.contexts.WillConfig) for more information."""
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2):
msg = "Client config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
if self.broker is not None:
warnings.warn("The 'broker' option is deprecated, please use 'connection' instead.", stacklevel=2)
self.connection = self.broker
if bool(not self.connection.keyfile) ^ bool(not self.connection.certfile):
msg = "Connection key and certificate files are _both_ required."
raise ValueError(msg)
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig":
"""Create a client config from a dictionary."""
if d is None:
return ClientConfig()
return dict_to_dataclass(data_class=ClientConfig,
data=d,
config=DaciteConfig(
cast=[StrEnum],
strict=True)
)

Wyświetl plik

@ -0,0 +1,47 @@
"""Module for contributed plugins."""
from dataclasses import asdict, is_dataclass
from typing import Any, TypeVar
from sqlalchemy import JSON, TypeDecorator
T = TypeVar("T")
class DataClassListJSON(TypeDecorator[list[dict[str, Any]]]):
impl = JSON
cache_ok = True
def __init__(self, dataclass_type: type[T]) -> None:
if not is_dataclass(dataclass_type):
msg = f"{dataclass_type} must be a dataclass type"
raise TypeError(msg)
self.dataclass_type = dataclass_type
super().__init__()
def process_bind_param(
self,
value: list[Any] | None, # Python -> DB
dialect: Any
) -> list[dict[str, Any]] | None:
if value is None:
return None
return [asdict(item) for item in value]
def process_result_value(
self,
value: list[dict[str, Any]] | None, # DB -> Python
dialect: Any
) -> list[Any] | None:
if value is None:
return None
return [self.dataclass_type(**item) for item in value]
def process_literal_param(self, value: Any, dialect: Any) -> Any:
# Required by SQLAlchemy, typically used for literal SQL rendering.
return value
@property
def python_type(self) -> type:
# Required by TypeEngine to indicate the expected Python type.
return list

Wyświetl plik

@ -0,0 +1,52 @@
"""Plugin to determine authentication of clients with DB storage."""
from dataclasses import dataclass
import click
try:
from enum import StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
from .plugin import TopicAuthDBPlugin, UserAuthDBPlugin
class DBType(StrEnum):
"""Enumeration for supported relational databases."""
MARIA = "mariadb"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
SQLITE = "sqlite"
@dataclass
class DBInfo:
"""SQLAlchemy database information."""
connect_str: str
connect_port: int | None
_db_map = {
DBType.MARIA: DBInfo("mysql+aiomysql", 3306),
DBType.MYSQL: DBInfo("mysql+aiomysql", 3306),
DBType.POSTGRESQL: DBInfo("postgresql+asyncpg", 5432),
DBType.SQLITE: DBInfo("sqlite+aiosqlite", None)
}
def db_connection_str(db_type: DBType, db_username: str, db_host: str, db_port: int | None, db_filename: str) -> str:
"""Create sqlalchemy database connection string."""
db_info = _db_map[db_type]
if db_type == DBType.SQLITE:
return f"{db_info.connect_str}:///{db_filename}"
db_password = click.prompt("Enter the db password (press enter for none)", hide_input=True)
pwd = f":{db_password}" if db_password else ""
return f"{db_info.connect_str}://{db_username}:{pwd}@{db_host}:{db_port or db_info.connect_port}"
__all__ = ["DBType", "TopicAuthDBPlugin", "UserAuthDBPlugin", "db_connection_str"]

Wyświetl plik

@ -0,0 +1,190 @@
from collections.abc import Iterator
import logging
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from amqtt.contexts import Action
from amqtt.contrib.auth_db.models import AllowedTopic, Base, TopicAuth, UserAuth
from amqtt.errors import MQTTError
logger = logging.getLogger(__name__)
class UserManager:
def __init__(self, connection: str) -> None:
self._engine = create_async_engine(connection)
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
async def db_sync(self) -> None:
"""Sync the database schema."""
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@staticmethod
async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> UserAuth:
stmt = select(UserAuth).filter(UserAuth.username == username)
user_auth = await db_session.scalar(stmt)
if not user_auth:
msg = f"Username '{username}' doesn't exist."
logger.debug(msg)
raise MQTTError(msg)
return user_auth
async def get_user_auth(self, username: str) -> UserAuth | None:
"""Retrieve a user by username."""
async with self._db_session_maker() as db_session, db_session.begin():
try:
return await self._get_auth_or_raise(db_session, username)
except MQTTError:
return None
async def list_user_auths(self) -> Iterator[UserAuth]:
"""Return list of all clients."""
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(UserAuth).order_by(UserAuth.username)
users = await db_session.scalars(stmt)
if not users:
msg = "No users exist."
logger.info(msg)
raise MQTTError(msg)
return users
async def create_user_auth(self, username: str, plain_password: str) -> UserAuth | None:
"""Create a new user."""
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(UserAuth).filter(UserAuth.username == username)
user_auth = await db_session.scalar(stmt)
if user_auth:
msg = f"Username '{username}' already exists."
logger.info(msg)
raise MQTTError(msg)
user_auth = UserAuth(username=username)
user_auth.password = plain_password
db_session.add(user_auth)
await db_session.commit()
await db_session.flush()
return user_auth
async def delete_user_auth(self, username: str) -> UserAuth | None:
"""Delete a user."""
async with self._db_session_maker() as db_session, db_session.begin():
try:
user_auth = await self._get_auth_or_raise(db_session, username)
except MQTTError:
return None
await db_session.delete(user_auth)
await db_session.commit()
await db_session.flush()
return user_auth
async def update_user_auth_password(self, username: str, plain_password: str) -> UserAuth | None:
"""Change a user's password."""
async with self._db_session_maker() as db_session, db_session.begin():
user_auth = await self._get_auth_or_raise(db_session, username)
user_auth.password = plain_password
await db_session.commit()
await db_session.flush()
return user_auth
class TopicManager:
def __init__(self, connection: str) -> None:
self._engine = create_async_engine(connection)
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
async def db_sync(self) -> None:
"""Sync the database schema."""
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@staticmethod
async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> TopicAuth:
stmt = select(TopicAuth).filter(TopicAuth.username == username)
topic_auth = await db_session.scalar(stmt)
if not topic_auth:
msg = f"Username '{username}' doesn't exist."
logger.debug(msg)
raise MQTTError(msg)
return topic_auth
@staticmethod
def _field_name(action: Action) -> str:
return f"{action}_acl"
async def create_topic_auth(self, username: str) -> TopicAuth | None:
"""Create a new user."""
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(TopicAuth).filter(TopicAuth.username == username)
topic_auth = await db_session.scalar(stmt)
if topic_auth:
msg = f"Username '{username}' already exists."
raise MQTTError(msg)
topic_auth = TopicAuth(username=username)
db_session.add(topic_auth)
await db_session.commit()
await db_session.flush()
return topic_auth
async def get_topic_auth(self, username: str) -> TopicAuth | None:
"""Retrieve a allowed topics by username."""
async with self._db_session_maker() as db_session, db_session.begin():
try:
return await self._get_auth_or_raise(db_session, username)
except MQTTError:
return None
async def list_topic_auths(self) -> Iterator[TopicAuth]:
"""Return list of all authorized clients."""
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(TopicAuth).order_by(TopicAuth.username)
topics = await db_session.scalars(stmt)
if not topics:
msg = "No topics exist."
logger.info(msg)
raise MQTTError(msg)
return topics
async def add_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None:
"""Add allowed topic from action for user."""
if action == Action.PUBLISH and topic.startswith("$"):
msg = "MQTT does not allow clients to publish to $ topics."
raise MQTTError(msg)
async with self._db_session_maker() as db_session, db_session.begin():
user_auth = await self._get_auth_or_raise(db_session, username)
topic_list = getattr(user_auth, self._field_name(action))
updated_list = [*topic_list, AllowedTopic(topic)]
setattr(user_auth, self._field_name(action), updated_list)
await db_session.commit()
await db_session.flush()
return updated_list
async def remove_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None:
"""Remove topic from action for user."""
async with self._db_session_maker() as db_session, db_session.begin():
topic_auth = await self._get_auth_or_raise(db_session, username)
topic_list = topic_auth.get_topic_list(action)
if AllowedTopic(topic) not in topic_list:
msg = f"Client '{username}' doesn't have topic '{topic}' for action '{action}'."
logger.debug(msg)
raise MQTTError(msg)
updated_list = [allowed_topic for allowed_topic in topic_list if allowed_topic != AllowedTopic(topic)]
setattr(topic_auth, f"{action}_acl", updated_list)
await db_session.commit()
await db_session.flush()
return updated_list

Wyświetl plik

@ -0,0 +1,124 @@
from dataclasses import dataclass
import logging
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from sqlalchemy import String
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from amqtt.contexts import Action
from amqtt.contrib import DataClassListJSON
from amqtt.plugins import TopicMatcher
if TYPE_CHECKING:
from passlib.context import CryptContext
logger = logging.getLogger(__name__)
matcher = TopicMatcher()
@dataclass
class AllowedTopic:
topic: str
def __contains__(self, item: Union[str, "AllowedTopic"]) -> bool:
"""Determine `in`."""
return self.__eq__(item)
def __eq__(self, item: object) -> bool:
"""Determine `==` or `!=`."""
if isinstance(item, str):
return matcher.is_topic_allowed(item, self.topic)
if isinstance(item, AllowedTopic):
return item.topic == self.topic
msg = "AllowedTopic can only be compared to another AllowedTopic or string."
raise AttributeError(msg)
def __str__(self) -> str:
"""Display topic."""
return self.topic
def __repr__(self) -> str:
"""Display topic."""
return self.topic
class PasswordHasher:
"""singleton to initialize the CryptContext and then use it elsewhere in the code."""
_instance: Optional["PasswordHasher"] = None
def __init__(self) -> None:
if not hasattr(self, "_crypt_context"):
self._crypt_context: CryptContext | None = None
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "PasswordHasher":
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
@property
def crypt_context(self) -> "CryptContext":
if not self._crypt_context:
msg = "CryptContext is empty"
raise ValueError(msg)
return self._crypt_context
@crypt_context.setter
def crypt_context(self, value: "CryptContext") -> None:
self._crypt_context = value
class Base(DeclarativeBase):
pass
class UserAuth(Base):
__tablename__ = "user_auth"
id: Mapped[int] = mapped_column(primary_key=True)
username: Mapped[str] = mapped_column(String, unique=True)
_password_hash: Mapped[str] = mapped_column("password_hash", String(128))
publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
@hybrid_property
def password(self) -> None:
msg = "Password is write-only"
raise AttributeError(msg)
@password.inplace.setter # type: ignore[arg-type]
def _password_setter(self, plain_password: str) -> None:
self._password_hash = PasswordHasher().crypt_context.hash(plain_password)
def verify_password(self, plain_password: str) -> bool:
return bool(PasswordHasher().crypt_context.verify(plain_password, self._password_hash))
def __str__(self) -> str:
"""Display client id and password hash."""
return f"'{self.username}' with password hash: {self._password_hash}"
class TopicAuth(Base):
__tablename__ = "topic_auth"
id: Mapped[int] = mapped_column(primary_key=True)
username: Mapped[str] = mapped_column(String, unique=True)
publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
def get_topic_list(self, action: Action) -> list[AllowedTopic]:
return cast("list[AllowedTopic]", getattr(self, f"{action}_acl"))
def __str__(self) -> str:
"""Display client id and password hash."""
return f"""'{self.username}':
\tpublish: {self.publish_acl}, subscribe: {self.subscribe_acl}, receive: {self.receive_acl}
"""

Wyświetl plik

@ -0,0 +1,111 @@
from dataclasses import dataclass, field
import logging
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import create_async_engine
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.contrib.auth_db.managers import TopicManager, UserManager
from amqtt.contrib.auth_db.models import Base, PasswordHasher
from amqtt.errors import MQTTError
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
def default_hash_scheme() -> list[str]:
"""Create config dataclass defaults."""
return ["argon2", "bcrypt", "pbkdf2_sha256", "scrypt"]
class UserAuthDBPlugin(BaseAuthPlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
# access the singleton and set the proper crypt context
pwd_hasher = PasswordHasher()
pwd_hasher.crypt_context = CryptContext(schemes=self.config.hash_schemes, deprecated="auto")
self._user_manager = UserManager(self.config.connection)
self._engine = create_async_engine(f"{self.config.connection}")
async def on_broker_pre_start(self) -> None:
"""Sync the schema (if configured)."""
if not self.config.sync_schema:
return
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def authenticate(self, *, session: Session) -> bool | None:
"""Authenticate a client's session."""
if not session.username or not session.password:
return False
user_auth = await self._user_manager.get_user_auth(session.username)
if not user_auth:
return False
return bool(session.password) and user_auth.verify_password(session.password)
@dataclass
class Config:
"""Configuration for DB authentication."""
connection: str
"""SQLAlchemy connection string for the asyncio version of the database connector:
- `mysql+aiomysql://user:password@host:port/dbname`
- `postgresql+asyncpg://user:password@host:port/dbname`
- `sqlite+aiosqlite:///dbfilename.db`
"""
sync_schema: bool = False
"""Use SQLAlchemy to create / update the database schema."""
hash_schemes: list[str] = field(default_factory=default_hash_scheme)
"""list of hash schemes to use for passwords"""
class TopicAuthDBPlugin(BaseTopicPlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self._topic_manager = TopicManager(self.config.connection)
self._engine = create_async_engine(f"{self.config.connection}")
async def on_broker_pre_start(self) -> None:
"""Sync the schema (if configured)."""
if not self.config.sync_schema:
return
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool | None:
if not session or not session.username or not topic:
return None
try:
topic_auth = await self._topic_manager.get_topic_auth(session.username)
topic_list = getattr(topic_auth, f"{action}_acl")
except MQTTError:
return False
return topic in topic_list
@dataclass
class Config:
"""Configuration for DB topic filtering."""
connection: str
"""SQLAlchemy connection string for the asyncio version of the database connector:
- `mysql+aiomysql://user:password@host:port/dbname`
- `postgresql+asyncpg://user:password@host:port/dbname`
- `sqlite+aiosqlite:///dbfilename.db`
"""
sync_schema: bool = False
"""Use SQLAlchemy to create / update the database schema."""

Wyświetl plik

@ -0,0 +1,151 @@
import asyncio
import contextlib
import logging
from pathlib import Path
from typing import Annotated
import typer
from amqtt.contexts import Action
from amqtt.contrib.auth_db import DBType, db_connection_str
from amqtt.contrib.auth_db.managers import TopicManager, UserManager
from amqtt.errors import MQTTError
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
topic_app = typer.Typer(no_args_is_help=True)
@topic_app.callback()
def main(
ctx: typer.Context,
db_type: Annotated[DBType, typer.Option("--db", "-d", help="db type", count=False)],
db_username: Annotated[str, typer.Option("--username", "-u", help="db username", show_default=False)] = "",
db_port: Annotated[int, typer.Option("--port", "-p", help="database port (defaults to db type)", show_default=False)] = 0,
db_host: Annotated[str, typer.Option("--host", "-h", help="database host")] = "localhost",
db_filename: Annotated[str, typer.Option("--file", "-f", help="database file name (sqlite only)")] = "auth.db",
) -> None:
"""Command line interface to add / remove topic authorization.
Passwords are not allowed to be passed via the command line for security reasons. You will be prompted for database
password (if applicable).
If you need to create users programmatically, see `amqtt.contrib.auth_db.managers.TopicManager` which provides
the underlying functionality to this command line interface.
"""
if db_type == DBType.SQLITE and ctx.invoked_subcommand == "sync" and not Path(db_filename).exists():
pass
elif db_type == DBType.SQLITE and not Path(db_filename).exists():
logger.error(f"SQLite option could not find '{db_filename}'")
raise typer.Exit(code=1)
elif db_type != DBType.SQLITE and not db_username:
logger.error("DB access requires a username be provided.")
raise typer.Exit(code=1)
ctx.obj = {"type": db_type, "username": db_username, "host": db_host, "port": db_port, "filename": db_filename}
@topic_app.command(name="sync")
def db_sync(ctx: typer.Context) -> None:
"""Create the table and schema for username and topic lists for subscribe, publish or receive.
Non-destructive if run multiple times. To clear the whole table, need to drop it manually.
"""
async def run_sync() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
mgr = UserManager(connect)
try:
await mgr.db_sync()
except MQTTError as me:
logger.critical("Could not sync schema on db.")
raise typer.Exit(code=1) from me
asyncio.run(run_sync())
logger.info("Success: database synced.")
@topic_app.command(name="list")
def list_clients(ctx: typer.Context) -> None:
"""List all Client IDs (in alphabetical order). Will also display the hashed passwords."""
async def run_list() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
mgr = TopicManager(connect)
user_count = 0
for user in await mgr.list_topic_auths():
user_count += 1
logger.info(user)
if not user_count:
logger.info("No client authorizations exist.")
asyncio.run(run_list())
@topic_app.command(name="add")
def add_topic_allowance(
ctx: typer.Context,
topic: Annotated[str, typer.Argument(help="list of topics", show_default=False)],
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client", show_default=False)],
action: Annotated[Action, typer.Option("--action", "-a", help="action for topic to allow", show_default=False)]
) -> None:
"""Create a new user with a client id and password (prompted)."""
async def run_add() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = TopicManager(connect)
with contextlib.suppress(MQTTError):
await mgr.create_topic_auth(client_id)
topic_auth = await mgr.get_topic_auth(client_id)
if not topic_auth:
logger.info(f"Topic auth doesn't exist for '{client_id}'")
raise typer.Exit(code=1)
if topic in [allowed_topic.topic for allowed_topic in topic_auth.get_topic_list(action)]:
logger.info(f"Topic '{topic}' already exists for '{action}'.")
raise typer.Exit(1)
await mgr.add_allowed_topic(client_id, topic, action)
logger.info(f"Success: topic '{topic}' added to {action} for '{client_id}'")
asyncio.run(run_add())
@topic_app.command(name="rm")
def remove_topic_allowance(ctx: typer.Context,
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client to remove")],
action: Annotated[Action, typer.Option("--action", "-a", help="action for topic to allow")],
topic: Annotated[str, typer.Argument(help="list of topics")]
) -> None:
"""Remove a client from the authentication database."""
async def run_remove() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = TopicManager(connect)
topic_auth = await mgr.get_topic_auth(client_id)
if not topic_auth:
logger.info(f"client '{client_id}' doesn't exist.")
raise typer.Exit(1)
if topic not in getattr(topic_auth, f"{action}_acl"):
logger.info(f"Error: topic '{topic}' not in the {action} allow list for {client_id}.")
raise typer.Exit(1)
try:
await mgr.remove_allowed_topic(client_id, topic, action)
except MQTTError as me:
logger.info(f"'Error: could not remove '{topic}' for client '{client_id}'.")
raise typer.Exit(1) from me
logger.info(f"Success: removed topic '{topic}' from {action} for '{client_id}'")
asyncio.run(run_remove())
if __name__ == "__main__":
topic_app()

Wyświetl plik

@ -0,0 +1,161 @@
import asyncio
import logging
from pathlib import Path
from typing import Annotated
import click
import passlib
import typer
from amqtt.contrib.auth_db import DBType, db_connection_str
from amqtt.contrib.auth_db.managers import UserManager
from amqtt.errors import MQTTError
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
user_app = typer.Typer(no_args_is_help=True)
@user_app.callback()
def main(
ctx: typer.Context,
db_type: Annotated[DBType, typer.Option(..., "--db", "-d", help="db type", show_default=False)],
db_username: Annotated[str, typer.Option("--username", "-u", help="db username", show_default=False)] = "",
db_port: Annotated[int, typer.Option("--port", "-p", help="database port (defaults to db type)", show_default=False)] = 0,
db_host: Annotated[str, typer.Option("--host", "-h", help="database host")] = "localhost",
db_filename: Annotated[str, typer.Option("--file", "-f", help="database file name (sqlite only)")] = "auth.db",
) -> None:
"""Command line interface to list, create, remove and add clients.
Passwords are not allowed to be passed via the command line for security reasons. You will be prompted for database
password (if applicable) and the client id's password.
If you need to create users programmatically, see `amqtt.contrib.auth_db.managers.UserManager` which provides
the underlying functionality to this command line interface.
"""
if db_type == DBType.SQLITE and ctx.invoked_subcommand == "sync" and not Path(db_filename).exists():
pass
elif db_type == DBType.SQLITE and not Path(db_filename).exists():
logger.error(f"SQLite option could not find '{db_filename}'")
raise typer.Exit(code=1)
elif db_type != DBType.SQLITE and not db_username:
logger.error("DB access requires a username be provided.")
raise typer.Exit(code=1)
ctx.obj = {"type": db_type, "username": db_username, "host": db_host, "port": db_port, "filename": db_filename}
@user_app.command(name="sync")
def db_sync(ctx: typer.Context) -> None:
"""Create the table and schema for username and hashed password.
Non-destructive if run multiple times. To clear the whole table, need to drop it manually.
"""
async def run_sync() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
mgr = UserManager(connect)
try:
await mgr.db_sync()
except MQTTError as me:
logger.critical("Could not sync schema on db.")
raise typer.Exit(code=1) from me
asyncio.run(run_sync())
logger.info("Success: database synced.")
@user_app.command(name="list")
def list_user_auths(ctx: typer.Context) -> None:
"""List all Client IDs (in alphabetical order). Will also display the hashed passwords."""
async def run_list() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
mgr = UserManager(connect)
user_count = 0
for user in await mgr.list_user_auths():
user_count += 1
logger.info(user)
if not user_count:
logger.info("No client authentications exist.")
asyncio.run(run_list())
@user_app.command(name="add")
def create_user_auth(
ctx: typer.Context,
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the new client")],
) -> None:
"""Create a new user with a client id and password (prompted)."""
async def run_create() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = UserManager(connect)
client_password = click.prompt("Enter the client's password", hide_input=True)
if not client_password.strip():
logger.info("Error: client password cannot be empty.")
raise typer.Exit(1)
try:
user = await mgr.create_user_auth(client_id, client_password.strip())
except passlib.exc.MissingBackendError as mbe:
logger.info(f"Please install backend: {mbe}")
raise typer.Exit(code=1) from mbe
if not user:
logger.info(f"Error: could not create user: {client_id}")
raise typer.Exit(code=1)
logger.info(f"Success: created {user}")
asyncio.run(run_create())
@user_app.command(name="rm")
def remove_user_auth(ctx: typer.Context,
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client to remove")]) -> None:
"""Remove a client from the authentication database."""
async def run_remove() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = UserManager(connect)
user = await mgr.get_user_auth(client_id)
if not user:
logger.info(f"Error: client '{client_id}' does not exist.")
raise typer.Exit(1)
if not click.confirm(f"Please confirm the removal of '{client_id}'?"):
raise typer.Exit(0)
user = await mgr.delete_user_auth(client_id)
if not user:
logger.info(f"Error: client '{client_id}' does not exist.")
raise typer.Exit(1)
logger.info(f"Success: '{user.username}' was removed.")
asyncio.run(run_remove())
@user_app.command(name="pwd")
def change_password(
ctx: typer.Context,
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the new client")],
) -> None:
"""Update a user's password (prompted)."""
async def run_password() -> None:
client_password = click.prompt("Enter the client's new password", hide_input=True)
if not client_password.strip():
logger.error("Error: client password cannot be empty.")
raise typer.Exit(1)
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = UserManager(connect)
await mgr.update_user_auth_password(client_id, client_password.strip())
logger.info(f"Success: client '{client_id}' password updated.")
asyncio.run(run_password())
if __name__ == "__main__":
user_app()

Wyświetl plik

@ -0,0 +1,250 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
try:
from datetime import UTC
except ImportError:
# support for python 3.10
from datetime import timezone
UTC = timezone.utc
from ipaddress import IPv4Address
import logging
from pathlib import Path
import re
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509 import Certificate, CertificateSigningRequest
from cryptography.x509.oid import NameOID
from amqtt.plugins.base import BaseAuthPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class UserAuthCertPlugin(BaseAuthPlugin):
"""Used a *signed* x509 certificate's `Subject AlternativeName` or `SAN` to verify client authentication.
Often used for IoT devices, this method provides the most secure form of identification. A root
certificate, often referenced as a CA certificate -- either issued by a known authority (such as LetsEncrypt)
or a self-signed certificate) is used to sign a private key and certificate for the server. Each device/client
also gets a unique private key and certificate signed by the same CA certificate; also included in the device
certificate is a 'SAN' or SubjectAlternativeName which is the device's unique identifier.
Since both server and device certificates are signed by the same CA certificate, the client can
verify the server's authenticity; and the server can verify the client's authenticity. And since
the device's certificate contains a x509 SAN, the server (with this plugin) can identify the device securely.
!!! note "URI and Client ID configuration"
`uri_domain` configuration must be set to the same uri used to generate the device credentials
when a device is connecting with private key and certificate, the `client_id` must
match the device id used to generate the device credentials.
Available ore three scripts to help with the key generation and certificate signing: `ca_creds`, `server_creds`
and `device_creds`.
!!! note "Configuring broker & client for using Self-signed root CA"
If using self-signed root credentials, the `cafile` configuration for both broker and client need to be
configured with `cafile` set to the `ca.crt`.
"""
async def authenticate(self, *, session: Session) -> bool | None:
"""Verify the client's session using the provided client's x509 certificate."""
if not session.ssl_object:
return False
der_cert = session.ssl_object.getpeercert(binary_form=True)
if der_cert:
cert = x509.load_der_x509_certificate(der_cert, backend=default_backend())
try:
san = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
uris = san.value.get_values_for_type(x509.UniformResourceIdentifier)
if self.config.uri_domain not in uris[0]:
return False
pattern = rf"^spiffe://{re.escape(self.config.uri_domain)}/device/([^/]+)$"
match = re.match(pattern, uris[0])
if not match:
return False
return match.group(1) == session.client_id
except x509.ExtensionNotFound:
logger.warning("No SAN extension found.")
return False
@dataclass
class Config:
"""Configuration for the CertificateAuthPlugin."""
uri_domain: str
"""The domain that is expected as part of the device certificate's spiffe (e.g. test.amqtt.io)"""
def generate_root_creds(country: str, state: str, locality: str,
org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
"""Generate CA key and certificate."""
# generate private key for the server
ca_key = rsa.generate_private_key(
public_exponent=65537,
key_size=4096,
)
# Create certificate subject and issuer (self-signed)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state),
x509.NameAttribute(NameOID.LOCALITY_NAME, locality),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name),
x509.NameAttribute(NameOID.COMMON_NAME, cn),
])
# 3. Build self-signed certificate
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(ca_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(UTC))
.not_valid_after(datetime.now(UTC) + timedelta(days=3650)) # 10 years
.add_extension(
x509.BasicConstraints(ca=True, path_length=None),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()),
critical=False,
)
.add_extension(
x509.KeyUsage(
key_cert_sign=True,
crl_sign=True,
digital_signature=False,
key_encipherment=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.sign(ca_key, hashes.SHA256())
)
return ca_key, cert
def generate_server_csr(country: str, org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
"""Generate server private key and server certificate-signing-request."""
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
csr = (
x509.CertificateSigningRequestBuilder()
.subject_name(x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name),
x509.NameAttribute(NameOID.COMMON_NAME, cn),
]))
.add_extension(
x509.SubjectAlternativeName([
x509.DNSName(cn),
x509.IPAddress(IPv4Address("127.0.0.1")),
]),
critical=False,
)
.sign(key, hashes.SHA256())
)
return key, csr
def generate_device_csr(country: str, org_name: str, common_name: str,
uri_san: str, dns_san: str
) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
"""Generate a device key and a csr."""
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
csr = (
x509.CertificateSigningRequestBuilder()
.subject_name(x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name
),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]))
.add_extension(
x509.SubjectAlternativeName([
x509.UniformResourceIdentifier(uri_san),
x509.DNSName(dns_san),
]),
critical=False,
)
.sign(key, hashes.SHA256())
)
return key, csr
def sign_csr(csr: CertificateSigningRequest,
ca_key: rsa.RSAPrivateKey,
ca_cert: Certificate, validity_days: int = 365) -> Certificate:
"""Sign a csr with CA credentials."""
return (
x509.CertificateBuilder()
.subject_name(csr.subject)
.issuer_name(ca_cert.subject)
.public_key(csr.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(UTC))
.not_valid_after(datetime.now(UTC) + timedelta(days=validity_days))
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)
.add_extension(
csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value,
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_cert.public_key()), # type: ignore[arg-type]
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
def load_ca(ca_key_fn: str, ca_crt_fn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
"""Load server key and certificate."""
with Path(ca_key_fn).open("rb") as f:
ca_key: rsa.RSAPrivateKey = serialization.load_pem_private_key(f.read(), password=None) # type: ignore[assignment]
with Path(ca_crt_fn).open("rb") as f:
ca_cert = x509.load_pem_x509_certificate(f.read())
return ca_key, ca_cert
def write_key_and_crt(key: rsa.RSAPrivateKey, crt: Certificate,
prefix: str, path: Path | None = None) -> None:
"""Create pem-encoded files for key and certificate."""
path = path or Path()
crt_fn = path / f"{prefix}.crt"
key_fn = path / f"{prefix}.key"
with crt_fn.open("wb") as f:
f.write(crt.public_bytes(serialization.Encoding.PEM))
with key_fn.open("wb") as f:
f.write(key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.TraditionalOpenSSL,
serialization.NoEncryption()
))

Wyświetl plik

@ -0,0 +1,181 @@
from dataclasses import dataclass
try:
from enum import StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
import logging
from typing import Any
from aiohttp import ClientResponse, ClientSession, FormData
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class ResponseMode(StrEnum):
STATUS = "status"
JSON = "json"
TEXT = "text"
class RequestMethod(StrEnum):
GET = "get"
POST = "post"
PUT = "put"
class ParamsMode(StrEnum):
JSON = "json"
FORM = "form"
class ACLError(Exception):
pass
HTTP_2xx_MIN = 200
HTTP_2xx_MAX = 299
HTTP_4xx_MIN = 400
HTTP_4xx_MAX = 499
@dataclass
class HttpConfig:
"""Configuration for the HTTP Auth & ACL Plugin."""
host: str
"""hostname of the server for the auth & acl check"""
port: int
"""port of the server for the auth & acl check"""
request_method: RequestMethod = RequestMethod.GET
"""send the request as a GET, POST or PUT"""
params_mode: ParamsMode = ParamsMode.JSON # see docs/plugins/http.md for additional details
"""send the request with `JSON` or `FORM` data. *additional details below*"""
response_mode: ResponseMode = ResponseMode.JSON # see docs/plugins/http.md for additional details
"""expected response from the auth/acl server. `STATUS` (code), `JSON`, or `TEXT`. *additional details below*"""
with_tls: bool = False
"""http or https"""
user_agent: str = "amqtt"
"""the 'User-Agent' header sent along with the request"""
superuser_uri: str | None = None
"""URI to verify if the user is a superuser (e.g. '/superuser'), `None` if superuser is not supported"""
timeout: int = 5
"""duration, in seconds, to wait for the HTTP server to respond"""
class AuthHttpPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self.http = ClientSession(headers={"User-Agent": self.config.user_agent})
match self.config.request_method:
case RequestMethod.GET:
self.method = self.http.get
case RequestMethod.PUT:
self.method = self.http.put
case _:
self.method = self.http.post
async def on_broker_pre_shutdown(self) -> None:
await self.http.close()
@staticmethod
def _is_2xx(r: ClientResponse) -> bool:
return HTTP_2xx_MIN <= r.status <= HTTP_2xx_MAX
@staticmethod
def _is_4xx(r: ClientResponse) -> bool:
return HTTP_4xx_MIN <= r.status <= HTTP_4xx_MAX
def _get_params(self, payload: dict[str, Any]) -> dict[str, Any]:
match self.config.params_mode:
case ParamsMode.FORM:
match self.config.request_method:
case RequestMethod.GET:
kwargs = {"params": payload}
case _: # POST, PUT
d: Any = FormData(payload)
kwargs = {"data": d}
case _: # JSON
kwargs = {"json": payload}
return kwargs
async def _send_request(self, url: str, payload: dict[str, Any]) -> bool | None: # pylint: disable=R0911
kwargs = self._get_params(payload)
async with self.method(url, **kwargs) as r:
logger.debug(f"http request returned {r.status}")
match self.config.response_mode:
case ResponseMode.TEXT:
return self._is_2xx(r) and (await r.text()).lower() == "ok"
case ResponseMode.STATUS:
if self._is_2xx(r):
return True
if self._is_4xx(r):
return False
# any other code
return None
case _:
if not self._is_2xx(r):
return False
data: dict[str, Any] = await r.json()
data = {k.lower(): v for k, v in data.items()}
return data.get("ok", None)
def get_url(self, uri: str) -> str:
return f"{'https' if self.config.with_tls else 'http'}://{self.config.host}:{self.config.port}{uri}"
class UserAuthHttpPlugin(AuthHttpPlugin, BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
d = {"username": session.username, "password": session.password, "client_id": session.client_id}
return await self._send_request(self.get_url(self.config.user_uri), d)
@dataclass
class Config(HttpConfig):
"""Configuration for the HTTP Auth Plugin."""
user_uri: str = "/user"
"""URI of the auth check."""
class TopicAuthHttpPlugin(AuthHttpPlugin, BaseTopicPlugin):
async def topic_filtering(self, *,
session: Session | None = None,
topic: str | None = None,
action: Action | None = None) -> bool | None:
if not session:
return None
acc = 0
match action:
case Action.PUBLISH:
acc = 2
case Action.SUBSCRIBE:
acc = 4
case Action.RECEIVE:
acc = 1
d = {"username": session.username, "client_id": session.client_id, "topic": topic, "acc": acc}
return await self._send_request(self.get_url(self.config.topic_uri), d)
@dataclass
class Config(HttpConfig):
"""Configuration for the HTTP Topic Plugin."""
topic_uri: str = "/acl"
"""URI of the topic check."""

Wyświetl plik

@ -0,0 +1,120 @@
from dataclasses import dataclass
import logging
from typing import ClassVar
import jwt
try:
from enum import StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.plugins import TopicMatcher
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class Algorithms(StrEnum):
ES256 = "ES256"
ES256K = "ES256K"
ES384 = "ES384"
ES512 = "ES512"
ES521 = "ES521"
EdDSA = "EdDSA"
HS256 = "HS256"
HS384 = "HS384"
HS512 = "HS512"
PS256 = "PS256"
PS384 = "PS384"
PS512 = "PS512"
RS256 = "RS256"
RS384 = "RS384"
RS512 = "RS512"
class UserAuthJwtPlugin(BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
if not session.username or not session.password:
return None
try:
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
return bool(decoded_payload.get(self.config.user_claim, None) == session.username)
except jwt.ExpiredSignatureError:
logger.debug(f"jwt for '{session.username}' is expired")
return False
except jwt.InvalidTokenError:
logger.debug(f"jwt for '{session.username}' is invalid")
return False
@dataclass
class Config:
"""Configuration for the JWT user authentication."""
secret_key: str
"""Secret key to decrypt the token."""
user_claim: str
"""Payload key for user name."""
algorithm: str = "HS256"
"""Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256',
'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""
class TopicAuthJwtPlugin(BaseTopicPlugin):
_topic_jwt_claims: ClassVar = {
Action.PUBLISH: "publish_claim",
Action.SUBSCRIBE: "subscribe_claim",
Action.RECEIVE: "receive_claim",
}
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self.topic_matcher = TopicMatcher()
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool | None:
if not session or not topic or not action:
return None
if not session.password:
return None
try:
decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"])
claim = getattr(self.config, self._topic_jwt_claims[action])
return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
except jwt.ExpiredSignatureError:
logger.debug(f"jwt for '{session.username}' is expired")
return False
except jwt.InvalidTokenError:
logger.debug(f"jwt for '{session.username}' is invalid")
return False
@dataclass
class Config:
"""Configuration for the JWT topic authorization."""
secret_key: str
"""Secret key to decrypt the token."""
publish_claim: str
"""Payload key for contains a list of permissible publish topics."""
subscribe_claim: str
"""Payload key for contains a list of permissible subscribe topics."""
receive_claim: str
"""Payload key for contains a list of permissible receive topics."""
algorithm: str = "HS256"
"""Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256',
'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""

Wyświetl plik

@ -0,0 +1,138 @@
from dataclasses import dataclass
import logging
from typing import ClassVar
import ldap
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.errors import PluginInitError
from amqtt.plugins import TopicMatcher
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
@dataclass
class LdapConfig:
"""Configuration for the LDAP Plugins."""
server: str
"""uri formatted server location. e.g `ldap://localhost:389`"""
base_dn: str
"""distinguished name (dn) of the ldap server. e.g. `dc=amqtt,dc=io`"""
user_attribute: str
"""attribute in ldap entry to match the username against"""
bind_dn: str
"""distinguished name (dn) of known, preferably read-only, user. e.g. `cn=admin,dc=amqtt,dc=io`"""
bind_password: str
"""password for known, preferably read-only, user"""
class AuthLdapPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self.conn = ldap.initialize(self.config.server)
self.conn.protocol_version = ldap.VERSION3 # pylint: disable=E1101
try:
self.conn.simple_bind_s(self.config.bind_dn, self.config.bind_password)
except ldap.INVALID_CREDENTIALS as e: # pylint: disable=E1101
raise PluginInitError(self.__class__) from e
class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
"""Plugin to authenticate a user with an LDAP directory server."""
async def authenticate(self, *, session: Session) -> bool | None:
# use our initial creds to see if the user exists
search_filter = f"({self.config.user_attribute}={session.username})"
result = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, ["dn"]) # pylint: disable=E1101
if not result:
logger.debug(f"user not found: {session.username}")
return False
try:
# `search_s` responds with list of tuples: (dn, entry); first in list is our match
user_dn = result[0][0]
except IndexError:
return False
try:
user_conn = ldap.initialize(self.config.server)
user_conn.simple_bind_s(user_dn, session.password)
except ldap.INVALID_CREDENTIALS: # pylint: disable=E1101
logger.debug(f"invalid credentials for '{session.username}'")
return False
except ldap.LDAPError as e: # pylint: disable=E1101
logger.debug(f"LDAP error during user bind: {e}")
return False
return True
@dataclass
class Config(LdapConfig):
"""Configuration for the User Auth LDAP Plugin."""
class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
"""Plugin to authenticate a user with an LDAP directory server."""
_action_attr_map: ClassVar = {
Action.PUBLISH: "publish_attribute",
Action.SUBSCRIBE: "subscribe_attribute",
Action.RECEIVE: "receive_attribute"
}
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self.topic_matcher = TopicMatcher()
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool | None:
# if not provided needed criteria, can't properly evaluate topic filtering
if not session or not action or not topic:
return None
search_filter = f"({self.config.user_attribute}={session.username})"
attrs = [
"cn",
self.config.publish_attribute,
self.config.subscribe_attribute,
self.config.receive_attribute
]
results = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, attrs) # pylint: disable=E1101
if not results:
logger.debug(f"user not found: {session.username}")
return False
if len(results) > 1:
found_users = [dn for dn, _ in results]
logger.debug(f"multiple users found: {', '.join(found_users)}")
return False
dn, entry = results[0]
ldap_attribute = getattr(self.config, self._action_attr_map[action])
topic_filters = [t.decode("utf-8") for t in entry.get(ldap_attribute, [])]
logger.debug(f"DN: {dn} - {ldap_attribute}={topic_filters}")
return self.topic_matcher.are_topics_allowed(topic, topic_filters)
@dataclass
class Config(LdapConfig):
"""Configuration for the LDAPAuthPlugin."""
publish_attribute: str
"""LDAP attribute which contains a list of permissible publish topics."""
subscribe_attribute: str
"""LDAP attribute which contains a list of permissible subscribe topics."""
receive_attribute: str
"""LDAP attribute which contains a list of permissible receive topics."""

Wyświetl plik

@ -0,0 +1,264 @@
from dataclasses import dataclass
import logging
from pathlib import Path
from sqlalchemy import Boolean, Integer, LargeBinary, Result, String, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from amqtt.broker import BrokerContext, RetainedApplicationMessage
from amqtt.contrib import DataClassListJSON
from amqtt.errors import PluginError
from amqtt.plugins.base import BasePlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
pass
@dataclass
class RetainedMessage:
topic: str
data: str
qos: int
@dataclass
class Subscription:
topic: str
qos: int
class StoredSession(Base):
__tablename__ = "stored_sessions"
id: Mapped[int] = mapped_column(primary_key=True)
client_id: Mapped[str] = mapped_column(String)
clean_session: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
will_flag: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
will_message: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
will_qos: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None)
will_retain: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=None)
will_topic: Mapped[str | None] = mapped_column(String, nullable=True, default=None)
keep_alive: Mapped[int] = mapped_column(Integer, default=0)
retained: Mapped[list[RetainedMessage]] = mapped_column(DataClassListJSON(RetainedMessage), default=list)
subscriptions: Mapped[list[Subscription]] = mapped_column(DataClassListJSON(Subscription), default=list)
class StoredMessage(Base):
__tablename__ = "stored_messages"
id: Mapped[int] = mapped_column(primary_key=True)
topic: Mapped[str] = mapped_column(String)
data: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
qos: Mapped[int] = mapped_column(Integer, default=0)
class SessionDBPlugin(BasePlugin[BrokerContext]):
"""Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
Configuration:
- file *(string)* path & filename to store the session db. default: `amqtt.db`
- clear_on_shutdown *(bool)* if the broker shutdowns down normally, don't retain any information. default: `True`
"""
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
# bypass the `test_plugins_correct_has_attr` until it can be updated
if not hasattr(self.config, "file"):
logger.warning("`Config` is missing a `file` attribute")
return
self._engine = create_async_engine(f"sqlite+aiosqlite:///{self.config.file}")
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
@staticmethod
async def _get_or_create_session(db_session: AsyncSession, client_id: str) -> StoredSession:
stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
stored_session = await db_session.scalar(stmt)
if stored_session is None:
stored_session = StoredSession(client_id=client_id)
db_session.add(stored_session)
await db_session.flush()
return stored_session
@staticmethod
async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage:
stmt = select(StoredMessage).filter(StoredMessage.topic == topic)
stored_message = await db_session.scalar(stmt)
if stored_message is None:
stored_message = StoredMessage(topic=topic)
db_session.add(stored_message)
await db_session.flush()
return stored_message
async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None:
"""Search to see if session already exists."""
# if client id doesn't exist, create (can ignore if session is anonymous)
# update session information (will, clean_session, etc)
# don't store session information for clean or anonymous sessions
if client_session.clean_session in (None, True) or client_session.is_anonymous:
return
async with self._db_session_maker() as db_session, db_session.begin():
stored_session = await self._get_or_create_session(db_session, client_id)
stored_session.clean_session = client_session.clean_session
stored_session.will_flag = client_session.will_flag
stored_session.will_message = client_session.will_message # type: ignore[assignment]
stored_session.will_qos = client_session.will_qos
stored_session.will_retain = client_session.will_retain
stored_session.will_topic = client_session.will_topic
stored_session.keep_alive = client_session.keep_alive
await db_session.flush()
async def on_broker_client_subscribed(self, client_id: str, topic: str, qos: int) -> None:
"""Create/update subscription if clean session = false."""
session = self.context.get_session(client_id)
if not session:
logger.warning(f"'{client_id}' is subscribing but doesn't have a session")
return
if session.clean_session:
return
async with self._db_session_maker() as db_session, db_session.begin():
# stored sessions shouldn't need to be created here, but we'll use the same helper...
stored_session = await self._get_or_create_session(db_session, client_id)
stored_session.subscriptions = [*stored_session.subscriptions, Subscription(topic, qos)]
await db_session.flush()
async def on_broker_client_unsubscribed(self, client_id: str, topic: str) -> None:
"""Remove subscription if clean session = false."""
async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None:
"""Update to retained messages.
if retained_message.data is None or '', the message is being cleared
"""
# if client_id is valid, the retained message is for a disconnected client
if client_id is not None:
async with self._db_session_maker() as db_session, db_session.begin():
# stored sessions shouldn't need to be created here, but we'll use the same helper...
stored_session = await self._get_or_create_session(db_session, client_id)
stored_session.retained = [*stored_session.retained, RetainedMessage(retained_message.topic,
retained_message.data.decode(),
retained_message.qos or 0)]
await db_session.flush()
return
async with self._db_session_maker() as db_session, db_session.begin():
# if the retained message has data, we need to store/update for the topic
if retained_message.data:
client_message = await self._get_or_create_message(db_session, retained_message.topic)
client_message.data = retained_message.data # type: ignore[assignment]
client_message.qos = retained_message.qos or 0
await db_session.flush()
return
# if there is no data, clear the stored message (if exists) for the topic
stmt = select(StoredMessage).filter(StoredMessage.topic == retained_message.topic)
topic_message = await db_session.scalar(stmt)
if topic_message is not None:
await db_session.delete(topic_message)
await db_session.flush()
return
async def on_broker_pre_start(self) -> None:
"""Initialize the database and db connection."""
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def on_broker_post_start(self) -> None:
"""Load subscriptions."""
if len(self.context.subscriptions) > 0:
msg = "SessionDBPlugin : broker shouldn't have any subscriptions yet"
raise PluginError(msg)
if len(list(self.context.sessions)) > 0:
msg = "SessionDBPlugin : broker shouldn't have any sessions yet"
raise PluginError(msg)
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(StoredSession)
stored_sessions = await db_session.execute(stmt)
restored_sessions = 0
for stored_session in stored_sessions.scalars():
await self.context.add_subscription(stored_session.client_id, None, None)
for subscription in stored_session.subscriptions:
await self.context.add_subscription(stored_session.client_id,
subscription.topic,
subscription.qos)
session = self.context.get_session(stored_session.client_id)
if not session:
continue
session.clean_session = stored_session.clean_session
session.will_flag = stored_session.will_flag
session.will_message = stored_session.will_message
session.will_qos = stored_session.will_qos
session.will_retain = stored_session.will_retain
session.will_topic = stored_session.will_topic
session.keep_alive = stored_session.keep_alive
for message in stored_session.retained:
retained_message = RetainedApplicationMessage(
source_session=None,
topic=message.topic,
data=message.data.encode(),
qos=message.qos
)
await session.retained_messages.put(retained_message)
restored_sessions += 1
stmt = select(StoredMessage)
stored_messages: Result[tuple[StoredMessage]] = await db_session.execute(stmt)
restored_messages = 0
retained_messages = self.context.retained_messages
for stored_message in stored_messages.scalars():
retained_messages[stored_message.topic] = (RetainedApplicationMessage(
source_session=None,
topic=stored_message.topic,
data=stored_message.data or b"",
qos=stored_message.qos
))
restored_messages += 1
logger.info(f"Retained messages restored: {restored_messages}")
logger.info(f"Restored {restored_sessions} sessions.")
async def on_broker_pre_shutdown(self) -> None:
"""Clean up the db connection."""
await self._engine.dispose()
async def on_broker_post_shutdown(self) -> None:
if self.config.clear_on_shutdown and self.config.file.exists():
self.config.file.unlink()
@dataclass
class Config:
"""Configuration variables."""
file: str | Path = "amqtt.db"
"""path & filename to store the sqlite session db."""
clear_on_shutdown: bool = True
"""if the broker shutdowns down normally, don't retain any information."""
def __post_init__(self) -> None:
"""Create `Path` from string path."""
if isinstance(self.file, str):
self.file = Path(self.file)

Wyświetl plik

@ -0,0 +1,6 @@
"""Module for the shadow state plugin."""
from .plugin import ShadowPlugin, ShadowTopicAuthPlugin
from .states import ShadowOperation
__all__ = ["ShadowOperation", "ShadowPlugin", "ShadowTopicAuthPlugin"]

Wyświetl plik

@ -0,0 +1,113 @@
from collections.abc import MutableMapping
from dataclasses import dataclass, fields, is_dataclass
import json
from typing import Any
from amqtt.contrib.shadows.states import MetaTimestamp, ShadowOperation, State, StateDocument
def asdict_no_none(obj: Any) -> Any:
"""Create dictionary from dataclass, but eliminate any key set to `None`."""
if is_dataclass(obj):
result = {}
for f in fields(obj):
value = getattr(obj, f.name)
if value is not None:
result[f.name] = asdict_no_none(value)
return result
if isinstance(obj, list):
return [asdict_no_none(item) for item in obj if item is not None]
if isinstance(obj, dict):
return {
key: asdict_no_none(value)
for key, value in obj.items()
if value is not None
}
return obj
def create_shadow_topic(device_id: str, shadow_name: str, message_op: "ShadowOperation") -> str:
"""Create a shadow topic for message type."""
return f"$shadow/{device_id}/{shadow_name}/{message_op}"
class ShadowMessage:
def to_message(self) -> bytes:
return json.dumps(asdict_no_none(self)).encode("utf-8")
@dataclass
class GetAcceptedMessage(ShadowMessage):
state: State[dict[str, Any]]
metadata: State[MetaTimestamp]
timestamp: int
version: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_ACCEPT)
@dataclass
class GetRejectedMessage(ShadowMessage):
code: int
message: str
timestamp: int | None = None
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_REJECT)
@dataclass
class UpdateAcceptedMessage(ShadowMessage):
state: State[dict[str, Any]]
metadata: State[MetaTimestamp]
timestamp: int
version: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_ACCEPT)
@dataclass
class UpdateRejectedMessage(ShadowMessage):
code: int
message: str
timestamp: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_REJECT)
@dataclass
class UpdateDeltaMessage(ShadowMessage):
state: MutableMapping[str, Any]
metadata: MutableMapping[str, Any]
timestamp: int
version: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DELTA)
class UpdateIotaMessage(UpdateDeltaMessage):
"""Same format, corollary name."""
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_IOTA)
@dataclass
class UpdateDocumentMessage(ShadowMessage):
previous: StateDocument
current: StateDocument
timestamp: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DOCUMENTS)

Wyświetl plik

@ -0,0 +1,140 @@
from collections.abc import Sequence
from dataclasses import asdict
import logging
import time
from typing import Any, Optional
import uuid
from sqlalchemy import JSON, CheckConstraint, Integer, String, UniqueConstraint, desc, event, func, select
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from sqlalchemy.orm import DeclarativeBase, Mapped, Mapper, Session, make_transient, mapped_column
from amqtt.contrib.shadows.states import StateDocument
logger = logging.getLogger(__name__)
class ShadowUpdateError(Exception):
def __init__(self, message: str = "updating an existing Shadow is not allowed") -> None:
super().__init__(message)
class ShadowBase(DeclarativeBase):
pass
async def sync_shadow_base(connection: AsyncConnection) -> None:
"""Create tables and table schemas."""
await connection.run_sync(ShadowBase.metadata.create_all)
def default_state_document() -> dict[str, Any]:
"""Create a default (empty) state document, factory for model field."""
return asdict(StateDocument())
class Shadow(ShadowBase):
__tablename__ = "shadows_shadow"
id: Mapped[str | None] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
device_id: Mapped[str] = mapped_column(String(128), nullable=False)
name: Mapped[str] = mapped_column(String(128), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
_state: Mapped[dict[str, Any]] = mapped_column("state", JSON, nullable=False, default=dict)
created_at: Mapped[int] = mapped_column(Integer, default=lambda: int(time.time()), nullable=False)
__table_args__ = (
CheckConstraint("version > 0", name="check_quantity_positive"),
UniqueConstraint("device_id", "name", "version", name="uq_device_id_name_version"),
)
@property
def state(self) -> StateDocument:
if not self._state:
return StateDocument()
return StateDocument.from_dict(self._state)
@state.setter
def state(self, value: StateDocument) -> None:
self._state = asdict(value)
@classmethod
async def latest_version(cls, session: AsyncSession, device_id: str, name: str) -> Optional["Shadow"]:
"""Get the latest version of the shadow associated with the device and name."""
stmt = (
select(cls).where(
cls.device_id == device_id,
cls.name == name
).order_by(desc(cls.version)).limit(1)
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
@classmethod
async def all(cls, session: AsyncSession, device_id: str, name: str) -> Sequence["Shadow"]:
"""Return a list of all shadows associated with the device and name."""
stmt = (
select(cls).where(
cls.device_id == device_id,
cls.name == name
).order_by(desc(cls.version)))
result = await session.execute(stmt)
return result.scalars().all()
@event.listens_for(Shadow, "before_insert")
def assign_incremental_version(_: Mapper[Any], connection: Session, target: "Shadow") -> None:
"""Get the latest version of the state document."""
stmt = (
select(func.max(Shadow.version))
.where(
Shadow.device_id == target.device_id,
Shadow.name == target.name
)
)
result = connection.execute(stmt).scalar_one_or_none()
target.version = (result or 0) + 1
@event.listens_for(Shadow, "before_update")
def prevent_update(_mapper: Mapper[Any], _session: Session, _instance: "Shadow") -> None:
"""Prevent existing shadow from being updated."""
raise ShadowUpdateError
@event.listens_for(Session, "before_flush")
def convert_update_to_insert(session: Session, _flush_context: object, _instances: object | None) -> None:
"""Force a shadow to insert a new version, instead of updating an existing."""
# Make a copy of the dirty set so we can safely mutate the session
dirty = list(session.dirty)
for obj in dirty:
if not session.is_modified(obj, include_collections=False):
continue # skip unchanged
# You can scope this to a particular class
if not isinstance(obj, Shadow):
continue
# Clone logic: convert update into insert
session.expunge(obj) # remove from session
make_transient(obj) # remove identity and history
obj.id = "" # clear primary key
obj.version += 1 # bump version or modify fields
session.add(obj) # re-add as new object
_listener_example = '''#
# @event.listens_for(Shadow, "before_insert")
# def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None:
# """Listen for insertion and convert state document to json."""
# if not isinstance(target.state, StateDocument):
# msg = "'state' field needs to be a StateDocument"
# raise TypeError(msg)
#
# target.state = target.state.to_dict()
'''

Wyświetl plik

@ -0,0 +1,198 @@
from collections import defaultdict
from dataclasses import dataclass, field
import json
import re
from typing import Any
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.contrib.shadows.messages import (
GetAcceptedMessage,
GetRejectedMessage,
UpdateAcceptedMessage,
UpdateDeltaMessage,
UpdateDocumentMessage,
UpdateIotaMessage,
)
from amqtt.contrib.shadows.models import Shadow, sync_shadow_base
from amqtt.contrib.shadows.states import (
ShadowOperation,
StateDocument,
calculate_delta_update,
calculate_iota_update,
)
from amqtt.plugins.base import BasePlugin, BaseTopicPlugin
from amqtt.session import ApplicationMessage, Session
shadow_topic_re = re.compile(r"^\$shadow/(?P<client_id>[a-zA-Z0-9_-]+?)/(?P<shadow_name>[a-zA-Z0-9_-]+?)/(?P<request>get|update)")
DeviceID = str
ShadowName = str
@dataclass
class ShadowTopic:
device_id: DeviceID
name: ShadowName
message_op: ShadowOperation
def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]:
"""Nested defaultdict for shadow cache."""
return defaultdict(shadow_dict) # type: ignore[arg-type]
class ShadowPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self._shadows: dict[DeviceID, dict[ShadowName, StateDocument]] = defaultdict(dict)
self._engine = create_async_engine(self.config.connection)
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
async def on_broker_pre_start(self) -> None:
"""Sync the schema."""
async with self._engine.begin() as conn:
await sync_shadow_base(conn)
@staticmethod
def shadow_topic_match(topic: str) -> ShadowTopic | None:
"""Check if topic matches the shadow topic format."""
# pattern is "$shadow/<username>/<shadow_name>/get, update, etc
match = shadow_topic_re.search(topic)
if match:
groups = match.groupdict()
return ShadowTopic(groups["client_id"], groups["shadow_name"], ShadowOperation(groups["request"]))
return None
async def _handle_get(self, st: ShadowTopic) -> None:
"""Send 'accepted."""
async with self._db_session_maker() as db_session, db_session.begin():
shadow = await Shadow.latest_version(db_session, st.device_id, st.name)
if not shadow:
reject_msg = GetRejectedMessage(
code=404,
message="shadow not found",
)
await self.context.broadcast_message(reject_msg.topic(st.device_id, st.name), reject_msg.to_message())
return
accept_msg = GetAcceptedMessage(
state=shadow.state.state,
metadata=shadow.state.metadata,
timestamp=shadow.created_at,
version=shadow.version
)
await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
async def _handle_update(self, st: ShadowTopic, update: dict[str, Any]) -> None:
async with self._db_session_maker() as db_session, db_session.begin():
shadow = await Shadow.latest_version(db_session, st.device_id, st.name)
if not shadow:
shadow = Shadow(device_id=st.device_id, name=st.name)
state_update = StateDocument.from_dict(update)
prev_state = shadow.state or StateDocument()
prev_state.version = shadow.version or 0 # only required when generating shadow messages
prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages
next_state = prev_state + state_update
shadow.state = next_state
db_session.add(shadow)
await db_session.commit()
next_state.version = shadow.version
next_state.timestamp = shadow.created_at
accept_msg = UpdateAcceptedMessage(
state=next_state.state,
metadata=next_state.metadata,
timestamp=123,
version=1
)
await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
delta_msg = UpdateDeltaMessage(
state=calculate_delta_update(next_state.state.desired, next_state.state.reported),
metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported),
version=shadow.version,
timestamp=shadow.created_at
)
await self.context.broadcast_message(delta_msg.topic(st.device_id, st.name), delta_msg.to_message())
iota_msg = UpdateIotaMessage(
state=calculate_iota_update(next_state.state.desired, next_state.state.reported),
metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported),
version=shadow.version,
timestamp=shadow.created_at
)
await self.context.broadcast_message(iota_msg.topic(st.device_id, st.name), iota_msg.to_message())
doc_msg = UpdateDocumentMessage(
previous=prev_state,
current=next_state,
timestamp=shadow.created_at
)
await self.context.broadcast_message(doc_msg.topic(st.device_id, st.name), doc_msg.to_message())
async def on_broker_message_received(self, *, client_id: str, message: ApplicationMessage) -> None:
"""Process a message that was received from a client."""
topic = message.topic
if not topic.startswith("$shadow"): # this is less overhead than do the full regular expression match
return
if not (shadow_topic := self.shadow_topic_match(topic)):
return
match shadow_topic.message_op:
case ShadowOperation.GET:
await self._handle_get(shadow_topic)
case ShadowOperation.UPDATE:
await self._handle_update(shadow_topic, json.loads(message.data.decode("utf-8")))
@dataclass
class Config:
"""Configuration for shadow plugin."""
connection: str
"""SQLAlchemy connection string for the asyncio version of the database connector:
- `mysql+aiomysql://user:password@host:port/dbname`
- `postgresql+asyncpg://user:password@host:port/dbname`
- `sqlite+aiosqlite:///dbfilename.db`
"""
class ShadowTopicAuthPlugin(BaseTopicPlugin):
async def topic_filtering(self, *,
session: Session | None = None,
topic: str | None = None,
action: Action | None = None) -> bool | None:
session = session or Session()
if not topic:
return False
shadow_topic = ShadowPlugin.shadow_topic_match(topic)
if not shadow_topic:
return False
return shadow_topic.device_id == session.username or session.username in self.config.superusers
@dataclass
class Config:
"""Configuration for only allowing devices access to their own shadow topics."""
superusers: list[str] = field(default_factory=list)
"""A list of one or more usernames that can write to any device topic,
primarily for the central app sending updates to devices."""

Wyświetl plik

@ -0,0 +1,206 @@
from collections import Counter
from collections.abc import MutableMapping
from dataclasses import dataclass, field
try:
from enum import StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
import time
from typing import Any, Generic, TypeVar
from mergedeep import merge
C = TypeVar("C", bound=Any)
class StateError(Exception):
def __init__(self, msg: str = "'state' field is required") -> None:
super().__init__(msg)
@dataclass
class MetaTimestamp:
timestamp: int = 0
def __eq__(self, other: object) -> bool:
"""Compare timestamps."""
if isinstance(other, int):
return self.timestamp == other
if isinstance(other, self.__class__):
return self.timestamp == other.timestamp
msg = "needs to be int or MetaTimestamp"
raise ValueError(msg)
# numeric operations to make this dataclass transparent
def __abs__(self) -> int:
"""Absolute timestamp."""
return self.timestamp
def __add__(self, other: int) -> int:
"""Add to a timestamp."""
return self.timestamp + other
def __sub__(self, other: int) -> int:
"""Subtract from a timestamp."""
return self.timestamp - other
def __mul__(self, other: int) -> int:
"""Multiply a timestamp."""
return self.timestamp * other
def __float__(self) -> float:
"""Convert timestamp to float."""
return float(self.timestamp)
def __int__(self) -> int:
"""Convert timestamp to int."""
return int(self.timestamp)
def __lt__(self, other: int) -> bool:
"""Compare timestamp."""
return self.timestamp < other
def __le__(self, other: int) -> bool:
"""Compare timestamp."""
return self.timestamp <= other
def __gt__(self, other: int) -> bool:
"""Compare timestamp."""
return self.timestamp > other
def __ge__(self, other: int) -> bool:
"""Compare timestamp."""
return self.timestamp >= other
def create_metadata(state: MutableMapping[str, Any], timestamp: int) -> dict[str, Any]:
"""Create metadata (timestamps) for each of the keys in 'state'."""
metadata: dict[str, Any] = {}
for key, value in state.items():
if isinstance(value, dict):
metadata[key] = create_metadata(value, timestamp)
elif value is None:
metadata[key] = None
else:
metadata[key] = MetaTimestamp(timestamp)
return metadata
def calculate_delta_update(desired: MutableMapping[str, Any],
reported: MutableMapping[str, Any],
depth: bool = True,
exclude_nones: bool = True,
ordered_lists: bool = True) -> dict[str, Any]:
"""Calculate state differences between desired and reported."""
diff_dict = {}
for key, value in desired.items():
if value is None and exclude_nones:
continue
# if the desired has an element that the reported does not...
if key not in reported:
diff_dict[key] = value
# if the desired has an element that's a list, but the list is
elif isinstance(value, list) and not ordered_lists:
if Counter(value) != Counter(reported[key]):
diff_dict[key] = value
elif isinstance(value, dict) and depth:
# recurse, report when there is a difference
obj_diff = calculate_delta_update(value, reported[key])
if obj_diff:
diff_dict[key] = obj_diff
elif value != reported[key]:
diff_dict[key] = value
return diff_dict
def calculate_iota_update(desired: MutableMapping[str, Any], reported: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
"""Calculate state differences between desired and reported (including missing keys)."""
delta = calculate_delta_update(desired, reported, depth=False, exclude_nones=False)
for key in reported:
if key not in desired:
delta[key] = None
return delta
@dataclass
class State(Generic[C]):
desired: MutableMapping[str, C] = field(default_factory=dict)
reported: MutableMapping[str, C] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: dict[str, C]) -> "State[C]":
"""Create state from dictionary."""
return cls(
desired=data.get("desired", {}),
reported=data.get("reported", {})
)
def __bool__(self) -> bool:
"""Determine if state is empty."""
return bool(self.desired) or bool(self.reported)
def __add__(self, other: "State[C]") -> "State[C]":
"""Merge states together."""
return State(
desired=merge({}, self.desired, other.desired),
reported=merge({}, self.reported, other.reported)
)
@dataclass
class StateDocument:
state: State[dict[str, Any]] = field(default_factory=State)
metadata: State[MetaTimestamp] = field(default_factory=State)
version: int | None = None # only required when generating shadow messages
timestamp: int | None = None # only required when generating shadow messages
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "StateDocument":
"""Create state document from dictionary."""
now = int(time.time())
if data and "state" not in data:
raise StateError
state = State.from_dict(data.get("state", {}))
metadata = State(
desired=create_metadata(state.desired, now),
reported=create_metadata(state.reported, now)
)
return cls(state=state, metadata=metadata)
def __post_init__(self) -> None:
"""Initialize meta data if not provided."""
now = int(time.time())
if not self.metadata:
self.metadata = State(
desired=create_metadata(self.state.desired, now),
reported=create_metadata(self.state.reported, now),
)
def __add__(self, other: "StateDocument") -> "StateDocument":
"""Merge two state documents together."""
return StateDocument(
state=self.state + other.state,
metadata=self.metadata + other.metadata
)
class ShadowOperation(StrEnum):
GET = "get"
UPDATE = "update"
GET_ACCEPT = "get/accepted"
GET_REJECT = "get/rejected"
UPDATE_ACCEPT = "update/accepted"
UPDATE_REJECT = "update/rejected"
UPDATE_DOCUMENTS = "update/documents"
UPDATE_DELTA = "update/delta"
UPDATE_IOTA = "update/iota"

Wyświetl plik

@ -16,10 +16,12 @@ class CodecError(Exception):
class NoDataError(Exception):
"""Exceptions thrown by packet encode/decode functions."""
class ZeroLengthReadError(NoDataError):
def __init__(self) -> None:
super().__init__("Decoding a string of length zero.")
class BrokerError(Exception):
"""Exceptions thrown by broker."""

Wyświetl plik

@ -3,7 +3,7 @@ try:
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
@ -31,4 +31,6 @@ class BrokerEvents(Events):
CLIENT_DISCONNECTED = "broker_client_disconnected"
CLIENT_SUBSCRIBED = "broker_client_subscribed"
CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed"
RETAINED_MESSAGE = "broker_retained_message"
MESSAGE_RECEIVED = "broker_message_received"
MESSAGE_BROADCAST = "broker_message_broadcast"

Wyświetl plik

@ -167,7 +167,7 @@ class PacketIdVariableHeader(MQTTVariableHeader):
_VH = TypeVar("_VH", bound=MQTTVariableHeader | None)
class MQTTPayload(Generic[_VH], ABC):
class MQTTPayload(ABC, Generic[_VH]):
"""Abstract base class for MQTT payloads."""
async def to_stream(self, writer: asyncio.StreamWriter) -> None:

Wyświetl plik

@ -31,6 +31,7 @@ _MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
if TYPE_CHECKING:
from amqtt.broker import BrokerContext
class Subscription:
def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
self.packet_id = packet_id
@ -235,6 +236,7 @@ class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]):
incoming_session.password = connect.password
incoming_session.remote_address = remote_address
incoming_session.remote_port = remote_port
incoming_session.ssl_object = writer.get_ssl_info()
incoming_session.keep_alive = max(connect.keep_alive, 0)

Wyświetl plik

@ -19,6 +19,7 @@ from amqtt.session import Session
if TYPE_CHECKING:
from amqtt.client import ClientContext
class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
def __init__(
self,

Wyświetl plik

@ -4,13 +4,13 @@ try:
from asyncio import InvalidStateError, QueueFull, QueueShutDown
except ImportError:
# Fallback for Python < 3.12
class InvalidStateError(Exception): # type: ignore[no-redef]
class InvalidStateError(Exception): # type: ignore[no-redef]
pass
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
pass
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
pass
@ -63,6 +63,7 @@ from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplic
C = TypeVar("C", bound=BaseContext)
class ProtocolHandler(Generic[C]):
"""Class implementing the MQTT communication protocol using asyncio features."""
@ -199,7 +200,7 @@ class ProtocolHandler(Generic[C]):
async def mqtt_publish(
self,
topic: str,
data: bytes | bytearray ,
data: bytes | bytearray,
qos: int | None,
retain: bool,
ack_timeout: int | None = None,
@ -535,7 +536,7 @@ class ProtocolHandler(Generic[C]):
self.handle_read_timeout()
except NoDataError:
self.logger.debug(f"{self.session.client_id} No data available")
except Exception as e: # noqa: BLE001
except Exception as e: # noqa: BLE001, pylint: disable=W0718
self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}")
break
while running_tasks:

Wyświetl plik

@ -1 +1,38 @@
"""INIT."""
import re
from typing import Any, Optional
class TopicMatcher:
_instance: Optional["TopicMatcher"] = None
def __init__(self) -> None:
if not hasattr(self, "_topic_filter_matchers"):
self._topic_filter_matchers: dict[str, re.Pattern[str]] = {}
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "TopicMatcher":
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
def is_topic_allowed(self, topic: str, a_filter: str) -> bool:
if topic.startswith("$") and (a_filter.startswith(("+", "#"))):
return False
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
# else use regex (re.compile is an expensive operation, store the matcher for future use)
if a_filter not in self._topic_filter_matchers:
self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter)
.replace("\\#", "?.*")
.replace("\\+", "[^/]*")
.lstrip("?"))
match_pattern = self._topic_filter_matchers[a_filter]
return bool(match_pattern.fullmatch(topic))
def are_topics_allowed(self, topic: str, many_filters: list[str]) -> bool:
return any(self.is_topic_allowed(topic, a_filter) for a_filter in many_filters)

Wyświetl plik

@ -37,9 +37,10 @@ class AnonymousAuthPlugin(BaseAuthPlugin):
@dataclass
class Config:
"""Allow empty username."""
"""Configuration for AnonymousAuthPlugin."""
allow_anonymous: bool = field(default=True)
"""Allow all anonymous authentication (even with _no_ username)."""
class FileAuthPlugin(BaseAuthPlugin):
@ -78,7 +79,7 @@ class FileAuthPlugin(BaseAuthPlugin):
self.context.logger.warning(f"Password file '{password_file}' not found")
except ValueError:
self.context.logger.exception(f"Malformed password file '{password_file}'")
except Exception:
except OSError:
self.context.logger.exception(f"Unexpected error reading password file '{password_file}'")
async def authenticate(self, *, session: Session) -> bool | None:
@ -107,6 +108,7 @@ class FileAuthPlugin(BaseAuthPlugin):
@dataclass
class Config:
"""Path to the properly encoded password file."""
"""Configuration for FileAuthPlugin."""
password_file: str | Path | None = None
"""Path to file with `username:password` pairs, one per line. All passwords are encoded using sha-512."""

Wyświetl plik

@ -1,7 +1,7 @@
from dataclasses import dataclass, is_dataclass
from typing import Any, Generic, TypeVar, cast
from amqtt.contexts import Action, BaseContext
from amqtt.contexts import Action, BaseContext, BrokerConfig
from amqtt.session import Session
C = TypeVar("C", bound=BaseContext)
@ -46,13 +46,13 @@ class BasePlugin(Generic[C]):
return section_config
# Deprecated : supports entrypoint-style configs as well as dataclass configuration.
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
return getattr(self.context.config, option_name.replace("-", "_"), default)
if option_name in self.context.config:
return self.context.config[option_name]
return default
@ -75,20 +75,21 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
if not bool(self.topic_config) and not is_dataclass(self.context.config):
self.context.logger.warning("'topic-check' section not found in context configuration")
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
# overloaded context.config with either BrokerConfig or plugin's Config
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
return getattr(self.context.config, option_name.replace("-", "_"), default)
if self.topic_config and option_name in self.topic_config:
return self.topic_config[option_name]
return default
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
) -> bool | None:
"""Logic for filtering out topics.
Args:
@ -97,7 +98,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
action: amqtt.broker.Action
Returns:
bool: `True` if topic is allowed, `False` otherwise
bool: `True` if topic is allowed, `False` otherwise. `None` if it can't be determined
"""
return bool(self.topic_config) or is_dataclass(self.context.config)
@ -106,13 +107,13 @@ class BaseTopicPlugin(BasePlugin[BaseContext]):
class BaseAuthPlugin(BasePlugin[BaseContext]):
"""Base class for authentication plugins."""
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
return getattr(self.context.config, option_name.replace("-", "_"), default)
if self.auth_config and option_name in self.auth_config:
return self.auth_config[option_name]
return default
@ -125,7 +126,6 @@ class BaseAuthPlugin(BasePlugin[BaseContext]):
# auth config section not found and Config dataclass not provided
self.context.logger.warning("'auth' section not found in context configuration")
async def authenticate(self, *, session: Session) -> bool | None:
"""Logic for session authentication.

Wyświetl plik

@ -8,6 +8,8 @@ import copy
from importlib.metadata import EntryPoint, EntryPoints, entry_points
from inspect import iscoroutinefunction
import logging
import sys
import traceback
from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast
import warnings
@ -49,6 +51,7 @@ def safe_issubclass(sub_class: Any, super_class: Any) -> bool:
AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]]
C = TypeVar("C", bound=BaseContext)
class PluginManager(Generic[C]):
"""Wraps contextlib Entry point mechanism to provide a basic plugin system.
@ -95,10 +98,9 @@ class PluginManager(Generic[C]):
if self.app_context.config and self.app_context.config.get("plugins", None) is not None:
# plugins loaded directly from config dictionary
if "auth" in self.app_context.config:
if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if "topic-check" in self.app_context.config:
if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])
@ -130,7 +132,7 @@ class PluginManager(Generic[C]):
"Loading plugins from EntryPoints is deprecated and will be removed in a future version."
" Use `plugins` section of config instead.",
DeprecationWarning,
stacklevel=2
stacklevel=4
)
self._load_ep_plugins(namespace)
@ -145,7 +147,7 @@ class PluginManager(Generic[C]):
self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'")
self._event_plugin_callbacks[event].append(awaitable)
def _load_ep_plugins(self, namespace:str) -> None:
def _load_ep_plugins(self, namespace: str) -> None:
"""Load plugins from `pyproject.toml` entrypoints. Deprecated."""
self.logger.debug(f"Loading plugins for namespace {namespace}")
auth_filter_list = []
@ -222,7 +224,7 @@ class PluginManager(Generic[C]):
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]":
"""Load plugin from string dotted path: mymodule.myfile.MyPlugin."""
try:
plugin_class: Any = import_string(plugin_path)
plugin_class: Any = import_string(plugin_path)
except ImportError as ep:
msg = f"Plugin import failed: {plugin_path}"
raise PluginImportError(msg) from ep
@ -291,6 +293,15 @@ class PluginManager(Generic[C]):
return asyncio.ensure_future(coro)
def _clean_fired_events(self, future: asyncio.Future[Any]) -> None:
if self.logger.getEffectiveLevel() <= logging.DEBUG:
try:
future.result()
except asyncio.CancelledError:
self.logger.warning("fired event was cancelled")
# display plugin fault; don't allow it to cause a broker failure
except Exception as exc: # noqa: BLE001, pylint: disable=W0718
traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
with contextlib.suppress(KeyError, ValueError):
self._fired_events.remove(future)
@ -366,7 +377,7 @@ class PluginManager(Generic[C]):
:return: dict containing return from coro call for each plugin.
"""
return await self._map_plugin_method(
self._auth_plugins, "authenticate", {"session": session }) # type: ignore[arg-type]
self._auth_plugins, "authenticate", {"session": session}) # type: ignore[arg-type]
async def map_plugin_topic(
self, *, session: Session, topic: str, action: "Action"

Wyświetl plik

@ -1,85 +1,11 @@
import json
import sqlite3
from typing import Any
import warnings
from amqtt.contexts import BaseContext
from amqtt.session import Session
from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin
class SQLitePlugin:
def __init__(self, context: BaseContext) -> None:
self.context: BaseContext = context
self.conn: sqlite3.Connection | None = None
self.cursor: sqlite3.Cursor | None = None
self.db_file: str | None = None
self.persistence_config: dict[str, Any]
class SQLitePlugin(BasePlugin[BrokerContext]):
if (
persistence_config := self.context.config.get("persistence") if self.context.config is not None else None
) is not None:
self.persistence_config = persistence_config
self.init_db()
else:
self.context.logger.warning("'persistence' section not found in context configuration")
def init_db(self) -> None:
self.db_file = self.persistence_config.get("file")
if not self.db_file:
self.context.logger.warning("'file' persistence parameter not found")
else:
try:
self.conn = sqlite3.connect(self.db_file)
self.cursor = self.conn.cursor()
self.context.logger.info(f"Database file '{self.db_file}' opened")
except Exception:
self.context.logger.exception(f"Error while initializing database '{self.db_file}'")
if self.cursor:
self.cursor.execute(
"CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)",
)
self.cursor.execute("PRAGMA table_info(session)")
columns = {col[1] for col in self.cursor.fetchall()}
required_columns = {"client_id", "data"}
if not required_columns.issubset(columns):
self.context.logger.error("Database schema for 'session' table is incompatible.")
async def save_session(self, session: Session) -> None:
if self.cursor and self.conn:
dump: str = json.dumps(session, default=str)
try:
self.cursor.execute(
"INSERT OR REPLACE INTO session (client_id, data) VALUES (?, ?)",
(session.client_id, dump),
)
self.conn.commit()
except Exception:
self.context.logger.exception(f"Failed saving session '{session}'")
async def find_session(self, client_id: str) -> Session | None:
if self.cursor:
row = self.cursor.execute(
"SELECT data FROM session where client_id=?",
(client_id,),
).fetchone()
return json.loads(row[0]) if row else None
return None
async def del_session(self, client_id: str) -> None:
if self.cursor and self.conn:
try:
exists = self.cursor.execute("SELECT 1 FROM session WHERE client_id=?", (client_id,)).fetchone()
if exists:
self.cursor.execute("DELETE FROM session where client_id=?", (client_id,))
self.conn.commit()
except Exception:
self.context.logger.exception(f"Failed deleting session with client_id '{client_id}'")
async def on_broker_post_shutdown(self) -> None:
if self.conn:
try:
self.conn.close()
self.context.logger.info(f"Database file '{self.db_file}' closed")
except Exception:
self.context.logger.exception("Error closing database connection")
finally:
self.conn = None
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
warnings.warn("SQLitePlugin is deprecated, use amqtt.contrib.persistence.SessionDBPlugin", stacklevel=1)

Wyświetl plik

@ -14,7 +14,7 @@ except ImportError:
from typing import Protocol, runtime_checkable
@runtime_checkable
class Buffer(Protocol): # type: ignore[no-redef]
class Buffer(Protocol): # type: ignore[no-redef]
def __buffer__(self, flags: int = ...) -> memoryview:
"""Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""
@ -75,7 +75,6 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
self._sys_interval: int = 0
self._current_process = psutil.Process()
def _clear_stats(self) -> None:
"""Initialize broker statistics data structures."""
for stat in (
@ -112,7 +111,7 @@ class BrokerSysPlugin(BasePlugin[BrokerContext]):
"""Initialize statistics and start $SYS broadcasting."""
self._stats[STAT_START_TIME] = int(datetime.now(tz=UTC).timestamp())
version = f"aMQTT version {amqtt.__version__}"
self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())
await self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())
# Start $SYS topics management
try:

Wyświetl plik

@ -1,7 +1,9 @@
from dataclasses import dataclass, field
from typing import Any
import warnings
from amqtt.contexts import Action, BaseContext
from amqtt.errors import PluginInitError
from amqtt.plugins.base import BaseTopicPlugin
from amqtt.session import Session
@ -13,17 +15,27 @@ class TopicTabooPlugin(BaseTopicPlugin):
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
) -> bool | None:
filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
if filter_result:
if session and session.username == "admin":
return True
return not (topic and topic in self._taboo)
return filter_result
return bool(filter_result)
class TopicAccessControlListPlugin(BaseTopicPlugin):
def __init__(self, context: BaseContext) -> None:
super().__init__(context)
if self._get_config_option("acl", None):
warnings.warn("The 'acl' option is deprecated, please use 'subscribe-acl' instead.", stacklevel=1)
if self._get_config_option("acl", None) and self._get_config_option("subscribe-acl", None):
msg = "'acl' has been replaced with 'subscribe-acl'; only one may be included"
raise PluginInitError(msg)
@staticmethod
def topic_ac(topic_requested: str, topic_allowed: str) -> bool:
req_split = topic_requested.split("/")
@ -46,7 +58,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
) -> bool | None:
filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
if not filter_result:
return False
@ -58,18 +70,26 @@ class TopicAccessControlListPlugin(BaseTopicPlugin):
req_topic = topic
if not req_topic:
return False\
return False
username = session.username if session else None
if username is None:
username = "anonymous"
acl: dict[str, Any] = {}
acl: dict[str, Any] | None = None
match action:
case Action.PUBLISH:
acl = self._get_config_option("publish-acl", {})
acl = self._get_config_option("publish-acl", None)
case Action.SUBSCRIBE:
acl = self._get_config_option("acl", {})
acl = self._get_config_option("subscribe-acl", self._get_config_option("acl", None))
case Action.RECEIVE:
acl = self._get_config_option("receive-acl", None)
case _:
msg = "Received an invalid action type."
raise ValueError(msg)
if acl is None:
return True
allowed_topics = acl.get(username, [])
if not allowed_topics:

Wyświetl plik

@ -21,7 +21,7 @@ def main() -> None:
app()
def _version(v:bool) -> None:
def _version(v: bool) -> None:
if v:
typer.echo(f"{amqtt_version}")
raise typer.Exit(code=0)
@ -41,9 +41,12 @@ def broker_main(
) -> None:
"""Command-line script for running a MQTT 3.1.1 broker."""
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
if debug:
formatter = "[%(asctime)s] %(name)s:%(lineno)d :: %(levelname)s - %(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(level=level, format=formatter)
logging.getLogger("transitions").setLevel(logging.WARNING)
try:
if config_file:
config = read_yaml_config(config_file)
@ -62,7 +65,7 @@ def broker_main(
typer.echo(f"❌ Broker failed to start: {exc}", err=True)
raise typer.Exit(code=1) from exc
_ = loop.create_task(broker.start()) #noqa : RUF006
_ = loop.create_task(broker.start()) # noqa : RUF006
try:
loop.run_forever()
except KeyboardInterrupt:

Wyświetl plik

@ -0,0 +1,42 @@
import logging
from pathlib import Path
import sys
import typer
logger = logging.getLogger(__name__)
app = typer.Typer(add_completion=False, rich_markup_mode=None)
def main() -> None:
"""Run the cli for `ca_creds`."""
app()
@app.command()
def ca_creds(
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
state: str = typer.Option(..., "--state", help="x509 'state_or_province_name' attribute"),
locality: str = typer.Option(..., "--locality", help="x509 'locality_name' attribute"),
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
) -> None:
"""Generate a self-signed key and certificate to be used as the root CA, with a key size of 2048 and a 1-year expiration."""
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
try:
from amqtt.contrib.cert import generate_root_creds, write_key_and_crt # pylint: disable=import-outside-toplevel
except ImportError:
msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`"
logger.critical(msg)
sys.exit(1)
ca_key, ca_crt = generate_root_creds(country=country, state=state, locality=locality, org_name=org_name, cn=cn)
write_key_and_crt(ca_key, ca_crt, "ca", Path(output_dir))
if __name__ == "__main__":
main()

Wyświetl plik

@ -7,7 +7,7 @@ auto_reconnect: true
cleansession: true
reconnect_max_interval: 10
reconnect_retries: 2
broker:
connection:
uri: "mqtt://127.0.0.1"
plugins:
amqtt.plugins.logging_amqtt.PacketLoggerPlugin:

Wyświetl plik

@ -0,0 +1,64 @@
import logging
from pathlib import Path
import sys
import typer
logger = logging.getLogger(__name__)
app = typer.Typer(add_completion=False, rich_markup_mode=None)
def main() -> None:
"""Run the `device_creds` cli."""
app()
@app.command()
def device_creds( # pylint: disable=too-many-locals
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
device_id: str = typer.Option(..., "--device-id", help="device id for the SAN"),
uri: str = typer.Option(..., "--uri", help="domain name for device SAN"),
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="root key filename used for signing."),
ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="root cert filename used for signing."),
) -> None:
"""Generate a key and certificate for each device in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa: E501
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
try:
from amqtt.contrib.cert import ( # pylint: disable=import-outside-toplevel
generate_device_csr,
load_ca,
sign_csr,
write_key_and_crt,
)
except ImportError:
msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`"
logger.critical(msg)
sys.exit(1)
ca_key, ca_crt = load_ca(ca_key_fn, ca_crt_fn)
uri_san = f"spiffe://{uri}/device/{device_id}"
dns_san = f"{device_id}.local"
device_key, device_csr = generate_device_csr(
country=country,
org_name=org_name,
common_name=device_id,
uri_san=uri_san,
dns_san=dns_san
)
device_crt = sign_csr(device_csr, ca_key, ca_crt)
write_key_and_crt(device_key, device_crt, device_id, Path(output_dir))
logger.info(f"✅ Created: {device_id}.crt and {device_id}.key")
if __name__ == "__main__":
main()

Wyświetl plik

@ -0,0 +1,37 @@
import logging
import sys
from amqtt.errors import MQTTError
logger = logging.getLogger(__name__)
def main() -> None:
"""Run the auth db cli."""
try:
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
except ImportError:
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
sys.exit(1)
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
try:
topic_app()
except ModuleNotFoundError as mnfe:
logger.critical(f"Please install database-specific dependencies: {mnfe}")
sys.exit(1)
except ValueError as ve:
if "greenlet" in f"{ve}":
logger.critical("Please install database-specific dependencies: 'greenlet'")
sys.exit(1)
logger.critical(f"Unknown error: {ve}")
sys.exit(1)
except MQTTError as me:
logger.critical(f"could not execute command: {me}")
sys.exit(1)
if __name__ == "__main__":
main()

Wyświetl plik

@ -0,0 +1,36 @@
import logging
import sys
from amqtt.errors import MQTTError
logger = logging.getLogger(__name__)
def main() -> None:
"""Run the auth db cli."""
try:
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
except ImportError:
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
sys.exit(1)
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
try:
user_app()
except ModuleNotFoundError as mnfe:
logger.critical(f"Please install database-specific dependencies: {mnfe}")
sys.exit(1)
except ValueError as ve:
if "greenlet" in f"{ve}":
logger.critical("Please install database-specific dependencies: 'greenlet'")
sys.exit(1)
logger.critical(f"Unknown error: {ve}")
sys.exit(1)
except MQTTError as me:
logger.critical(f"could not execute command: {me}")
sys.exit(1)
if __name__ == "__main__":
main()

Wyświetl plik

@ -52,7 +52,7 @@ class MessageInput:
with Path(self.file).open(encoding="utf-8") as f:
for line in f:
yield line.encode(encoding="utf-8")
except Exception:
except (FileNotFoundError, OSError):
logger.exception(f"Failed to read file '{self.file}'")
if self.lines:
for line in sys.stdin:
@ -118,6 +118,7 @@ async def do_pub(
logger.fatal("Publish canceled due to previous error")
raise asyncio.CancelledError from ce
app = typer.Typer(add_completion=False, rich_markup_mode=None)
@ -131,8 +132,9 @@ def _version(v: bool) -> None:
typer.echo(f"{amqtt_version}")
raise typer.Exit(code=0)
@app.command()
def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
def publisher_main( # pylint: disable=R0914,R0917
url: str | None = typer.Option(None, "--url", help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*"),
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Client configuration file"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"),
@ -155,7 +157,7 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
will_retain: bool = typer.Option(False, "--will-retain", help="If the client disconnects unexpectedly the message sent out will be treated as a retained message. *only valid, if `--will-topic` is specified*"),
extra_headers_json: str | None = typer.Option(None, "--extra-headers", help="Specify a JSON object string with key-value pairs representing additional headers that are transmitted on the initial connection. *websocket connections only*."),
debug: bool = typer.Option(False, "-d", help="Enable debug messages"),
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
) -> None:
"""Command-line MQTT client for publishing simple messages."""
provided = [bool(message), bool(file), stdin, lines, no_message]

Wyświetl plik

@ -0,0 +1,50 @@
import logging
from pathlib import Path
import sys
import typer
logger = logging.getLogger(__name__)
app = typer.Typer(add_completion=False, rich_markup_mode=None)
def main() -> None:
"""Run the `server_creds` cli."""
app()
@app.command()
def server_creds(
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="server key output filename."),
ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="server cert output filename."),
) -> None:
"""Generate a key and certificate for the broker in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa : E501
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
try:
from amqtt.contrib.cert import ( # pylint: disable=import-outside-toplevel
generate_server_csr,
load_ca,
sign_csr,
write_key_and_crt,
)
except ImportError:
msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`"
logger.critical(msg)
sys.exit(1)
ca_key, ca_crt = load_ca(ca_key_fn, ca_crt_fn)
server_key, server_csr = generate_server_csr(country=country, org_name=org_name, cn=cn)
server_crt = sign_csr(server_csr, ca_key, ca_crt)
write_key_and_crt(server_key, server_crt, "server", Path(output_dir))
if __name__ == "__main__":
main()

Wyświetl plik

@ -100,17 +100,17 @@ def main() -> None:
app()
def _version(v:bool) -> None:
def _version(v: bool) -> None:
if v:
typer.echo(f"{amqtt_version}")
raise typer.Exit(code=0)
@app.command()
def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
def subscribe_main( # pylint: disable=R0914,R0917
url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False),
config_file: str | None = typer.Option(None, "-c", help="Client configuration file"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
topics: list[str] = typer.Option(..., "-t", help="Topic filter to subscribe, can be used multiple times."), # noqa: B008
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),

Wyświetl plik

@ -1,6 +1,9 @@
from asyncio import Queue
from collections import OrderedDict
from typing import Any, ClassVar
import logging
from math import floor
import time
from typing import TYPE_CHECKING, Any, ClassVar
from transitions import Machine
@ -10,6 +13,11 @@ from amqtt.mqtt.publish import PublishPacket
OUTGOING = 0
INCOMING = 1
if TYPE_CHECKING:
import ssl
logger = logging.getLogger(__name__)
class ApplicationMessage:
"""ApplicationMessage and subclasses are used to store published message information flow.
@ -138,6 +146,9 @@ class Session:
self.cadata: bytes | None = None
self._packet_id: int = 0
self.parent: int = 0
self.last_connect_time: int | None = None
self.ssl_object: ssl.SSLObject | None = None
self.last_disconnect_time: int | None = None
# Used to store outgoing ApplicationMessage while publish protocol flows
self.inflight_out: OrderedDict[int, OutgoingApplicationMessage] = OrderedDict()
@ -161,6 +172,7 @@ class Session:
source="new",
dest="connected",
)
self.transitions.on_enter_connected(self._on_enter_connected)
self.transitions.add_transition(
trigger="connect",
source="disconnected",
@ -171,6 +183,7 @@ class Session:
source="connected",
dest="disconnected",
)
self.transitions.on_enter_disconnected(self._on_enter_disconnected)
self.transitions.add_transition(
trigger="disconnect",
source="new",
@ -182,6 +195,20 @@ class Session:
dest="disconnected",
)
def _on_enter_connected(self) -> None:
cur_time = floor(time.time())
if self.last_disconnect_time is not None:
logger.debug(f"Session reconnected after {cur_time - self.last_disconnect_time} seconds.")
self.last_connect_time = cur_time
self.last_disconnect_time = None
def _on_enter_disconnected(self) -> None:
cur_time = floor(time.time())
if self.last_connect_time is not None:
logger.debug(f"Session disconnected after {cur_time - self.last_connect_time} seconds.")
self.last_disconnect_time = cur_time
@property
def next_packet_id(self) -> int:
self._packet_id = (self._packet_id % 65535) + 1

Wyświetl plik

@ -30,3 +30,7 @@ h2.doc-heading-parameter {
.md-nav__link--active {
color: #f15581 !important;
}
.admonition {
font-size: 16px !important;
}

Wyświetl plik

@ -1,5 +1,43 @@
# Changelog
## 0.11.3
API changes:
- broker and client configuration via dataclasses and enums instead of unstructured dictionaries (backwards compatible)
- `MESSAGE_RECIEVE` event moved to after topic filtering
- `MESSAGE_BROADCAST` event added for prior topic filtering
- `RETAINED_MESSAGE` event added for messages with retained flag or offline clients without setting clean session flag
- method `retain_message` changed to coroutine (broker)
- change `add_subscription` to a public method (broker)
- add listener type for external servers and api method for passing new connection via `external_connected` method (broker)
- for TLS sessions, properly load the key and cert file (client)
- added abstract method `get_ssl_info` to `WriterAdapter`
Details:
* Structural elements for the 0.11.3 release https://github.com/Yakifo/amqtt/pull/265
* Release Candidate Branch for 0.11.3 https://github.com/Yakifo/amqtt/pull/272
* update the configuration for the broker running at test.amqtt.io https://github.com/Yakifo/amqtt/pull/271
* Improved broker script logging https://github.com/Yakifo/amqtt/pull/277
* test.amqtt.io dashboard cleanup https://github.com/Yakifo/amqtt/pull/278
* Structured broker and client configurations https://github.com/Yakifo/amqtt/pull/269
* Determine auth & topic access via external http server https://github.com/Yakifo/amqtt/pull/262
* Plugin: authentication against a relational database https://github.com/Yakifo/amqtt/pull/280
* Fixes #247 : expire disconnected sessions https://github.com/Yakifo/amqtt/pull/279
* Expanded structure for plugin documentation https://github.com/Yakifo/amqtt/pull/281
* Yakifo/amqtt#120 confirms : validate example is functioning https://github.com/Yakifo/amqtt/pull/284
* Yakifo/amqtt#39 : adding W0718 'broad exception caught' https://github.com/Yakifo/amqtt/pull/285
* Documentation improvement for 0.11.3 https://github.com/Yakifo/amqtt/pull/286
* Plugin naming convention https://github.com/Yakifo/amqtt/pull/288
* embed amqtt into an existing server https://github.com/Yakifo/amqtt/pull/283
* Plugin: rebuild of session persistence https://github.com/Yakifo/amqtt/pull/256
* Plugin: determine authentication based on X509 certificates https://github.com/Yakifo/amqtt/pull/264
* Plugin: device 'shadows' to bridge device online/offline states https://github.com/Yakifo/amqtt/pull/282
* Plugin: authenticate against LDAP server https://github.com/Yakifo/amqtt/pull/287
* Sample: broker and client communicating with mqtt over unix socket https://github.com/Yakifo/amqtt/pull/291
* Plugin: jwt authentication and authorization https://github.com/Yakifo/amqtt/pull/289
## 0.11.2
- config-file based plugin loading [PR #240](https://github.com/Yakifo/amqtt/pull/240)

Wyświetl plik

@ -1,8 +0,0 @@
{% extends "base.html" %}
{% block outdated %}
You're not viewing the latest version.
<a href="{{ '../' ~ base_url }}">
<strong>Click here to go to latest.</strong>
</a>
{% endblock %}

Wyświetl plik

@ -0,0 +1,39 @@
# Relational Database for Authentication and Authorization
- `amqtt.contrib.auth_db.UserAuthDBPlugin` (authentication) verify a client's ability to connect to broker
- `amqtt.contrib.auth_db.TopicAuthDBPlugin` (authorization) determine a client's access to topics
Relational database access is supported using SQLAlchemy so MySQL, MariaDB, Postgres and SQLite support is available.
For ease of use, the [`user_mgr` command-line utility](auth_db.md/#user_mgr) to add, remove, update and
list clients. And the [`topic_mgr` command-line utility](auth_db.md/#topic_mgr) to add client access to
subscribe, publish and receive messages on topics.
# Authentication Configuration
::: amqtt.contrib.auth_db.UserAuthDBPlugin.Config
options:
heading_level: 4
extra:
class_style: "simple"
# Authorization Configuration
::: amqtt.contrib.auth_db.TopicAuthDBPlugin.Config
options:
heading_level: 4
extra:
class_style: "simple"
## CLI
::: mkdocs-typer2
:module: amqtt.contrib.auth_db.user_mgr_cli
:name: user_mgr
::: mkdocs-typer2
:module: amqtt.contrib.auth_db.topic_mgr_cli
:name: topic_mgr

Wyświetl plik

@ -0,0 +1,175 @@
# Authentication Using Signed Certificates
Using client-specific certificates, signed by a common authority (even if self-signed) provides
a highly secure way of authenticating mqtt clients. Often used with IoT devices where a unique
certificate can be initialized on initial provisioning.
With so many options, X509 certificates can be daunting to create with `openssl`. Included are
command line utilities to generate a root self-signed certificate and then the proper broker and
device certificates with the correct X509 attributes to enable authenticity.
### Quick start
Generate a self-signed root credentials and server credentials:
```shell
$ ca_creds --country US --state NY --locality NY --org-name "My Org's Name" --cn "my.domain.name"
$ server_creds --country US --org-name "My Org's Name" --cn "my.domain.name"
```
!!! warning "Security of private keys"
Your root credential private key and your server key should *never* be shared with anyone. The
certificates -- specifically the root CA certificate -- is completely safe to share and will need
to be shared along with device credentials when using a self-signed CA.
Include in your server config:
```yaml
listeners:
ssl-mqtt:
bind: "127.0.0.1:8883"
ssl: true
certfile: server.crt
keyfile: server.key
cafile: ca.crt
plugins:
amqtt.contrib.cert.CertificateAuthPlugin:
uri_domain: my.domain.name
```
Generate a device's credentials:
```shell
$ device_creds --country US --org-name "My Org's Name" --device-id myUniqueDeviceId --uri my.domain.name
```
And use these to initialize the `MQTTClient`:
```python
import asyncio
from amqtt.client import MQTTClient
client_config = {
'keyfile': 'myUniqueDeviceId.key',
'certfile': 'myUniqueDeviceId.crt',
'broker': {
'cafile': 'ca.crt'
}
}
async def main():
client = MQTTClient(config=client_config)
await client.connect("mqtts://my.domain.name:8883")
# publish messages or subscribe to receive
asyncio.run(main())
```
## Background
Often used for IoT devices, this method provides the most secure form of identification. A root
certificate, often referenced as a CA certificate -- either issued by a known authority (such as LetsEncrypt)
or a self-signed certificate) is used to sign a private key and certificate for the server. Each device/client
also gets a unique private key and certificate signed by the same CA certificate; also included in the device
certificate is a 'SAN' or SubjectAlternativeName which is the device's unique identifier.
Since both server and device certificates are signed by the same CA certificate, the client can
verify the server's authenticity; and the server can verify the client's authenticity. And since
the device's certificate contains a x509 SAN, the server (with this plugin) can identify the device securely.
!!! note "URI and Client ID configuration"
`uri_domain` configuration must be set to the same uri used to generate the device credentials
when a device is connecting with private key and certificate, the `client_id` must
match the device id used to generate the device credentials.
Available ore three scripts to help with the key generation and certificate signing: `ca_creds`, `server_creds`
and `device_creds`.
!!! note "Configuring broker & client for using Self-signed root CA"
If using self-signed root credentials, the `cafile` configuration for both broker and client need to be
configured with `cafile` set to the `ca.crt`.
## Root & Certificate Credentials
The process for generating a server's private key and certificate is only done once. If you have a private key & certificate --
such as one from verifying your webserver's domain with LetsEncrypt -- that you want to use, pass them to the `server_creds` cli.
If you'd like to use a self-signed certificate, generate your own CA by running the `ca_creds` cli (make sure your client is
configured with `check_hostname` as `False`).
```mermaid
---
config:
theme: redux
---
flowchart LR
subgraph ca_cred["ca_cred #40;cli#41; or other CA"]
ca["ca key & cert"]
end
subgraph server_cred["server_cred fl°°40¶ßclifl°°41¶ß"]
scsr("certificate signing<br>request fl°°40¶ßCSRfl°°41¶ß with<br>SAN of DNS &amp; IP Address")
spk["private key"]
ssi["sign csr"]
end
spk -.-> skc["server key & cert"]
ca_cred --> ssi
spk --> scsr
con["country, org<br>&amp; common name"] --> scsr
scsr --> ssi
ssi --> skc
```
## Device credentials
For each device, create a device id to generate a device-specific private key and certificate using the `device_creds` cli.
Use the same CA as was used for the server (above) so the client & server recognize each other.
```mermaid
---
config:
theme: redux
---
flowchart LR
subgraph ca_cred["ca_cred #40;cli#41; or other CA"]
ca["ca key & cert"]
end
subgraph device_cred["device_cred fl°°40¶ßclifl°°41¶ß"]
ccsr("certificate signing<br>request fl°°40¶ßCSRfl°°41¶ß with<br>SAN of URI &amp; DNS")
cpk["private key"]
csi["sign csr"]
end
cpk --> ccsr
csi --> ckc[device key & cert]
cpk -.-> ckc
ccsr --> csi
ca_cred --> csi
con["country, org<br/>common name<br/>& device id"] --> ccsr
```
## Configuration
::: amqtt.contrib.cert.UserAuthCertPlugin.Config
options:
show_source: false
heading_level: 4
extra:
class_style: "simple"
## Key and Certificate Generation
::: mkdocs-typer2
:module: amqtt.scripts.ca_creds
:name: ca_creds
::: mkdocs-typer2
:module: amqtt.scripts.server_creds
:name: server_creds
::: mkdocs-typer2
:module: amqtt.scripts.device_creds
:name: device_creds

Wyświetl plik

@ -0,0 +1,45 @@
# Contributed Plugins
These are fully supported plugins but require additional dependencies to be installed:
`$ pip install '.[contrib]'`
- [Relational Database Auth](auth_db.md)<br/>
Grant or deny access to clients based on entries in a relational db (mysql, postgres, maria, sqlite). _Includes
manager script to add, remove and create db entries_<br/>
- `amqtt.contrib.auth_db.UserAuthDBPlugin`
- `amqtt.contrib.auth_db.TopicAuthDBPlugin`
- [HTTP Auth](http.md)<br/>
Determine client authentication and/or authorization based on response from a separate HTTP server.<br/>
- `amqtt.contrib.http.UserAuthHttpPlugin`
- `amqtt.contrib.http.TopicAuthHttpPlugin`
- [LDAP Auth](ldap.md)<br/>
Authenticate a user with an LDAP directory server.<br/>
- `amqtt.contrib.ldap.UserAuthLdapPlugin`
- `amqtt.contrib.ldap.TopicAuthLdapPlugin`
- [Shadows](shadows.md)<br/>
Device shadows provide a persistent, cloud-based representation of the state of a device,
even when the device is offline. This plugin tracks the desired and reported state of a client
and provides MQTT topic-based communication channels to retrieve and update a shadow.<br/>
`amqtt.contrib.shadows.ShadowPlugin`
- [Certificate Auth](cert.md)<br/>
Using client-specific certificates, signed by a common authority (even if self-signed) provides
a highly secure way of authenticating mqtt clients. Often used with IoT devices where a unique
certificate can be initialized on initial provisioning. _Includes command line utilities to generate
root, broker and device certificates with the correct X509
attributes to enable authenticity._<br/>
`amqtt.contrib.cert.UserAuthCertPlugin`
- [JWT Auth](jwt.md)<br/>
Plugin to determine user authentication and topic authorization based on claims in a JWT.
- `amqtt.contrib.jwt.UserAuthJwtPlugin` (client authentication)
- `amqtt.contrib.jwt.TopicAuthJwtPlugin` (topic authorization)
- [Session Persistence](session.md)<br/>
Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.<br/>
`amqtt.contrib.persistence.SessionDBPlugin`

Wyświetl plik

@ -77,7 +77,9 @@ and then run via `amqtt -c myBroker.yaml`.
variables to configure its behavior.
::: amqtt.plugins.base.BasePlugin
options:
show_source: false
heading_level: 3
## Events
@ -85,16 +87,13 @@ and then run via `amqtt -c myBroker.yaml`.
All plugins are notified of events if the `BasePlugin` subclass implements one or more of these methods:
### Client and Broker
### Client
- `async def on_mqtt_packet_sent(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
- `async def on_mqtt_packet_received(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
### Client Only
none
### Broker Only
### Broker
- `async def on_broker_pre_start(self) -> None`
- `async def on_broker_post_start(self) -> None`
@ -107,29 +106,59 @@ none
- `async def on_broker_client_connected(self, *, client_id:str) -> None`
- `async def on_broker_client_disconnected(self, *, client_id:str) -> None`
- `async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None`
- `async def on_broker_client_subscribed(self, *, client_id: str, topic: str, qos: int) -> None`
- `async def on_broker_client_unsubscribed(self, *, client_id: str, topic: str) -> None`
- `async def on_broker_message_received(self, *, client_id: str, message: ApplicationMessage) -> None`
- `async def on_broker_message_broadcast(self, *, client_id: str, message: ApplicationMessage) -> None`
- `async def on_mqtt_packet_sent(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
- `async def on_mqtt_packet_received(self, *, packet: MQTTPacket[MQTTVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader], session: Session | None = None) -> None`
!!! note retained message event
if the `client_id` is `None`, the message is retained for a topic
if the `retained_message.data` is `None` or empty (`''`), the topic message is being cleared
## Authentication Plugins
In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin`
is used by the aMQTT `Broker` to determine if a connection from a client is allowed by
implementing the `authenticate` method and returning `True` if the session is allowed or `False` otherwise.
implementing the `authenticate` method and returning:
- `True` if the session is allowed
- `False` if not allowed
- `None` if plugin can't determine authentication
If there are multiple authentication plugins:
- at least one plugin must return `True` to allow access
- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access)
- `None` gets ignored from the determination
::: amqtt.plugins.base.BaseAuthPlugin
options:
show_source: false
heading_level: 3
## Topic Filter Plugins
In addition to receiving any of the event callbacks, a plugin which is subclassed from `BaseTopicPlugin`
is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH) or receive (SUBSCRIBE)
messages to a particular topic by implementing the `topic_filtering` method and returning `True` if allowed or
`False` otherwise.
is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH), receive (RECEIVE)
and/or subscribe (SUBSCRIBE) messages to a particular topic by implementing the `topic_filtering` method and returning:
- `True` if topic is allowed
- `False` if not allowed
- `None` will be ignored
If there are multiple topic plugins:
- at least one plugin must return `True` to allow access
- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access)
- `None` will be ignored
::: amqtt.plugins.base.BaseTopicPlugin
options:
show_source: false
heading_level: 3
!!! note
A custom plugin class can subclass from both `BaseAuthPlugin` and `BaseTopicPlugin` as long it defines

Wyświetl plik

@ -0,0 +1,112 @@
# Authentication & Authorization via external HTTP server
If clients accessing the broker are managed by another application, it can implement API endpoints
that respond with information about client authentication and/or topic-level authorization.
- `amqtt.contrib.http.UserAuthHttpPlugin` (client authentication)
- `amqtt.contrib.http.TopicAuthHttpPlugin` (topic authorization)
Configuration of these plugins is identical (except for the uri name) so that they can be used independently, if desired.
# User Auth
See the [Request and Response Modes](#request-response-modes) section below for details on `params_mode` and `response_mode`.
!!! info "browser-based mqtt over websockets"
One of the primary use cases for this plugin is to enable browser-based applications to communicate with mqtt
over websockets.
!!! warning
Care must be taken to make sure the mqtt password is secure (encrypted).
For more implementation information:
??? info "recipe for authentication"
Provide the client id and username when webpage is initially rendered or passed to the mqtt initialization from stored
cookies. If application is secure, the user's password will already be stored as a hashed value and, therefore, cannot
be used in this context to authenticate a client. Instead, the application should create its own encrypted key (eg jwt)
which the server can then verify when the broker contacts the application.
??? example "mqtt in javascript"
Example initialization of mqtt in javascript:
import mqtt from 'mqtt';
const url = 'https://path.to.amqtt.broker';
const options = {
'myclientid',
connectTimeout: 30000,
username: 'myclientid',
password: '' // encrypted password
};
try {
const clientMqtt = await mqtt.connect(url, options);
::: amqtt.contrib.http.UserAuthHttpPlugin.Config
options:
show_source: false
heading_level: 4
extra:
class_style: "simple"
# Topic ACL
See the [Request and Response Modes](#request-response-modes) section below for details on `params_mode` and `response_mode`.
::: amqtt.contrib.http.TopicAuthHttpPlugin.Config
options:
show_source: false
heading_level: 4
extra:
class_style: "simple"
[//]: # (manually creating the heading so it doesn't show in the sidebar ToC)
[](){#request-response-modes}
<h2>Request and Response Modes</h2>
Each URI endpoint will receive different information in order to determine authentication and authorization;
format will depend on `params_mode` configuration attribute (`json` or `form`).:
*For user authentication, the request will contain:*
- username *(str)*
- password *(str)*
- client_id *(str)*
*For acl check, the request will contain:*
- username *(str)*
- client_id *(str)*
- topic *(str)*
- acc *(int)* : client can receive (1), can publish(2), can receive & publish (3) and can subscribe (4)
All endpoints should respond with the following, dependent on `response_mode` configuration attribute:
*In `status` mode:*
- status code: 2xx (granted) or 4xx(denied) or 5xx (noop)
!!! note "5xx response"
**noop** (no operation): plugin will not participate in the operation and will defer to another
plugin to determine access. if there is no other auth/filtering plugin, access will be denied.
*In `json` mode:*
- status code: 2xx
- content-type: application/json
- response: {'ok': True } (granted)
or {'ok': False, 'error': 'optional error message' } (denied)
or { 'error': 'optional error message' } (noop)
!!! note "excluded 'ok' key"
**noop** (no operation): plugin will not participate in the operation and will defer to another
plugin to determine access. if there is no other auth/filtering plugin, access will be denied.
*In `text` mode:*
- status code: 2xx
- content-type: text/plain
- response: 'ok' or 'error'
!!! note "noop not supported"
in text mode, noop (no operation) is not supported

Wyświetl plik

@ -0,0 +1,50 @@
# Authentication & Authorization from JWT
- `amqtt.contrib.jwt.UserAuthJwtPlugin` (client authentication)
- `amqtt.contrib.jwt.TopicAuthJwtPlugin` (topic authorization)
Plugin to determine user authentication and topic authorization based on claims in a JWT.
# User Authentication
For auth, the JWT should include a key as specified in the configuration as `user_clam`:
```python
from datetime import datetime, UTC, timedelta
claims = {
"username": "example_user",
"exp": datetime.now(UTC) + timedelta(hours=1),
}
```
::: amqtt.contrib.jwt.UserAuthJwtPlugin.Config
options:
show_source: false
heading_level: 4
extra:
class_style: "simple"
# Topic Authorization
For authorizing a client for certain topics, the token should also include claims for publish, subscribe and receive;
keys based on how `publish_claim`, `subscribe_claim` and `receive_claim` are specified in the plugin's configuration.
```python
from datetime import datetime, UTC, timedelta
claims = {
"username": "example_user",
"exp": datetime.now(UTC) + timedelta(hours=1),
"publish_acl": ['my/topic/#', 'my/+/other'],
"subscribe_acl": ['my/+/other'],
"receive_acl": ['#']
}
```
::: amqtt.contrib.jwt.TopicAuthJwtPlugin.Config
options:
show_source: false
heading_level: 4
extra:
class_style: "simple"

Wyświetl plik

@ -0,0 +1,25 @@
# Authentication with LDAP Server
If clients accessing the broker are managed by an LDAP server, this plugin can verify credentials
for client authentication and/or topic-level authorization.
- `amqtt.contrib.ldap.UserAuthLdapPlugin` (client authentication)
- `amqtt.contrib.ldap.TopicAuthLdapPlugin` (topic authorization)
Authenticate a user with an LDAP directory server.
# User Auth
::: amqtt.contrib.ldap.UserAuthLdapPlugin.Config
options:
heading_level: 4
extra:
class_style: "simple"
# Topic Auth (ACL)
::: amqtt.contrib.ldap.TopicAuthLdapPlugin.Config
options:
heading_level: 4
extra:
class_style: "simple"

Wyświetl plik

@ -1,4 +1,4 @@
# Existing Plugins
# Packaged Plugins
With the aMQTT plugins framework, one can add additional functionality without
having to rewrite core logic in the broker or client. Plugins can be loaded and configured using
@ -23,7 +23,7 @@ and configured for the broker:
--8<-- "pyproject.toml:included"
```
But the same 4 plugins were activated in the previous default config:
But the previous default config only caused 4 plugins to be active:
```yaml
--8<-- "samples/legacy.yaml"
@ -31,7 +31,7 @@ and configured for the broker:
## Client
By default, the `PacketLoggerPlugin` is activated and configured for the client:
By default, the `PacketLoggerPlugin` is activated and configured for the client:
```yaml
--8<-- "amqtt/scripts/default_client.yaml"
@ -43,15 +43,13 @@ By default, the `PacketLoggerPlugin` is activated and configured for the clien
`amqtt.plugins.authentication.AnonymousAuthPlugin`
**Configuration**
Authentication plugin allowing anonymous access.
```yaml
plugins:
.
.
amqtt.plugins.authentication.AnonymousAuthPlugin:
allow_anonymous: false
```
::: amqtt.plugins.authentication.AnonymousAuthPlugin.Config
options:
heading_level: 4
extra:
class_style: "simple"
!!! danger
even if `allow_anonymous` is set to `false`, the plugin will still allow access if a username is provided by the client
@ -71,15 +69,14 @@ plugins:
`amqtt.plugins.authentication.FileAuthPlugin`
clients are authorized by providing username and password, compared against file
Authentication plugin based on a file-stored user database.
**Configuration**
::: amqtt.plugins.authentication.FileAuthPlugin.Config
options:
heading_level: 4
extra:
class_style: "simple"
```yaml
plugins:
amqtt.plugins.authentication.FileAuthPlugin:
password_file: /path/to/password_file
```
??? warning "EntryPoint-style configuration is deprecated"
```yaml
@ -134,13 +131,26 @@ plugins:
**Configuration**
- `acl` *(mapping)*: determines subscription access
The list should be a key-value pair, where:
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
Each acl category are a list a key-value pair, where:
> `<username>:["<topic1>", "<topic2>", ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
- `publish-acl` *(mapping)*: determines publish access. If absent, no restrictions are placed on client publishing.
`<username>:[<topic1>, <topic2>, ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`).
!!! info "`#` and `$SYS` topics"
Per the MQTT 3.1.1 spec 4.7.2, a single `#` will not allow access to `$` broker
topics; need to additionally specify `$SYS/#` to allow a client full access subscribe & receive.
Also MQTT spec prevents clients from publishing to topics starting with `$`; these will be ignored.
If set to `None`, no restrictions are placed on client subscriptions (legacy behavior). An empty list will block clients from using any topics.
- `subscribe-acl` *(mapping)*: determines subscription access.
- `acl` *(mapping)*: Deprecated and replaced by `subscribe-acl`.
- `publish-acl` *(mapping)*: determines publish access.
- `receive-acl` *(mapping)*: determines if a message can be sent to a client.
!!! info "Reserved usernames"
@ -235,7 +245,7 @@ plugins:
`amqtt.plugins.logging_amqtt.PacketLoggerPlugin`
This plugin issues debug-level messages for [mqtt events](custom_plugins.md#client-and-broker): `on_mqtt_packet_sent`
This plugin issues debug-level messages for [mqtt events](custom_plugins.md#events): `on_mqtt_packet_sent`
and `on_mqtt_packet_received`.
```yaml

Wyświetl plik

@ -0,0 +1,12 @@
# Session Persistence
`amqtt.contrib.persistence.SessionDBPlugin`
Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
::: amqtt.contrib.persistence.SessionDBPlugin.Config
options:
show_source: false
heading_level: 4
extra:
class_style: "simple"

Wyświetl plik

@ -0,0 +1,202 @@
# Device Shadows Plugin
Device shadows provide a persistent, cloud-based representation of the state of a device,
even when the device is offline. This plugin tracks the desired and reported state of a client
and provides MQTT topic-based communication channels to retrieve and update a shadow.
Typically, this structure is used for MQTT IoT devices to communicate with a central application.
This plugin is patterned after [AWS's IoT Shadow](https://docs.aws.amazon.com/iot/latest/developerguide/iot-device-shadows.html) service.
## How it works
All shadow states are associated with a `device id` and `name` and have the following structure:
```json
{
"state": {
"desired": {
"property1": "value1"
},
"reported": {
"property1": "value1"
}
},
"metadata": {
"desired": {
"property1": {
"timestamp": 1623855600
}
},
"reported": {
"property1": {
"timestamp": 1623855602
}
}
},
"version": 10,
"timestamp": 1623855602
}
```
The `state` is updated by messages to shadow topics and includes key/value pairs, where the value can be any valid
json object (int, string, dictionary, list, etc). `metadata` is automatically updated by the plugin based on when
the key/values were most recently updated. Both `state` and `metadata` are split between:
- desired: the intended state of a device
- reported: the actual state of a device
A client can update a part or all of the desired or reported state. On any update, the plugin:
- updates the 'state' portion of the shadow with any key/values provided in the update
- stores a version of the update
- tracks the timestamp of each key/value pair change
- sends messages that the shadow was updated
## Typical usage
As mentioned above, this plugin is often used for MQTT IoT devices to communicate with a central application. The
app pushes updates to a device's 'desired' shadow state and the device can confirm the change was made by updating
the 'reported' state. With this sequence the 'desired' state matches the 'reported' state and the delta message is empty.
In most situations, the app only updates the 'desired' state and the device only updates the 'reported' state.
If online, the IoT device will receive and can act on that information immediately. If offline, the app doesn't need
to republish or retry a change 'command', waiting for an acknowledgement from the device. If a device is offline, it
simply retrieves the configuration changes when it comes back online.
Once a device receives its desired state, it should either (1) update its reported state to match the change in desired
or (2) if the desired state is invalid, clear that key/value from the desired state. The latter is the only case
when a device should update its own 'desired' state.
For example, if the app sends a command to set the brightness of a device to 100 lumens, but the device only supports
a maximum of 80, it can send an update `'state': {'desired': {'lumens': null}}` to clear the invalid state.
The reported state can (and most likely will) include key/values that will never show up in the desired state. For
example, the app might set the thermostat to 70 and the device reports both the configuration change of 70 to the
thermostat *and* the current temperature of the room.
```json
{
"state": {
"desired": {
"thermostat": 68
},
"reported": {
"thermostat": 68,
"temperature": 78
}
}
}
```
!!! note "desired and reported state structure"
It is important that both the app and the device have the same understanding of the key/value
state structure and units. Creating [JSON schemas](https://json-schema.org/) for desired and
reported shadow states are very useful as it can provide a clear way of describing the schema.
These schemas can also be used to generate [dataclasses](https://pypi.org/project/datamodel-code-generator/),
[pojos](https://github.com/joelittlejohn/jsonschema2pojo) or [many other language constructs](https://json-schema.org/tools?query=&sortBy=name&sortOrder=ascending&groupBy=toolingTypes&licenses=&languages=&drafts=&toolingTypes=&environments=&showObsolete=false&supportsBowtie=false#schema-to-code) that
can be easily included by both app and device to make state encoding and decoding consistent.
## Shadow state access
All shadows are addressed by using specific topics, all of which have the following base:
`$shadow/<device_id>/<shadow name>`
Clients send either `get` or `update` messages:
| Operation | Topic | Direction | Payload |
|-------------------------|-------------------------------------------------------|-----------|-----------------------------------------------------|
| **Update** | `$shadow/{device_id}/{shadow_name}/update` | → | `{ "state": { "desired" or "reported": ... } }` |
| **Get** | `$shadow/{device_id}/{shadow_name}/get` | → | Empty message triggers get accepted or rejected |
Then clients can subscribe to any or all of these topics which receive messages issued by the plugin:
| Operation | Topic | Direction | Payload |
|-------------------------|-------------------------------------------------------|-----------|-----------------------------------------------------|
| **Update Accepted** | `$shadow/{device_id}/{shadow_name}/update/accepted` | ← | Full updated document |
| **Update Rejected** | `$shadow/{device_id}/{shadow_name}/update/rejected` | ← | Error message |
| **Update Documents** | `$shadow/{device_id}/{shadow_name}/update/documents` | ← | Full current & previous shadow documents |
| **Get Accepted** | `$shadow/{device_id}/{shadow_name}/get/accepted` | ← | Full shadow document |
| **Get Rejected** | `$shadow/{device_id}/{shadow_name}/get/rejected` | ← | Error message |
| **Delta** | `$shadow/{device_id}/{shadow_name}/update/delta` | ← | Difference between desired and reported |
| **Iota** | `$shadow/{device_id}/{shadow_name}/update/iota` | ← | Difference between desired and reported, with nulls |
## Delta messages
While the 'accepted' and 'documents' messages carry the full desired and reported states, this plugin also generates
a 'delta' message - containing items in the desired state that are different from those items in the reported state. This
topic optimizes for IoT devices which typically have lower bandwidth and not as powerful processing by (1) to reducing the
amount of data transmitted and (2) simplifying device implementation as it only needs to respond to differences.
While shadows are stateful since delta messages are only based on the desired and reported state and *not on the previous
and current state*. Therefore, it doesn't matter if an IoT device is offline and misses a delta message. When it comes
back online, the delta is identical.
This is also an improvement over a connection without the clean flag and QoS > 0. When an IoT device comes back online, bandwidth
isn't consumed and the IoT device does not have to process a backlog of messages to understand how it should behave.
For a setting -- such as volume -- that goes from 80 then to 91 and then to 60 while the device is offline, it will
only receive a single change that its volume should now be 60.
| Reported Shadow State | Desired Shadow State | Resulting Delta Message (`delta`) |
|----------------------------------------|------------------------------------------|---------------------------------------|
| `{ "temperature": 70 }` | `{ "temperature": 72 }` | `{ "temperature": 72 }` |
| `{ "led": "off", "fan": "low" }` | `{ "led": "on", "fan": "low" }` | `{ "led": "on" }` |
| `{ "door": "closed" }` | `{ "door": "closed", "alarm": "armed" }` | `{ "alarm": "armed" }` |
| `{ "volume": 10 }` | `{ "volume": 10 }` | *(no delta; states match)* |
| `{ "brightness": 100 }` | `{ "brightness": 80, "mode": "eco" }` | `{ "brightness": 80, "mode": "eco" }` |
| `{ "levels": [1, 10, 4]}` | `{"levels": [1, 4, 10]}` | `{"levels": [1, 4, 10]}` |
| `{ "brightness": 100, "mode": "eco" }` | `{ "brightness": 80 }` | `{ "brightness": 80}` |
## Iota messages
Typically, null values never show in any received update message as a null signals the removal of a key from the desired
or reported state. However, if the app removes a key from the desired state -- such as a piece of state that is no longer
needed or applicable -- the device won't receive any notification of this deletion in a delta messages.
These messages are very similar to 'delta' messages as they also contain items in the desired state that are different from
those in the reported state; it *also* contains any items in the reported state that are *missing* from the desired
state (last row in table).
| Reported Shadow State | Desired Shadow State | Resulting Delta Message (`delta`) |
|----------------------------------------|------------------------------------------|-----------------------------------------|
| `{ "temperature": 70 }` | `{ "temperature": 72 }` | `{ "temperature": 72 }` |
| `{ "led": "off", "fan": "low" }` | `{ "led": "on", "fan": "low" }` | `{ "led": "on" }` |
| `{ "door": "closed" }` | `{ "door": "closed", "alarm": "armed" }` | `{ "alarm": "armed" }` |
| `{ "volume": 10 }` | `{ "volume": 10 }` | *(no delta; states match)* |
| `{ "brightness": 100 }` | `{ "brightness": 80, "mode": "eco" }` | `{ "brightness": 80, "mode": "eco" }` |
| `{ "levels": [1, 10, 4]}` | `{"levels": [1, 4, 10]}` | `{"levels": [1, 4, 10]}` |
| `{ "brightness": 100, "mode": "eco" }` | `{ "brightness": 80 }` | `{ "brightness": 80, "mode": null }` |
## Configuration
::: amqtt.contrib.shadows.ShadowPlugin.Config
options:
show_source: false
heading_level: 4
extra:
class_style: "simple"
## Security
Often a device only needs access to get/update and receive changes in its own shadow state. In addition to the `ShadowPlugin`,
included is the `ShadowTopicAuthPlugin`. This allows (authorizes) a device to only subscribe, publish and receive its own topics.
::: amqtt.contrib.shadows.ShadowTopicAuthPlugin.Config
options:
show_source: false
heading_level: 4
extra:
class_style: "simple"
!!! warning
`ShadowTopicAuthPlugin` only handles topic authorization. Another plugin should be used to authenticate client device
connections to the broker. See [file auth](packaged_plugins.md#password-file-auth-plugin),
[http auth](http.md), [db auth](auth_db.md) or [certificate auth](cert.md) plugins. Or create your own:
[auth plugins](custom_plugins.md#authentication-plugins):

Wyświetl plik

@ -1,4 +1,4 @@
#
# Broker
::: mkdocs-typer2
:module: amqtt.scripts.broker_script

Wyświetl plik

@ -21,15 +21,5 @@ The `amqtt.broker` module provides the following key methods in the `Broker` cla
- `start()`: Starts the broker and begins serving
- `shutdown()`: Gracefully shuts down the broker
### Broker configuration
The `Broker` class's `__init__` method accepts a `config` parameter which allows setup of default and custom behaviors.
Details on the `config` parameter structure is a dictionary whose structure is identical to yaml formatted file[^1]
used by the included broker script: [broker configuration](broker_config.md)
::: amqtt.broker.Broker
[^1]: See [PyYAML](http://pyyaml.org/wiki/PyYAMLDocumentation) for loading YAML files as Python dict.

Wyświetl plik

@ -1,42 +1,31 @@
# Broker Configuration
This configuration structure is valid as a python dictionary passed to the `amqtt.broker.Broker` class's `__init__` method or
as a yaml formatted file passed to the `amqtt` script.
### `listeners` *(list[dict[str, Any]])*
Defines the network listeners used by the service. Items defined in the `default` listener will be
applied to all other listeners, unless they are overridden by the configuration for the specific
listener.
- `default` | `<listener_name>`: Named listener
- `type` *(string)*: Transport type. Can be `tcp` or `ws`.
- `bind` *(string)*: IP address and port (e.g., `0.0.0.0:1883`)
- `max-connections` *(integer)*: Maximum number of clients that can connect to this interface
- `ssl` *(string)*: Enable SSL connection. Can be `on` or `off` (default: off).
- `cafile` *(string)*: Path to a file of concatenated CA certificates in PEM format. See [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info.
- `capath` *(string)*: Path to a directory containing several CA certificates in PEM format, following an [OpenSSL specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/).
- `cadata` *(string)*: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
- `certfile` *(string)*: Path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the certificate's authenticity.
- `keyfile` *(string): A file containing the private key. Otherwise the private key will be taken from `certfile` as well.
### `timeout-disconnect-delay` *(int)*
Client disconnect timeout without a keep-alive.
### `plugins` *(mapping)*
A list of strings representing the modules and class name of `BasePlugin`, `BaseAuthPlugin` and `BaseTopicPlugins`. Each
entry may have one or more configuration settings. For more information, see the [configuration of the included plugins](../packaged_plugins.md)
??? warning "Deprecated: `sys_interval` "
**`sys_interval`** *(int)*
System status report interval in seconds, used by the `amqtt.plugins.sys.broker.BrokerSysPlugin`.
This configuration structure is a `amqtt.contexts.BrokerConfig` or a python dictionary with the same structure
when instantiating `amqtt.broker.Broker` or as a yaml formatted file passed to the `amqtt` script.
If not specified, the `Broker()` will be started with the default `BrokerConfig()`, as represented in yaml format:
```yaml
---
listeners:
default:
type: tcp
bind: 0.0.0.0:1883
timeout_disconnect_delay: 0
plugins:
amqtt.plugins.logging_amqtt.EventLoggerPlugin:
amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
amqtt.plugins.authentication.AnonymousAuthPlugin:
allow_anonymous: true
amqtt.plugins.sys.broker.BrokerSysPlugin:
sys_interval: 20
```
::: amqtt.contexts.BrokerConfig
options:
heading_level: 3
extra:
class_style: "simple"
??? warning "Deprecated: `auth` configuration settings"
@ -64,6 +53,13 @@ entry may have one or more configuration settings. For more information, see the
- `password-file` *(string)*. Path to sha-512 encoded password file, used by `amqtt.plugins.authentication.FileAuthPlugin`.
??? warning "Deprecated: `sys_interval` "
**`sys_interval`** *(int)*
System status report interval in seconds, used by the `amqtt.plugins.sys.broker.BrokerSysPlugin`.
??? warning "Deprecated: `topic-check` configuration settings"
@ -88,19 +84,17 @@ entry may have one or more configuration settings. For more information, see the
- The username `admin` is allowed access to all topic.
- The username `anonymous` will control allowed topics if using the `auth_anonymous` plugin.
## Default Configuration
```yaml
--8<-- "amqtt/scripts/default_broker.yaml"
```
::: amqtt.contexts.ListenerConfig
options:
heading_level: 3
extra:
class_style: "simple"
## Example
When a configuration is passed to the `amqtt` script, here is the equivalent format based on the structures above:
```yaml
listeners:
default:
@ -149,7 +143,7 @@ This configuration file would create the following listeners:
- `my-ws-1`: an unsecured websocket listener on port 9001 allowing `500` clients connections simultaneously
- `my-wss-1`: a secured websocket listener on port 9003 allowing `500`
And enable the following access controls:
And enable the following topic access:
- `username1` to login and subscribe/publish to topics `repositories/+/master`, `calendar/#` and `data/memes`
- `username2` to login and subscribe/publish to topics `calendar/2025/#` and `data/memes`

Wyświetl plik

@ -129,15 +129,4 @@ amqtt/LYRf52W[56SOjW04 <-in-- PubcompPacket(ts=2015-11-11 21:54:48.713107, fixed
Both coroutines have the same results except that `test_coro2()` manages messages flow in parallel which may be more efficient.
### Client configuration
The `MQTTClient` class's `__init__` method accepts a `config` parameter which allows setup of default and custom behaviors.
Details on the `config` parameter structure is a dictionary whose structure is identical to yaml formatted file[^1]
used by the included broker script: [client configuration](client_config.md)
::: amqtt.client.MQTTClient
[^1]: See [PyYAML](http://pyyaml.org/wiki/PyYAMLDocumentation) for loading YAML files as Python dict.

Wyświetl plik

@ -1,97 +1,57 @@
# Client Configuration
This configuration structure is valid as a python dictionary passed to the `amqtt.broker.MQTTClient` class's `__init__` method or
as a yaml formatted file passed to the `amqtt_pub` script.
This configuration structure is either a `amqtt.contexts.ClientConfig` or a python dictionary with identical structure
when instantiating `amqtt.broker.MQTTClient` or as a yaml formatted file passed to the `amqtt_pub` script.
### `keep_alive` *(int)*
Keep-alive timeout sent to the broker. Defaults to `10` seconds.
### `ping_delay` *(int)*
Auto-ping delay before keep-alive timeout. Defaults to 1. Setting to `0` will disable to 0 and may lead to broker disconnection.
### `default_qos` *(int: 0-2)*
Default QoS for messages published. Defaults to 0.
### `default_retain` *(bool)*
Default retain value to messages published. Defaults to `false`.
### `auto_reconnect` *(bool)*
Enable or disable auto-reconnect if connection with the broker is interrupted. Defaults to `false`.
### `connect_timeout` *(int)*
If specified, the number of seconds before a connection times out
### `reconnect_retries` *(int)*
Maximum reconnection retries. Defaults to `2`. Negative value will cause client to reconnect infinitely.
### `reconnect_max_interval` *(int)*
Maximum interval between 2 connection retry. Defaults to `10`.
### `cleansession` *(bool)*
Upon reconnect, should subscriptions be cleared. Defaults to `true`.
### `topics` *(list[mapping])*
Specify the topics and what flags should be set for messages published to them.
- `<topic>`: Named listener
- `qos` *(int, 0-3)*:
- `retain` *(bool)*:
### `will` *(mapping)*
If included, the message that should be sent if the client disconnects.
- `topic` *(string)*:
- `message` *(string)*:
- `qos` *(int): 0, 1 or 2
- `retain`: *(bool)* new clients subscribing to `topic` will receive this message
### `broker` *(mapping)*
- `uri` *(string)*: Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*
TLS certificates used to verify the broker's authenticity.
- `cafile` *(string)*: Path to a file of concatenated CA certificates in PEM format. See [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info.
- `capath` *(string)*: Path to a directory containing several CA certificates in PEM format, following an [OpenSSL specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/).
- `cadata` *(string)*: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
### `certfile` *(string)*
Path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the server certificate's authenticity.
### `check_hostname` *(bool)*
Bypass ssl host certificate verification, allowing self-signed certificates
### `plugins` *(mapping)*
A list of strings representing the modules and class name of any `BasePlugin`s. Each entry may have one or more
configuration settings. For more information, see the [configuration of the included plugins](../packaged_plugins.md)
## Default Configuration
If not specified, the `MQTTClient()` will be started with the default `ClientConfig()`, as represented in yaml format:
```yaml
--8<-- "amqtt/scripts/default_client.yaml"
---
keep_alive: 10
ping_delay: 1
default_qos: 0
default_retain: false
auto_reconnect: true
connection_timeout: 60
reconnect_retries: 2
reconnect_max_interval: 10
cleansession: true
broker:
uri: "mqtt://127.0.0.1"
plugins:
amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
```
::: amqtt.contexts.ClientConfig
options:
heading_level: 3
extra:
class_style: "simple"
::: amqtt.contexts.TopicConfig
options:
heading_level: 3
extra:
class_style: "simple"
::: amqtt.contexts.WillConfig
options:
heading_level: 3
extra:
class_style: "simple"
::: amqtt.contexts.ConnectionConfig
options:
heading_level: 3
extra:
class_style: "simple"
## Example
A more expansive `ClientConfig` in equivalent yaml format:
```yaml
keep_alive: 10
@ -102,9 +62,9 @@ auto_reconnect: true
reconnect_max_interval: 5
reconnect_retries: 10
topics:
test:
topic/subtopic:
qos: 0
some_topic:
topic/other:
qos: 2
retain: true
will:
@ -113,9 +73,8 @@ will:
qos: 1
retain: false
broker:
uri: mqtt://localhost:1883
cafile: /path/to/ca/file
uri: 'mqtt://localhost:1883'
cafile: '/path/to/ca/file'
plugins:
- amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
```

Wyświetl plik

@ -2,20 +2,26 @@
This document describes `aMQTT` common API both used by [MQTT Client](client.md) and [Broker](broker.md).
## Reference
## ApplicationMessage
### ApplicationMessage
::: amqtt.session.ApplicationMessage
options:
heading_level: 3
The `amqtt.session` module provides the following message classes:
## IncomingApplicationMessage
#### ApplicationMessage
Represents messages received from MQTT clients.
Base class for MQTT application messages.
::: amqtt.session.IncomingApplicationMessage
options:
heading_level: 3
#### IncomingApplicationMessage
Inherits from ApplicationMessage. Represents messages received from MQTT clients.
#### OutgoingApplicationMessage
## OutgoingApplicationMessage
Inherits from ApplicationMessage. Represents messages to be sent to MQTT clients.
::: amqtt.session.OutgoingApplicationMessage
options:
heading_level: 3

Wyświetl plik

@ -0,0 +1,140 @@
import ast
import pprint
from typing import Any
import griffe
from griffe import Inspector, ObjectNode, Visitor, Attribute
from amqtt.contexts import default_listeners, default_broker_plugins, default_client_plugins
from amqtt.contrib.auth_db.plugin import default_hash_scheme
default_factory_map = {
'default_listeners': default_listeners(),
'default_broker_plugins': default_broker_plugins(),
'default_client_plugins': default_client_plugins(),
'default_hash_scheme': default_hash_scheme()
}
def get_qualified_name(node: ast.AST) -> str | None:
"""Recursively build the qualified name from an AST node."""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
parent = get_qualified_name(node.value)
if parent:
return f"{parent}.{node.attr}"
return node.attr
elif isinstance(node, ast.Call):
# e.g., uuid.uuid4()
return get_qualified_name(node.func)
return None
def get_fully_qualified_name(call_node):
"""
Extracts the fully qualified name from an ast.Call node.
"""
if isinstance(call_node.func, ast.Name):
# Direct function call (e.g., "my_function(arg)")
return call_node.func.id
elif isinstance(call_node.func, ast.Attribute):
# Method call or qualified name (e.g., "obj.method(arg)" or "module.submodule.function(arg)")
parts = []
current = call_node.func
while isinstance(current, ast.Attribute):
parts.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.append(current.id)
return ".".join(reversed(parts))
else:
# Handle other potential cases (e.g., ast.Subscript) if necessary
return None
def get_callable_name(node):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return f"{get_callable_name(node.value)}.{node.attr}"
return None
def evaluate_callable_node(node):
try:
# Wrap the node in an Expression so it can be compiled
expr = ast.Expression(body=node)
compiled = compile(expr, filename="<ast>", mode="eval")
return eval(compiled, {"__builtins__": __builtins__, "list": list, "dict": dict})
except Exception as e:
return f"<unresolvable: {e}>"
class DataclassDefaultFactoryExtension(griffe.Extension):
"""Renders the output of a dataclasses field which uses a default factory.
def other_field_defaults():
return {'item1': 'value1', 'item2': 'value2'}
@dataclass
class MyDataClass:
my_field: dict[str, Any] = field(default_factory=dict)
my_other_field: dict[str, Any] = field(default_factory=other_field_defaults)
instead of documentation rendering this as:
```
class MyDataClass:
my_field: dict[str, Any] = dict()
my_other_field: dict[str, Any] = other_field_defaults()
```
it will be displayed with the output of factory functions for more clarity:
```
class MyDataClass:
my_field: dict[str, Any] = {}
my_other_field: dict[str, Any] = {'item1': 'value1', 'item2': 'value2'}
```
_note_ : for any custom default factory function, it must be added to the `default_factory_map`
in this file as `griffe` doesn't provide a straightforward mechanism with its AST to dynamically
import/call the function.
"""
def on_attribute_instance(
self,
*,
node: ast.AST | ObjectNode,
attr: Attribute,
agent: Visitor | Inspector,
**kwargs: Any,
) -> None:
"""Called for every `node` and/or `attr` on a file's AST."""
if not hasattr(node, "value"):
return
if isinstance(node.value, ast.Call):
# Search for all of the `default_factory` fields.
default_factory_value: str | None = None
for kw in node.value.keywords:
if kw.arg == "default_factory":
# based on the node type, return the proper function name
match get_callable_name(kw.value):
# `dict` and `list` are common default factory functions
case 'dict':
default_factory_value = "{}"
case 'list':
default_factory_value = "[]"
case _:
# otherwise, see the nodes is in our map for the custom default factory function
callable_name = get_callable_name(kw.value)
if callable_name in default_factory_map:
default_factory_value = pprint.pformat(default_factory_map[callable_name], indent=4, width=80, sort_dicts=False)
else:
# if not, display as the default
default_factory_value = f"{callable_name}()"
# store the information in the griffe attribute, which is what is passed to the template for rendering
if "dataclass_ext" not in attr.extra:
attr.extra["dataclass_ext"] = {}
attr.extra["dataclass_ext"]["has_default_factory"] = False
if default_factory_value is not None:
attr.extra["dataclass_ext"]["has_default_factory"] = True
attr.extra["dataclass_ext"]["default_factory"] = default_factory_value

1
docs/templates/README vendored 100644
Wyświetl plik

@ -0,0 +1 @@
template overrides for mkdocs-materials

Wyświetl plik

@ -0,0 +1,13 @@
{% extends "_base/class.html.jinja" %}
{% block signature scoped %}
{% if config.extra.class_style != 'simple' %}
{{ super() }}
{% endif %}
{% endblock signature %}
{% block bases scoped %}
{% if config.extra.class_style != 'simple' %}
{{ super() }}
{% endif %}
{% endblock bases %}

Wyświetl plik

@ -0,0 +1,7 @@
{% extends "_base/backlinks.html.jinja" %}
{% block logs scoped %}
<p style="color:red">backlinks.html.jinja</p>
{% endblock logs %}

Wyświetl plik

@ -0,0 +1,11 @@
{% extends "_base/class.html.jinja" %}
{% block logs scoped %}
{% if config.extra.template_log_display %}<p style="color:red">class.html.jinja</p>{% endif %}
{% endblock logs %}
{% block signature scoped %}
{% if config.extra.class_style == 'simple' %}{% else %}
{{ super() }}
{% endif %}
{% endblock signature %}

Wyświetl plik

@ -0,0 +1,21 @@
{% extends "_base/docstring/attributes.html.jinja" %}
{% block logs scoped %}
{% if config.extra.template_log_display %}<p style="color:red">docstring/attributes.html.jinja</p>{% endif %}
{% endblock logs %}
{% block table_style scoped %}
{% if config.extra.class_style == 'simple' %}{% else %}
{{ super() }}
{% endif %}
{% endblock table_style %}
{% block list_style scoped %}
{% if config.extra.class_style == 'simple' %}{% else %}
{{ super() }}
{% endif %}
{% endblock list_style %}
{% block spacy_style scoped %}
{% if config.extra.class_style == 'simple' %}{% else %}
{{ super() }}
{% endif %}
{% endblock spacy_style %}

Wyświetl plik

@ -0,0 +1,22 @@
{% extends "_base/docstring/functions.html.jinja" %}
{% block logs scoped %}
{% if config.extra.template_log_display %}<p style="color:red">docstring/functions.html.jinja</p>{% endif %}
{% endblock logs %}
{% block table_style scoped %}
{% if config.extra.class_style == 'simple' %}{% else %}
{{ super() }}
{% endif %}
{% endblock %}
{% block list_style scoped %}
{% if config.extra.class_style == 'simple' %}{% else %}
{{ super() }}
{% endif %}
{% endblock %}
{% block spacy_style scoped %}
{% if config.extra.class_style == 'simple' %}{% else %}
{{ super() }}
{% endif %}
{% endblock %}

Wyświetl plik

@ -0,0 +1,5 @@
{% if 'default_factory' in expression.__str__() %}
{{ obj.extra.dataclass_ext.default_factory | safe }}
{% else %}
{% extends "_base/expression.html.jinja" %}
{% endif %}

Wyświetl plik

@ -0,0 +1,6 @@
{% if config.extra.class_style == 'simple' %}{% else %}{% extends "_base/function.html.jinja" %}{% endif %}
{% block logs scoped %}
{% if config.extra.template_log_display %}<p style="color:red">function.html.jinja</p>{% endif %}
{% endblock logs %}

Wyświetl plik

@ -1,7 +1,7 @@
{
"name": "amqttio",
"private": true,
"version": "0.11.2",
"version": "0.11.3",
"type": "module",
"scripts": {
"dev": "vite",

Wyświetl plik

@ -15,4 +15,43 @@ export type TopicEntry<T> = {
// Define the topic_map type
export type TopicMap = {
[topic: string]: TopicEntry<DataPoint[]>;
};
};
// no need for a full uuid, generate a 6-character alphanumeric sequence, in two parts
export function getClientID() {
const genPart = () => {
const rand = (Math.random() * 46656) | 0
// convert random number into an ascii sequence of letters, trimmed to 3 characters
return ("000" + rand.toString(36)).slice(-3)
}
return `web-client-${genPart() + genPart()}`
}
export function secondsToDhms(seconds: number) {
const days = Math.floor(seconds / (24 * 3600));
seconds %= (24 * 3600);
const hours = Math.floor(seconds / 3600);
seconds %= 3600;
const minutes = Math.floor(seconds / 60);
seconds = seconds % 60;
return {
days: days,
hours: hours,
minutes: minutes,
seconds: seconds,
};
}
export function getMQTTSettings() {
return {
url: import.meta.env.VITE_MQTT_WS_TYPE + '://' + import.meta.env.VITE_MQTT_WS_HOST + ':' + import.meta.env.VITE_MQTT_WS_PORT,
client_id: getClientID(),
clean: true,
protocol: 'wss',
protocolVersion: 4, // MQTT 3.1.1
wsOptions: {
protocol: 'mqtt'
}
}
}

Wyświetl plik

@ -7,17 +7,10 @@ import MainGrid from './components/MainGrid';
import AppTheme from '../shared-theme/AppTheme';
import AmqttLogo from './amqtt_bw.svg';
import {
chartsCustomizations,
treeViewCustomizations,
} from './theme/customizations';
import AppBar from "@mui/material/AppBar";
import {Toolbar} from "@mui/material";
const xThemeComponents = {
...chartsCustomizations,
...treeViewCustomizations,
};
const xThemeComponents = {};
export default function Dashboard(props: { disableCustomTheme?: boolean }) {
return (

Wyświetl plik

@ -1,3 +0,0 @@
<Typography component="h2" variant="h6" color="primary" gutterBottom>
{props.children}
</Typography>

Wyświetl plik

@ -1,105 +0,0 @@
import * as React from 'react';
import { styled } from '@mui/material/styles';
import AppBar from '@mui/material/AppBar';
import Box from '@mui/material/Box';
import Stack from '@mui/material/Stack';
import MuiToolbar from '@mui/material/Toolbar';
import { tabsClasses } from '@mui/material/Tabs';
import Typography from '@mui/material/Typography';
import MenuRoundedIcon from '@mui/icons-material/MenuRounded';
import DashboardRoundedIcon from '@mui/icons-material/DashboardRounded';
import SideMenuMobile from './SideMenuMobile';
import MenuButton from './MenuButton';
import ColorModeIconDropdown from '../../shared-theme/ColorModeIconDropdown';
const Toolbar = styled(MuiToolbar)({
width: '100%',
padding: '12px',
display: 'flex',
flexDirection: 'column',
alignItems: 'start',
justifyContent: 'center',
gap: '12px',
flexShrink: 0,
[`& ${tabsClasses.flexContainer}`]: {
gap: '8px',
p: '8px',
pb: 0,
},
});
export default function AppNavbar() {
const [open, setOpen] = React.useState(false);
const toggleDrawer = (newOpen: boolean) => () => {
setOpen(newOpen);
};
return (
<AppBar
position="fixed"
sx={{
display: { xs: 'auto', md: 'none' },
boxShadow: 0,
bgcolor: 'background.paper',
backgroundImage: 'none',
borderBottom: '1px solid',
borderColor: 'divider',
top: 'var(--template-frame-height, 0px)',
}}
>
<Toolbar variant="regular">
<Stack
direction="row"
sx={{
alignItems: 'center',
flexGrow: 1,
width: '100%',
gap: 1,
}}
>
<Stack
direction="row"
spacing={1}
sx={{ justifyContent: 'center', mr: 'auto' }}
>
<CustomIcon />
<Typography variant="h4" component="h1" sx={{ color: 'text.primary' }}>
Dashboard
</Typography>
</Stack>
<ColorModeIconDropdown />
<MenuButton aria-label="menu" onClick={toggleDrawer(true)}>
<MenuRoundedIcon />
</MenuButton>
<SideMenuMobile open={open} toggleDrawer={toggleDrawer} />
</Stack>
</Toolbar>
</AppBar>
);
}
export function CustomIcon() {
return (
<Box
sx={{
width: '1.5rem',
height: '1.5rem',
bgcolor: 'black',
borderRadius: '999px',
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
alignSelf: 'center',
backgroundImage:
'linear-gradient(135deg, hsl(210, 98%, 60%) 0%, hsl(210, 100%, 35%) 100%)',
color: 'hsla(210, 100%, 95%, 0.9)',
border: '1px solid',
borderColor: 'hsl(210, 100%, 55%)',
boxShadow: 'inset 0 2px 5px rgba(255, 255, 255, 0.3)',
}}
>
<DashboardRoundedIcon color="inherit" sx={{ fontSize: '1rem' }} />
</Box>
);
}

Wyświetl plik

@ -1,24 +0,0 @@
import Card from '@mui/material/Card';
import CardContent from '@mui/material/CardContent';
import Button from '@mui/material/Button';
import Typography from '@mui/material/Typography';
import AutoAwesomeRoundedIcon from '@mui/icons-material/AutoAwesomeRounded';
export default function CardAlert() {
return (
<Card variant="outlined" sx={{ m: 1.5, flexShrink: 0 }}>
<CardContent>
<AutoAwesomeRoundedIcon fontSize="small" />
<Typography gutterBottom sx={{ fontWeight: 600 }}>
Plan about to expire
</Typography>
<Typography variant="body2" sx={{ mb: 2, color: 'text.secondary' }}>
Enjoy 10% off when renewing your plan today.
</Typography>
<Button variant="contained" size="small" fullWidth>
Get the discount
</Button>
</CardContent>
</Card>
);
}

Wyświetl plik

@ -0,0 +1,49 @@
import CountUp from "react-countup";
const byte_units = [
'Bytes',
'KB',
'MB',
'GB',
'TB'
]
const update_time = 5;
export function ByteCounter(props: any) {
let start = props.start;
let end = props.end;
if(end - start < 200){
return <CountUp
start={start}
end={end}
duration={update_time}/>
}
let unit = byte_units[0];
for (let i = 0; i < byte_units.length; i++) {
if( start > 1_000) {
start = start / 1000;
end = end / 1000;
unit = byte_units[i+1];
}
}
return <CountUp
start={start}
end={end}
suffix={" " + unit}
decimals={2}
duration={update_time}/>
}
export function StandardCounter(props: any) {
return <CountUp
start={props.start}
end={props.end}
duration={update_time}/>
}

Wyświetl plik

@ -1,47 +0,0 @@
import { DataGrid } from '@mui/x-data-grid';
import { columns, rows } from '../internals/data/gridData';
export default function CustomizedDataGrid() {
return (
<DataGrid
checkboxSelection
rows={rows}
columns={columns}
getRowClassName={(params) =>
params.indexRelativeToCurrentPage % 2 === 0 ? 'even' : 'odd'
}
initialState={{
pagination: { paginationModel: { pageSize: 20 } },
}}
pageSizeOptions={[10, 20, 50]}
disableColumnResize
density="compact"
slotProps={{
filterPanel: {
filterFormProps: {
logicOperatorInputProps: {
variant: 'outlined',
size: 'small',
},
columnInputProps: {
variant: 'outlined',
size: 'small',
sx: { mt: 'auto' },
},
operatorInputProps: {
variant: 'outlined',
size: 'small',
sx: { mt: 'auto' },
},
valueInputProps: {
InputComponentProps: {
variant: 'outlined',
size: 'small',
},
},
},
},
}}
/>
);
}

Wyświetl plik

@ -4,9 +4,10 @@
import Typography from '@mui/material/Typography';
import Stack from '@mui/material/Stack';
import { LineChart } from '@mui/x-charts/LineChart';
import CountUp from 'react-countup';
import type { DataPoint } from '../../assets/helpers.jsx';
import {CircularProgress} from "@mui/material";
import {StandardCounter, ByteCounter} from "./Counter.tsx";
import {useRef} from "react";
const currentTimeZone = Intl.DateTimeFormat().resolvedOptions().timeZone;
@ -116,7 +117,7 @@
</LineChart>
}
export default function SessionsChart(props: any) {
export default function DashboardChart(props: any) {
const lastCalc = useRef<number>(0);
@ -153,13 +154,9 @@
<Typography variant="h4" component="p">
{ props.data.length < 2 ? "" :
<CountUp
start={props.data[props.data.length - 2].value}
end={props.data[props.data.length - 1].value}
duration={5}
decimals={props.decimals}
/>} {props.label}
props.isBytes ? <ByteCounter start={props.data[props.data.length - 2].value} end={props.data[props.data.length - 1].value}/> :
<StandardCounter start={props.data[props.data.length - 2].value} end={props.data[props.data.length - 1].value} />
} {props.label}
</Typography>
<p>
{ calc_per_second(props.data[props.data.length-1], props.data[props.data.length-2]) }

Wyświetl plik

@ -0,0 +1,105 @@
import Grid from "@mui/material/Grid";
import Typography from "@mui/material/Typography";
import {FontAwesomeIcon} from "@fortawesome/react-fontawesome";
import {faDiscord, faDocker, faGithub, faPython} from "@fortawesome/free-brands-svg-icons";
import rtdIcon from "../../assets/readthedocs.svg";
import {Paper, Table, TableBody, TableCell, TableContainer, TableHead, TableRow} from "@mui/material";
export default function DescriptionPanel() {
return <>
<Grid size={{xs: 10, md: 5}}>
<Typography component="h2" variant="h6" sx={{mb: 2}}>
Overview
</Typography>
<div>
<p style={{textAlign: 'left'}}>This is <b>test.amqtt.io</b>.</p>
<p style={{textAlign: 'left'}}>It hosts a publicly available aMQTT server/broker.</p>
<p style={{textAlign: 'left'}}><a href="http://www.mqtt.org">MQTT</a> is a very lightweight
protocol that uses a publish/subscribe model. This makes it suitable for "machine to machine"
messaging such as with low power sensors or mobile devices.
</p>
<p style={{textAlign: 'left'}}>For more information: </p>
<table>
<tbody>
<tr>
<td style={{width: 250}}>
<p style={{textAlign: 'left'}}>
<FontAwesomeIcon icon={faGithub} size="xl"/> github: <a
href="https://github.com/Yakifo/amqtt">Yakifo/amqtt</a>
</p>
<p style={{textAlign: 'left'}}>
<FontAwesomeIcon icon={faPython} size="xl"/> PyPi: <a
href="https://pypi.org/project/amqtt/">aMQTT</a>
</p>
<p style={{textAlign: 'left'}}>
<FontAwesomeIcon icon={faDiscord} size="xl"/> Discord: <a
href="https://discord.gg/S3sP6dDaF3">aMQTT</a>
</p>
</td>
<td>
<p style={{textAlign: 'left'}}>
<img
src={rtdIcon}
style={{width: 20, verticalAlign: -4}}
alt="website logo"
/>
ReadTheDocs: <a href="https://amqtt.readthedocs.io/">aMQTT</a>
</p>
<p style={{textAlign: 'left'}}>
<FontAwesomeIcon icon={faDocker} size="xl"/> DockerHub: <a
href="https://hub.docker.com/repositories/amqtt">aMQTT</a>
</p>
<p>&nbsp;</p>
</td>
</tr>
</tbody>
</table>
</div>
</Grid>
<Grid size={{xs: 1, md: 1}}></Grid>
<Grid size={{xs: 12, md: 6}}>
<Typography component="h2" variant="h6" sx={{mb: 2}}>
Access
</Typography>
<TableContainer component={Paper}>
<Table sx={{maxWidth: 400}} size="small">
<TableHead>
<TableRow>
<TableCell>Host</TableCell>
<TableCell>test.amqtt.io</TableCell>
</TableRow>
</TableHead>
<TableBody>
<TableRow>
<TableCell>TCP</TableCell>
<TableCell>1883</TableCell>
</TableRow>
<TableRow>
<TableCell>TLS TCP</TableCell>
<TableCell>8883</TableCell>
</TableRow>
<TableRow>
<TableCell>Websocket</TableCell>
<TableCell>8080</TableCell>
</TableRow>
<TableRow>
<TableCell>SSL Websocket</TableCell>
<TableCell>8443</TableCell>
</TableRow>
</TableBody>
</Table>
</TableContainer>
<p style={{textAlign: 'left'}}>
The purpose of this free MQTT broker at <strong>test.amqtt.io</strong> is to learn about and test the MQTT
protocol. It
should not be used in production, development, staging or uat environments. Do not to use it to send any
sensitive information or personal data into the system as all topics are public. Any illegal use of this
MQTT broker is strictly forbidden. By using this MQTT broker located at <strong>test.amqtt.io</strong> you
warrant that you are neither a sanctioned person nor located in a country that is subject to sanctions.
</p>
</Grid>
</>
}

Wyświetl plik

@ -1,33 +0,0 @@
import Stack from '@mui/material/Stack';
import NotificationsRoundedIcon from '@mui/icons-material/NotificationsRounded';
import NavbarBreadcrumbs from './NavbarBreadcrumbs';
import MenuButton from './MenuButton';
import ColorModeIconDropdown from '../../shared-theme/ColorModeIconDropdown';
import Search from './Search';
export default function Header() {
return (
<Stack
direction="row"
sx={{
display: { xs: 'none', md: 'flex' },
width: '100%',
alignItems: { xs: 'flex-start', md: 'center' },
justifyContent: 'space-between',
maxWidth: { sm: '100%', md: '1700px' },
pt: 1.5,
}}
spacing={2}
>
<NavbarBreadcrumbs />
<Stack direction="row" sx={{ gap: 1 }}>
<Search />
<MenuButton showBadge aria-label="Open notifications">
<NotificationsRoundedIcon />
</MenuButton>
<ColorModeIconDropdown />
</Stack>
</Stack>
);
}

Wyświetl plik

@ -1,41 +0,0 @@
import Card from '@mui/material/Card';
import CardContent from '@mui/material/CardContent';
import Button from '@mui/material/Button';
import Typography from '@mui/material/Typography';
import ChevronRightRoundedIcon from '@mui/icons-material/ChevronRightRounded';
import InsightsRoundedIcon from '@mui/icons-material/InsightsRounded';
import useMediaQuery from '@mui/material/useMediaQuery';
import { useTheme } from '@mui/material/styles';
export default function HighlightedCard() {
const theme = useTheme();
const isSmallScreen = useMediaQuery(theme.breakpoints.down('sm'));
return (
<Card sx={{ height: '100%' }}>
<CardContent>
<InsightsRoundedIcon />
<Typography
component="h2"
variant="subtitle2"
gutterBottom
sx={{ fontWeight: '600' }}
>
Explore your data
</Typography>
<Typography sx={{ color: 'text.secondary', mb: '8px' }}>
Uncover performance and visitor insights with our data wizardry.
</Typography>
<Button
variant="contained"
size="small"
color="primary"
endIcon={<ChevronRightRoundedIcon />}
fullWidth={isSmallScreen}
>
Get insights
</Button>
</CardContent>
</Card>
);
}

Wyświetl plik

@ -1,18 +1,13 @@
import Grid from '@mui/material/Grid';
import Box from '@mui/material/Box';
import Stack from '@mui/material/Stack';
import Typography from '@mui/material/Typography';
import Copyright from '../internals/components/Copyright';
import SessionsChart from './SessionsChart';
import Copyright from '../Copyright';
import DashboardChart from './DashboardChart.tsx';
import {useEffect, useState} from "react";
// @ts-ignore
import useMqtt from '../../assets/usemqtt';
import type {DataPoint, TopicMap} from '../../assets/helpers';
import {Paper, Table, TableBody, TableCell, TableContainer, TableHead, TableRow} from "@mui/material";
import {FontAwesomeIcon} from '@fortawesome/react-fontawesome'
import {faGithub, faPython, faDocker, faDiscord} from "@fortawesome/free-brands-svg-icons";
import rtdIcon from "../../assets/readthedocs.svg";
import {type DataPoint, getMQTTSettings, secondsToDhms, type TopicMap} from '../../assets/helpers';
import DescriptionPanel from "./DescriptionPanel";
export default function MainGrid() {
@ -27,54 +22,9 @@ export default function MainGrid() {
const [memSize, setMemSize] = useState<DataPoint[]>([]);
const [version, setVersion] = useState<string>('');
function getRandomInt(min: number, max: number) {
min = Math.ceil(min);
max = Math.floor(max);
return Math.floor(Math.random() * (max - min + 1)) + min;
}
const {mqttSubscribe, isConnected, messageQueue, messageTick} = useMqtt(getMQTTSettings());
function secondsToDhms(seconds: number) {
const days = Math.floor(seconds / (24 * 3600));
seconds %= (24 * 3600);
const hours = Math.floor(seconds / 3600);
seconds %= 3600;
const minutes = Math.floor(seconds / 60);
seconds = seconds % 60;
return {
days: days,
hours: hours,
minutes: minutes,
seconds: seconds,
};
}
const mqtt_settings = {
url: import.meta.env.VITE_MQTT_WS_TYPE + '://' + import.meta.env.VITE_MQTT_WS_HOST + ':' + import.meta.env.VITE_MQTT_WS_PORT, client_id: `web-client-${getRandomInt(1, 100)}`,
clean: true,
protocol: 'wss',
protocolVersion: 4, // MQTT 3.1.1
wsOptions: {
protocol: 'mqtt'
}
};
const {mqttSubscribe, isConnected, messageQueue, messageTick} = useMqtt(mqtt_settings);
useEffect(() => {
if (isConnected) {
mqttSubscribe('$SYS/broker/version');
mqttSubscribe('$SYS/broker/messages/publish/#');
mqttSubscribe('$SYS/broker/load/bytes/#');
mqttSubscribe('$SYS/broker/uptime/formatted');
mqttSubscribe('$SYS/broker/uptime');
mqttSubscribe('$SYS/broker/clients/connected');
mqttSubscribe('$SYS/broker/cpu/percent');
mqttSubscribe('$SYS/broker/heap/size')
}
}, [isConnected, mqttSubscribe]);
const topic_map: TopicMap = {
const topicMap: TopicMap = {
'$SYS/broker/messages/publish/sent': {current: sent, update: setSent},
'$SYS/broker/messages/publish/received': {current: received, update: setReceived},
'$SYS/broker/load/bytes/received': {current: bytesIn, update: setBytesIn},
@ -84,6 +34,17 @@ export default function MainGrid() {
'$SYS/broker/heap/size': {current: memSize, update: setMemSize},
};
useEffect(() => {
if (isConnected) {
for(const topic in topicMap) {
mqttSubscribe(topic);
}
mqttSubscribe('$SYS/broker/version');
mqttSubscribe('$SYS/broker/uptime/formatted');
mqttSubscribe('$SYS/broker/uptime');
}
}, [isConnected, mqttSubscribe]);
useEffect(() => {
while (messageQueue.current.length > 0) {
@ -92,8 +53,8 @@ export default function MainGrid() {
const d = payload.message;
if(payload.topic in topic_map) {
const { update } = topic_map[payload.topic];
if(payload.topic in topicMap) {
const { update } = topicMap[payload.topic];
const newPoint: DataPoint = {
time: new Date().toISOString(),
timestamp: Date.now(),
@ -115,6 +76,13 @@ export default function MainGrid() {
}
}, [messageTick, messageQueue]);
const upTime = () => {
return isConnected && serverUptime ? <>
<strong>aMQTT broker</strong> {version.replace('aMQTT version ', 'v')} <strong>started at </strong> {serverStart} &nbsp;
<strong>up for</strong> {serverUptime}
</> : <></>;
}
return (
<Box sx={{width: '100%', maxWidth: {sm: '100%', md: '1700px'}}}>
{/* cards */}
@ -125,131 +93,39 @@ export default function MainGrid() {
columns={12}
sx={{mb: (theme) => theme.spacing(2)}}
>
<Grid size={{xs: 10, md: 5}}>
<Typography component="h2" variant="h6" sx={{mb: 2}}>
Overview
</Typography>
<div>
<p style={{textAlign: 'left'}}>This is <b>test.amqtt.io</b>.</p>
<p style={{textAlign: 'left'}}>It hosts a publicly available aMQTT server/broker.</p>
<p style={{textAlign: 'left'}}><a href="http://www.mqtt.org">MQTT</a> is a very lightweight
protocol that uses a publish/subscribe model. This makes it suitable for "machine to machine"
messaging such as with low power sensors or mobile devices.
</p>
<p style={{textAlign: 'left'}}>For more information: </p>
<table>
<tbody>
<tr>
<td style={{width: 250}}>
<p style={{textAlign: 'left'}}>
<FontAwesomeIcon icon={faGithub} size="xl"/> github: <a
href="https://github.com/Yakifo/amqtt">Yakifo/amqtt</a>
</p>
<p style={{textAlign: 'left'}}>
<FontAwesomeIcon icon={faPython} size="xl"/> PyPi: <a
href="https://pypi.org/project/amqtt/">aMQTT</a>
</p>
<p style={{textAlign: 'left'}}>
<FontAwesomeIcon icon={faDiscord} size="xl"/> Discord: <a
href="https://discord.gg/S3sP6dDaF3">aMQTT</a>
</p>
</td>
<td>
<p style={{textAlign: 'left'}}>
<img
src={rtdIcon}
style={{width: 20, verticalAlign: -4}}
alt="website logo"
/>
ReadTheDocs: <a href="https://amqtt.readthedocs.io/">aMQTT</a>
</p>
<p style={{textAlign: 'left'}}>
<FontAwesomeIcon icon={faDocker} size="xl"/> DockerHub: <a
href="https://hub.docker.com/repositories/amqtt">aMQTT</a>
</p>
<p>&nbsp;</p>
</td>
</tr>
</tbody>
</table>
</div>
</Grid>
<Grid size={{xs: 1, md: 1}}></Grid>
<Grid size={{xs: 12, md: 6}}>
<Typography component="h2" variant="h6" sx={{mb: 2}}>
Access
</Typography>
<TableContainer component={Paper}>
<Table sx={{maxWidth: 400}} size="small">
<TableHead>
<TableRow>
<TableCell>Host</TableCell>
<TableCell>test.amqtt.io</TableCell>
</TableRow>
</TableHead>
<TableBody>
<TableRow>
<TableCell>TCP</TableCell>
<TableCell>1883</TableCell>
</TableRow>
<TableRow>
<TableCell>TLS TCP</TableCell>
<TableCell>8883</TableCell>
</TableRow>
<TableRow>
<TableCell>Websocket</TableCell>
<TableCell>8080</TableCell>
</TableRow>
<TableRow>
<TableCell>SSL Websocket</TableCell>
<TableCell>8443</TableCell>
</TableRow>
</TableBody>
</Table>
</TableContainer>
<p style={{textAlign: 'left'}}>
The purpose of this free MQTT broker at <strong>test.amqtt.io</strong> is to learn about and test the MQTT
protocol. It
should not be used in production, development, staging or uat environments. Do not to use it to send any
sensitive information or personal data into the system as all topics are public. Any illegal use of this
MQTT broker is strictly forbidden. By using this MQTT broker located at <strong>test.amqtt.io</strong> you
warrant that you are neither a sanctioned person nor located in a country that is subject to sanctions.
</p>
</Grid>
<DescriptionPanel/>
</Grid>
<Grid
container
spacing={2}
columns={12}
sx={{mb: (theme) => theme.spacing(2)}}
><Grid size={{xs: 12, md: 12}}>
<strong>broker</strong> ('{version}') <strong>started at </strong> {serverStart} &nbsp;&nbsp;&nbsp;
<strong>up for</strong> {serverUptime}
>
<Grid size={{xs: 12, md: 12}}>
{upTime()}
</Grid>
<Grid size={{xs: 12, md: 6}}>
<SessionsChart title={'Sent Messages'} label={''} data={sent} isConnected={isConnected} isPerSecond/>
<DashboardChart title={'Sent Messages'} label={''} data={sent} isConnected={isConnected} isPerSecond/>
</Grid>
<Grid size={{xs: 12, md: 6}}>
<SessionsChart title={'Received Messages'} label={''} data={received} isConnected={isConnected} isPerSecond/>
<DashboardChart title={'Received Messages'} label={''} data={received} isConnected={isConnected} isPerSecond/>
</Grid>
<Grid size={{xs: 12, md: 6}}>
<SessionsChart title={'Bytes Out'} label={'Bytes'} data={bytesOut} isConnected={isConnected}/>
<DashboardChart title={'Bytes Out'} label={''} data={bytesOut} isConnected={isConnected} isBytes/>
</Grid>
<Grid size={{xs: 12, md: 6}}>
<SessionsChart title={'Bytes In'} label={'Bytes'} data={bytesIn} isConnected={isConnected}/>
<DashboardChart title={'Bytes In'} label={''} data={bytesIn} isConnected={isConnected} isBytes/>
</Grid>
<Grid size={{xs: 12, md: 6}}>
<SessionsChart title={'Clients Connected'} label={''} data={clientsConnected} isConnected={isConnected}/>
<DashboardChart title={'Clients Connected'} label={''} data={clientsConnected} isConnected={isConnected}/>
</Grid>
<Grid size={{xs: 12, md: 6}}>
<Grid container spacing={2} columns={2}>
<Grid size={{lg:1}}>
<SessionsChart title={'CPU'} label={'%'} data={cpuPercent} decimals={2} isConnected={isConnected}/>
<DashboardChart title={'CPU'} label={'%'} data={cpuPercent} decimals={2} isConnected={isConnected}/>
</Grid>
<Grid size={{lg:1}}>
<SessionsChart title={'Memory'} label={'MB'} data={memSize} decimals={1} isConnected={isConnected}/>
<DashboardChart title={'Memory'} label={'MB'} data={memSize} decimals={1} isConnected={isConnected}/>
</Grid>
</Grid>
</Grid>

Wyświetl plik

@ -1,24 +0,0 @@
import Badge, { badgeClasses } from '@mui/material/Badge';
import IconButton from '@mui/material/IconButton';
import type { IconButtonProps } from '@mui/material/IconButton';
export interface MenuButtonProps extends IconButtonProps {
showBadge?: boolean;
}
export default function MenuButton({
showBadge = false,
...props
}) {
return (
<Badge
color="error"
variant="dot"
invisible={!showBadge}
sx={{ [`& .${badgeClasses.badge}`]: { right: 2, top: 2 } }}
>
<IconButton size="small" {...props} />
</Badge>
);
}

Wyświetl plik

@ -1,53 +0,0 @@
import List from '@mui/material/List';
import ListItem from '@mui/material/ListItem';
import ListItemButton from '@mui/material/ListItemButton';
import ListItemIcon from '@mui/material/ListItemIcon';
import ListItemText from '@mui/material/ListItemText';
import Stack from '@mui/material/Stack';
import HomeRoundedIcon from '@mui/icons-material/HomeRounded';
import AnalyticsRoundedIcon from '@mui/icons-material/AnalyticsRounded';
import PeopleRoundedIcon from '@mui/icons-material/PeopleRounded';
import AssignmentRoundedIcon from '@mui/icons-material/AssignmentRounded';
import SettingsRoundedIcon from '@mui/icons-material/SettingsRounded';
import InfoRoundedIcon from '@mui/icons-material/InfoRounded';
import HelpRoundedIcon from '@mui/icons-material/HelpRounded';
const mainListItems = [
{ text: 'Home', icon: <HomeRoundedIcon /> },
{ text: 'Analytics', icon: <AnalyticsRoundedIcon /> },
{ text: 'Clients', icon: <PeopleRoundedIcon /> },
{ text: 'Tasks', icon: <AssignmentRoundedIcon /> },
];
const secondaryListItems = [
{ text: 'Settings', icon: <SettingsRoundedIcon /> },
{ text: 'About', icon: <InfoRoundedIcon /> },
{ text: 'Feedback', icon: <HelpRoundedIcon /> },
];
export default function MenuContent() {
return (
<Stack sx={{ flexGrow: 1, p: 1, justifyContent: 'space-between' }}>
<List dense>
{mainListItems.map((item, index) => (
<ListItem key={index} disablePadding sx={{ display: 'block' }}>
<ListItemButton selected={index === 0}>
<ListItemIcon>{item.icon}</ListItemIcon>
<ListItemText primary={item.text} />
</ListItemButton>
</ListItem>
))}
</List>
<List dense>
{secondaryListItems.map((item, index) => (
<ListItem key={index} disablePadding sx={{ display: 'block' }}>
<ListItemButton>
<ListItemIcon>{item.icon}</ListItemIcon>
<ListItemText primary={item.text} />
</ListItemButton>
</ListItem>
))}
</List>
</Stack>
);
}

Wyświetl plik

@ -1,29 +0,0 @@
import { styled } from '@mui/material/styles';
import Typography from '@mui/material/Typography';
import Breadcrumbs, { breadcrumbsClasses } from '@mui/material/Breadcrumbs';
import NavigateNextRoundedIcon from '@mui/icons-material/NavigateNextRounded';
const StyledBreadcrumbs = styled(Breadcrumbs)(({ theme }) => ({
margin: theme.spacing(1, 0),
[`& .${breadcrumbsClasses.separator}`]: {
color: (theme.vars || theme).palette.action.disabled,
margin: 1,
},
[`& .${breadcrumbsClasses.ol}`]: {
alignItems: 'center',
},
}));
export default function NavbarBreadcrumbs() {
return (
<StyledBreadcrumbs
aria-label="breadcrumb"
separator={<NavigateNextRoundedIcon fontSize="small" />}
>
<Typography variant="body1">Dashboard</Typography>
<Typography variant="body1" sx={{ color: 'text.primary', fontWeight: 600 }}>
Home
</Typography>
</StyledBreadcrumbs>
);
}

Wyświetl plik

@ -1,79 +0,0 @@
import * as React from 'react';
import { styled } from '@mui/material/styles';
import Divider, { dividerClasses } from '@mui/material/Divider';
import Menu from '@mui/material/Menu';
import MuiMenuItem from '@mui/material/MenuItem';
import { paperClasses } from '@mui/material/Paper';
import { listClasses } from '@mui/material/List';
import ListItemText from '@mui/material/ListItemText';
import ListItemIcon, { listItemIconClasses } from '@mui/material/ListItemIcon';
import LogoutRoundedIcon from '@mui/icons-material/LogoutRounded';
import MoreVertRoundedIcon from '@mui/icons-material/MoreVertRounded';
import MenuButton from './MenuButton';
const MenuItem = styled(MuiMenuItem)({
margin: '2px 0',
});
export default function OptionsMenu() {
const [anchorEl, setAnchorEl] = React.useState<null | HTMLElement>(null);
const open = Boolean(anchorEl);
const handleClick = (event: React.MouseEvent<HTMLElement>) => {
setAnchorEl(event.currentTarget);
};
const handleClose = () => {
setAnchorEl(null);
};
return (
<React.Fragment>
<MenuButton
aria-label="Open menu"
onClick={handleClick}
sx={{ borderColor: 'transparent' }}
>
<MoreVertRoundedIcon />
</MenuButton>
<Menu
anchorEl={anchorEl}
id="menu"
open={open}
onClose={handleClose}
onClick={handleClose}
transformOrigin={{ horizontal: 'right', vertical: 'top' }}
anchorOrigin={{ horizontal: 'right', vertical: 'bottom' }}
sx={{
[`& .${listClasses.root}`]: {
padding: '4px',
},
[`& .${paperClasses.root}`]: {
padding: 0,
},
[`& .${dividerClasses.root}`]: {
margin: '4px -4px',
},
}}
>
<MenuItem onClick={handleClose}>Profile</MenuItem>
<MenuItem onClick={handleClose}>My account</MenuItem>
<Divider />
<MenuItem onClick={handleClose}>Add another account</MenuItem>
<MenuItem onClick={handleClose}>Settings</MenuItem>
<Divider />
<MenuItem
onClick={handleClose}
sx={{
[`& .${listItemIconClasses.root}`]: {
ml: 'auto',
minWidth: 0,
},
}}
>
<ListItemText>Logout</ListItemText>
<ListItemIcon>
<LogoutRoundedIcon fontSize="small" />
</ListItemIcon>
</MenuItem>
</Menu>
</React.Fragment>
);
}

Wyświetl plik

@ -1,84 +0,0 @@
import Card from '@mui/material/Card';
import CardContent from '@mui/material/CardContent';
import Chip from '@mui/material/Chip';
import Typography from '@mui/material/Typography';
import Stack from '@mui/material/Stack';
import { BarChart } from '@mui/x-charts/BarChart';
import { useTheme } from '@mui/material/styles';
export default function PageViewsBarChart() {
const theme = useTheme();
const colorPalette = [
(theme.vars || theme).palette.primary.dark,
(theme.vars || theme).palette.primary.main,
(theme.vars || theme).palette.primary.light,
];
return (
<Card variant="outlined" sx={{ width: '100%' }}>
<CardContent>
<Typography component="h2" variant="subtitle2" gutterBottom>
Page views and downloads
</Typography>
<Stack sx={{ justifyContent: 'space-between' }}>
<Stack
direction="row"
sx={{
alignContent: { xs: 'center', sm: 'flex-start' },
alignItems: 'center',
gap: 1,
}}
>
<Typography variant="h4" component="p">
1.3M
</Typography>
<Chip size="small" color="error" label="-8%" />
</Stack>
<Typography variant="caption" sx={{ color: 'text.secondary' }}>
Page views and downloads for the last 6 months
</Typography>
</Stack>
<BarChart
borderRadius={8}
colors={colorPalette}
xAxis={
[
{
scaleType: 'band',
categoryGapRatio: 0.5,
data: ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul'],
},
] as any
}
series={[
{
id: 'page-views',
label: 'Page views',
data: [2234, 3872, 2998, 4125, 3357, 2789, 2998],
stack: 'A',
},
{
id: 'downloads',
label: 'Downloads',
data: [3098, 4215, 2384, 2101, 4752, 3593, 2384],
stack: 'A',
},
{
id: 'conversions',
label: 'Conversions',
data: [4051, 2275, 3129, 4693, 3904, 2038, 2275],
stack: 'A',
},
]}
height={250}
margin={{ left: 50, right: 0, top: 20, bottom: 20 }}
grid={{ horizontal: true }}
hideLegend
slotProps={{
legend: {
},
}}
/>
</CardContent>
</Card>
);
}

Wyświetl plik

@ -1,25 +0,0 @@
import FormControl from '@mui/material/FormControl';
import InputAdornment from '@mui/material/InputAdornment';
import OutlinedInput from '@mui/material/OutlinedInput';
import SearchRoundedIcon from '@mui/icons-material/SearchRounded';
export default function Search() {
return (
<FormControl sx={{ width: { xs: '100%', md: '25ch' } }} variant="outlined">
<OutlinedInput
size="small"
id="search"
placeholder="Search…"
sx={{ flexGrow: 1 }}
startAdornment={
<InputAdornment position="start" sx={{ color: 'text.primary' }}>
<SearchRoundedIcon fontSize="small" />
</InputAdornment>
}
inputProps={{
'aria-label': 'search',
}}
/>
</FormControl>
);
}

Some files were not shown because too many files have changed in this diff Show More