From 7255561b09385fd0aa54fa41521b02b16f5bb52b Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Sun, 18 May 2025 11:12:14 -0400 Subject: [PATCH] replacing amqtt publisher use of 'docopts' with 'typer' --- amqtt/mqtt/packet.py | 4 +- amqtt/plugins/sys/broker.py | 6 +- amqtt/scripts/pub_script.py | 240 ++++++++++++++++++++++-------------- amqtt/version.py | 1 + pyproject.toml | 7 +- 5 files changed, 161 insertions(+), 97 deletions(-) diff --git a/amqtt/mqtt/packet.py b/amqtt/mqtt/packet.py index b674e95..23fa377 100644 --- a/amqtt/mqtt/packet.py +++ b/amqtt/mqtt/packet.py @@ -5,11 +5,11 @@ try: from datetime import UTC, datetime except ImportError: from datetime import datetime, timezone + UTC = timezone.utc from struct import unpack -from typing import Generic -from typing_extensions import Self, TypeVar +from typing import Generic, Self, TypeVar from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise diff --git a/amqtt/plugins/sys/broker.py b/amqtt/plugins/sys/broker.py index cf53fea..63d4ec5 100644 --- a/amqtt/plugins/sys/broker.py +++ b/amqtt/plugins/sys/broker.py @@ -1,20 +1,24 @@ import asyncio from typing import SupportsIndex, SupportsInt # pylint: disable=C0412 -from collections import deque try: + from collections import deque from collections.abc import Buffer except ImportError: + from collections import deque from typing import Protocol, runtime_checkable + @runtime_checkable class Buffer(Protocol): # type: ignore[no-redef] def __buffer__(self, flags: int = ...) -> memoryview: """Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12.""" + try: from datetime import UTC, datetime except ImportError: from datetime import datetime, timezone + UTC = timezone.utc diff --git a/amqtt/scripts/pub_script.py b/amqtt/scripts/pub_script.py index 942ce74..1e0a8f2 100644 --- a/amqtt/scripts/pub_script.py +++ b/amqtt/scripts/pub_script.py @@ -1,38 +1,7 @@ -"""amqtt_pub - MQTT 3.1.1 publisher. - -Usage: - amqtt_pub --version - amqtt_pub (-h | --help) - amqtt_pub --url BROKER_URL -t TOPIC (-f FILE | -l | -m MESSAGE | -n | -s) [-c CONFIG_FILE] [-i CLIENT_ID] [-q | --qos QOS] [-d] [-k KEEP_ALIVE] [--clean-session] [--ca-file CAFILE] [--ca-path CAPATH] [--ca-data CADATA] [ --will-topic WILL_TOPIC [--will-message WILL_MESSAGE] [--will-qos WILL_QOS] [--will-retain] ] [--extra-headers HEADER] [-r] - -Options: - -h --help Show this screen. - --version Show version. - --url BROKER_URL Broker connection URL (must conform to MQTT URI scheme (see https://github.com/mqtt/mqtt.github.io/wiki/URI-Scheme>) - -c CONFIG_FILE Broker configuration file (YAML format) - -i CLIENT_ID Id to use as client ID. - -q | --qos QOS Quality of service to use for the message, from 0, 1, and 2. Defaults to 0. - -r Set retain flag on connect - -t TOPIC Message topic - -m MESSAGE Message data to send - -f FILE Read file by line and publish message for each line - -s Read from stdin and publish message for each line - -k KEEP_ALIVE Keep alive timeout in seconds - --clean-session Clean session on connect (defaults to False) - --ca-file CAFILE CA file - --ca-path CAPATH CA Path - --ca-data CADATA CA data - --will-topic WILL_TOPIC - --will-message WILL_MESSAGE - --will-qos WILL_QOS - --will-retain - --extra-headers EXTRA_HEADERS JSON object with key-value pairs of additional headers for websocket connections - -d Enable debug messages -""" # noqa: E501 - import asyncio from collections.abc import Generator import contextlib +from dataclasses import dataclass import json import logging import os @@ -41,11 +10,11 @@ import socket import sys from typing import Any -from docopt import docopt +import typer -import amqtt +from amqtt import __version__ as amqtt_version from amqtt.client import MQTTClient -from amqtt.errors import ConnectError +from amqtt.errors import ClientError, ConnectError from amqtt.utils import read_yaml_config logger = logging.getLogger(__name__) @@ -57,64 +26,79 @@ def _gen_client_id() -> str: return f"amqtt_pub/{pid}-{hostname}" -def _get_qos(arguments: dict[str, Any]) -> int | None: +def _get_extra_headers(extra_headers_json: str | None = None) -> dict[str, Any]: try: - return int(arguments["--qos"][0]) - except (ValueError, IndexError): - return None - - -def _get_extra_headers(arguments: dict[str, Any]) -> dict[str, Any]: - try: - extra_headers: dict[str, Any] = json.loads(arguments["--extra-headers"]) + extra_headers: dict[str, Any] = json.loads(extra_headers_json or "{}") except (json.JSONDecodeError, TypeError): return {} return extra_headers -def _get_message(arguments: dict[str, Any]) -> Generator[bytes | bytearray]: - if arguments["-n"]: - yield b"" - if arguments["-m"]: - yield arguments["-m"].encode(encoding="utf-8") - if arguments["-f"]: - try: - with Path(arguments["-f"]).open(encoding="utf-8") as f: - for line in f: +@dataclass +class MessageInput: + message_str: str | None = None + file: str | None = None + stdin: bool | None = False + lines: bool | None = False + no_message: bool | None = False + + def get_message(self) -> Generator[bytes | bytearray]: + if self.no_message: + yield b"" + if self.message_str: + yield self.message_str.encode(encoding="utf-8") + if self.file: + try: + with Path(self.file).open(encoding="utf-8") as f: + for line in f: + yield line.encode(encoding="utf-8") + except Exception: + logger.exception(f"Failed to read file '{self.file}'") + if self.lines: + for line in sys.stdin: + if line: yield line.encode(encoding="utf-8") - except Exception: - logger.exception(f"Failed to read file '{arguments['-f']}'") - if arguments["-l"]: - for line in sys.stdin: - if line: - yield line.encode(encoding="utf-8") - if arguments["-s"]: - message = bytearray() - for line in sys.stdin: - message.extend(line.encode(encoding="utf-8")) - yield message + if self.stdin: + messages = bytearray() + for line in sys.stdin: + messages.extend(line.encode(encoding="utf-8")) + yield messages -async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None: - """Perform the publish.""" +@dataclass +class CAInfo: + ca_file: str | None = None + ca_path: str | None = None + ca_data: str | None = None + + +async def do_pub( + client: MQTTClient, + url: str, + topic: str, + message_input: MessageInput, + ca_info: CAInfo, + clean_session: bool = False, + retain: bool = False, + extra_headers_json: str | None = None, + qos: int | None = None, +) -> None: + """Publish the message.""" running_tasks = [] try: logger.info(f"{client.client_id} Connecting to broker") await client.connect( - uri=arguments["--url"], - cleansession=arguments["--clean-session"], - cafile=arguments["--ca-file"], - capath=arguments["--ca-path"], - cadata=arguments["--ca-data"], - additional_headers=_get_extra_headers(arguments), + uri=url, + cleansession=clean_session, + cafile=ca_info.ca_file, + capath=ca_info.ca_path, + cadata=ca_info.ca_data, + additional_headers=_get_extra_headers(extra_headers_json), ) - qos = _get_qos(arguments) - topic = arguments["-t"] - retain = arguments["-r"] - for message in _get_message(arguments): + for message in message_input.get_message(): logger.info(f"{client.client_id} Publishing to '{topic}'") task = asyncio.ensure_future(client.publish(topic, message, qos, retain)) running_tasks.append(task) @@ -128,22 +112,68 @@ async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None: await client.disconnect() logger.info(f"{client.client_id} Disconnected from broker") except ConnectError as ce: - logger.fatal(f"Connection to '{arguments['--url']}' failed: {ce!r}") + logger.fatal(f"Connection to '{url}' failed: {ce!r}") except asyncio.CancelledError: logger.fatal("Publish canceled due to previous error") def main() -> None: + """Entry point for the amqtt publisher.""" + typer.run(publisher) + + +def _version() -> None: + typer.echo(f"{amqtt_version}") + raise typer.Exit(code=0) + + +def publisher( # pylint: disable=R0914,R0917 # noqa : PLR0913 + url: str = typer.Option( + ..., "--url", help="Broker connection URL (must conform to MQTT URI scheme: mqtt://@HOST:port)" + ), + config_file: str | None = typer.Option(None, "-c", "--config-file", help="Broker configuration file (YAML format)"), + client_id: str | None = typer.Option(None, "-i", "--client-id", help="Client ID to use for the connection"), + qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"), + retain: bool = typer.Option(False, "-r", help="Set retain flag on connect"), + topic: str = typer.Option(..., "-t", help="Message topic"), + message: str | None = typer.Option(None, "-m", help="Message data to send"), + file: str | None = typer.Option(None, "-f", help="Read file by line and publish each line as a message"), + stdin: bool = typer.Option(False, "-s", help="Read from stdin and publish message for first line"), + lines: bool = typer.Option(False, "-l", help="Read from stdin and publish message for each line"), + no_message: bool = typer.Option(False, "-n", help="Publish an empty message"), + keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"), + clean_session: bool = typer.Option(False, "--clean-session", help="Clean session on connect (defaults to False)"), + ca_file: str | None = typer.Option(None, "--ca-file", help="CA file"), + ca_path: str | None = typer.Option(None, "--ca-path", help="CA path"), + ca_data: str | None = typer.Option(None, "--ca-data", help="CA data"), + will_topic: str | None = typer.Option(None, "--will-topic", help="Last will topic"), + will_message: str | None = typer.Option(None, "--will-message", help="Last will message"), + will_qos: int | None = typer.Option(None, "--will-qos", help="Last will QoS"), + will_retain: bool = typer.Option(False, "--will-retain", help="Set retain flag for last will message"), + extra_headers_json: str | None = typer.Option( + None, "--extra-headers", help="JSON object with key-value headers for websocket connections" + ), + debug: bool = typer.Option(False, "-d", help="Enable debug messages"), + version: bool | None = typer.Option( # noqa : ARG001 + None, + "--version", + callback=_version, + is_eager=True, + help="Show version and exit", + ), +) -> None: """Run the MQTT publisher.""" - arguments = docopt(__doc__, version=amqtt.__version__) + provided = [bool(message), bool(file), stdin, lines, no_message] + if sum(provided) != 1: + typer.echo("❌ You must provide exactly one of --config, --file, or --stdin.", err=True) + raise typer.Exit(code=1) formatter = "[%(asctime)s] :: %(levelname)s - %(message)s" - level = logging.DEBUG if arguments["-d"] else logging.INFO + level = logging.DEBUG if debug else logging.INFO logging.basicConfig(level=level, format=formatter) - config = None - if arguments["-c"]: - config = read_yaml_config(arguments["-c"]) + if config_file: + config = read_yaml_config(config_file) else: default_config_path = Path(__file__).parent / "default_client.yaml" logger.debug(f"Using default configuration from {default_config_path}") @@ -151,7 +181,6 @@ def main() -> None: loop = asyncio.get_event_loop() - client_id = arguments.get("-i", None) if not client_id: client_id = _gen_client_id() @@ -159,22 +188,49 @@ def main() -> None: logger.debug("Failed to correctly initialize config") return - if arguments["-k"]: - config["keep_alive"] = int(arguments["-k"]) + if keep_alive: + config["keep_alive"] = int(keep_alive) - if arguments["--will-topic"] and arguments["--will-message"] and arguments["--will-qos"]: + if will_topic and will_message and will_qos: config["will"] = { - "topic": arguments["--will-topic"], - "message": arguments["--will-message"].encode("utf-8"), - "qos": int(arguments["--will-qos"]), - "retain": arguments["--will-retain"], + "topic": will_topic, + "message": will_message.encode("utf-8"), + "qos": int(will_qos), + "retain": will_retain, } client = MQTTClient(client_id=client_id, config=config) + message_input = MessageInput( + message_str=message, + file=file, + stdin=stdin, + no_message=no_message, + lines=lines, + ) + ca_info = CAInfo( + ca_file=ca_file, + ca_path=ca_path, + ca_data=ca_data, + ) with contextlib.suppress(KeyboardInterrupt): - loop.run_until_complete(do_pub(client, arguments)) + try: + loop.run_until_complete( + do_pub( + client=client, + message_input=message_input, + url=url, + topic=topic, + retain=retain, + clean_session=clean_session, + ca_info=ca_info, + extra_headers_json=extra_headers_json, + qos=qos, + ) + ) + except (ClientError, ConnectError) as ce: + typer.echo(f"❌ Connection failed: {ce}", err=True) loop.close() if __name__ == "__main__": - main() + typer.run(main) diff --git a/amqtt/version.py b/amqtt/version.py index d5dde47..2b41d19 100644 --- a/amqtt/version.py +++ b/amqtt/version.py @@ -2,6 +2,7 @@ try: from datetime import UTC, datetime except ImportError: from datetime import datetime, timezone + UTC = timezone.utc import logging diff --git a/pyproject.toml b/pyproject.toml index fcb7c3f..ec2a350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [build-system] -requires = ["hatchling", "hatch-vcs"] +requires = ["hatchling", "hatch-vcs", "uv-dynamic-versioning"] + build-backend = "hatchling.build" [project] @@ -117,11 +118,13 @@ ignore = [ "TD003", # TODO Missing issue link for this TODO. "ANN401", # Dynamically typed expressions (typing.Any) are disallowed "ARG002", # Unused method argument - "PERF203" # try-except penalty within loops (3.10 only) + "PERF203",# try-except penalty within loops (3.10 only), + "COM812" # rule causes conflicts when used with the formatter ] [tool.ruff.lint.per-file-ignores] "tests/**" = ["ALL"] +"amqtt/scripts/pub_script.py" = ["FBT003"] [tool.ruff.lint.flake8-pytest-style] fixture-parentheses = false