kopia lustrzana https://github.com/Yakifo/amqtt
replacing amqtt publisher use of 'docopts' with 'typer'
rodzic
d39bd2ec5c
commit
7255561b09
|
@ -5,11 +5,11 @@ try:
|
|||
from datetime import UTC, datetime
|
||||
except ImportError:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
UTC = timezone.utc
|
||||
|
||||
from struct import unpack
|
||||
from typing import Generic
|
||||
from typing_extensions import Self, TypeVar
|
||||
from typing import Generic, Self, TypeVar
|
||||
|
||||
from amqtt.adapters import ReaderAdapter, WriterAdapter
|
||||
from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
import asyncio
|
||||
from typing import SupportsIndex, SupportsInt # pylint: disable=C0412
|
||||
|
||||
from collections import deque
|
||||
try:
|
||||
from collections import deque
|
||||
from collections.abc import Buffer
|
||||
except ImportError:
|
||||
from collections import deque
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class Buffer(Protocol): # type: ignore[no-redef]
|
||||
def __buffer__(self, flags: int = ...) -> memoryview:
|
||||
"""Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""
|
||||
|
||||
|
||||
try:
|
||||
from datetime import UTC, datetime
|
||||
except ImportError:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
UTC = timezone.utc
|
||||
|
||||
|
||||
|
|
|
@ -1,38 +1,7 @@
|
|||
"""amqtt_pub - MQTT 3.1.1 publisher.
|
||||
|
||||
Usage:
|
||||
amqtt_pub --version
|
||||
amqtt_pub (-h | --help)
|
||||
amqtt_pub --url BROKER_URL -t TOPIC (-f FILE | -l | -m MESSAGE | -n | -s) [-c CONFIG_FILE] [-i CLIENT_ID] [-q | --qos QOS] [-d] [-k KEEP_ALIVE] [--clean-session] [--ca-file CAFILE] [--ca-path CAPATH] [--ca-data CADATA] [ --will-topic WILL_TOPIC [--will-message WILL_MESSAGE] [--will-qos WILL_QOS] [--will-retain] ] [--extra-headers HEADER] [-r]
|
||||
|
||||
Options:
|
||||
-h --help Show this screen.
|
||||
--version Show version.
|
||||
--url BROKER_URL Broker connection URL (must conform to MQTT URI scheme (see https://github.com/mqtt/mqtt.github.io/wiki/URI-Scheme>)
|
||||
-c CONFIG_FILE Broker configuration file (YAML format)
|
||||
-i CLIENT_ID Id to use as client ID.
|
||||
-q | --qos QOS Quality of service to use for the message, from 0, 1, and 2. Defaults to 0.
|
||||
-r Set retain flag on connect
|
||||
-t TOPIC Message topic
|
||||
-m MESSAGE Message data to send
|
||||
-f FILE Read file by line and publish message for each line
|
||||
-s Read from stdin and publish message for each line
|
||||
-k KEEP_ALIVE Keep alive timeout in seconds
|
||||
--clean-session Clean session on connect (defaults to False)
|
||||
--ca-file CAFILE CA file
|
||||
--ca-path CAPATH CA Path
|
||||
--ca-data CADATA CA data
|
||||
--will-topic WILL_TOPIC
|
||||
--will-message WILL_MESSAGE
|
||||
--will-qos WILL_QOS
|
||||
--will-retain
|
||||
--extra-headers EXTRA_HEADERS JSON object with key-value pairs of additional headers for websocket connections
|
||||
-d Enable debug messages
|
||||
""" # noqa: E501
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
@ -41,11 +10,11 @@ import socket
|
|||
import sys
|
||||
from typing import Any
|
||||
|
||||
from docopt import docopt
|
||||
import typer
|
||||
|
||||
import amqtt
|
||||
from amqtt import __version__ as amqtt_version
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.errors import ConnectError
|
||||
from amqtt.errors import ClientError, ConnectError
|
||||
from amqtt.utils import read_yaml_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -57,64 +26,79 @@ def _gen_client_id() -> str:
|
|||
return f"amqtt_pub/{pid}-{hostname}"
|
||||
|
||||
|
||||
def _get_qos(arguments: dict[str, Any]) -> int | None:
|
||||
def _get_extra_headers(extra_headers_json: str | None = None) -> dict[str, Any]:
|
||||
try:
|
||||
return int(arguments["--qos"][0])
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
|
||||
def _get_extra_headers(arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
try:
|
||||
extra_headers: dict[str, Any] = json.loads(arguments["--extra-headers"])
|
||||
extra_headers: dict[str, Any] = json.loads(extra_headers_json or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
return extra_headers
|
||||
|
||||
|
||||
def _get_message(arguments: dict[str, Any]) -> Generator[bytes | bytearray]:
|
||||
if arguments["-n"]:
|
||||
yield b""
|
||||
if arguments["-m"]:
|
||||
yield arguments["-m"].encode(encoding="utf-8")
|
||||
if arguments["-f"]:
|
||||
try:
|
||||
with Path(arguments["-f"]).open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
@dataclass
|
||||
class MessageInput:
|
||||
message_str: str | None = None
|
||||
file: str | None = None
|
||||
stdin: bool | None = False
|
||||
lines: bool | None = False
|
||||
no_message: bool | None = False
|
||||
|
||||
def get_message(self) -> Generator[bytes | bytearray]:
|
||||
if self.no_message:
|
||||
yield b""
|
||||
if self.message_str:
|
||||
yield self.message_str.encode(encoding="utf-8")
|
||||
if self.file:
|
||||
try:
|
||||
with Path(self.file).open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
yield line.encode(encoding="utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read file '{self.file}'")
|
||||
if self.lines:
|
||||
for line in sys.stdin:
|
||||
if line:
|
||||
yield line.encode(encoding="utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read file '{arguments['-f']}'")
|
||||
if arguments["-l"]:
|
||||
for line in sys.stdin:
|
||||
if line:
|
||||
yield line.encode(encoding="utf-8")
|
||||
if arguments["-s"]:
|
||||
message = bytearray()
|
||||
for line in sys.stdin:
|
||||
message.extend(line.encode(encoding="utf-8"))
|
||||
yield message
|
||||
if self.stdin:
|
||||
messages = bytearray()
|
||||
for line in sys.stdin:
|
||||
messages.extend(line.encode(encoding="utf-8"))
|
||||
yield messages
|
||||
|
||||
|
||||
async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None:
|
||||
"""Perform the publish."""
|
||||
@dataclass
|
||||
class CAInfo:
|
||||
ca_file: str | None = None
|
||||
ca_path: str | None = None
|
||||
ca_data: str | None = None
|
||||
|
||||
|
||||
async def do_pub(
|
||||
client: MQTTClient,
|
||||
url: str,
|
||||
topic: str,
|
||||
message_input: MessageInput,
|
||||
ca_info: CAInfo,
|
||||
clean_session: bool = False,
|
||||
retain: bool = False,
|
||||
extra_headers_json: str | None = None,
|
||||
qos: int | None = None,
|
||||
) -> None:
|
||||
"""Publish the message."""
|
||||
running_tasks = []
|
||||
|
||||
try:
|
||||
logger.info(f"{client.client_id} Connecting to broker")
|
||||
|
||||
await client.connect(
|
||||
uri=arguments["--url"],
|
||||
cleansession=arguments["--clean-session"],
|
||||
cafile=arguments["--ca-file"],
|
||||
capath=arguments["--ca-path"],
|
||||
cadata=arguments["--ca-data"],
|
||||
additional_headers=_get_extra_headers(arguments),
|
||||
uri=url,
|
||||
cleansession=clean_session,
|
||||
cafile=ca_info.ca_file,
|
||||
capath=ca_info.ca_path,
|
||||
cadata=ca_info.ca_data,
|
||||
additional_headers=_get_extra_headers(extra_headers_json),
|
||||
)
|
||||
|
||||
qos = _get_qos(arguments)
|
||||
topic = arguments["-t"]
|
||||
retain = arguments["-r"]
|
||||
for message in _get_message(arguments):
|
||||
for message in message_input.get_message():
|
||||
logger.info(f"{client.client_id} Publishing to '{topic}'")
|
||||
task = asyncio.ensure_future(client.publish(topic, message, qos, retain))
|
||||
running_tasks.append(task)
|
||||
|
@ -128,22 +112,68 @@ async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None:
|
|||
await client.disconnect()
|
||||
logger.info(f"{client.client_id} Disconnected from broker")
|
||||
except ConnectError as ce:
|
||||
logger.fatal(f"Connection to '{arguments['--url']}' failed: {ce!r}")
|
||||
logger.fatal(f"Connection to '{url}' failed: {ce!r}")
|
||||
except asyncio.CancelledError:
|
||||
logger.fatal("Publish canceled due to previous error")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point for the amqtt publisher."""
|
||||
typer.run(publisher)
|
||||
|
||||
|
||||
def _version() -> None:
|
||||
typer.echo(f"{amqtt_version}")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
|
||||
def publisher( # pylint: disable=R0914,R0917 # noqa : PLR0913
|
||||
url: str = typer.Option(
|
||||
..., "--url", help="Broker connection URL (must conform to MQTT URI scheme: mqtt://<username:password>@HOST:port)"
|
||||
),
|
||||
config_file: str | None = typer.Option(None, "-c", "--config-file", help="Broker configuration file (YAML format)"),
|
||||
client_id: str | None = typer.Option(None, "-i", "--client-id", help="Client ID to use for the connection"),
|
||||
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),
|
||||
retain: bool = typer.Option(False, "-r", help="Set retain flag on connect"),
|
||||
topic: str = typer.Option(..., "-t", help="Message topic"),
|
||||
message: str | None = typer.Option(None, "-m", help="Message data to send"),
|
||||
file: str | None = typer.Option(None, "-f", help="Read file by line and publish each line as a message"),
|
||||
stdin: bool = typer.Option(False, "-s", help="Read from stdin and publish message for first line"),
|
||||
lines: bool = typer.Option(False, "-l", help="Read from stdin and publish message for each line"),
|
||||
no_message: bool = typer.Option(False, "-n", help="Publish an empty message"),
|
||||
keep_alive: int | None = typer.Option(None, "-k", help="Keep alive timeout in seconds"),
|
||||
clean_session: bool = typer.Option(False, "--clean-session", help="Clean session on connect (defaults to False)"),
|
||||
ca_file: str | None = typer.Option(None, "--ca-file", help="CA file"),
|
||||
ca_path: str | None = typer.Option(None, "--ca-path", help="CA path"),
|
||||
ca_data: str | None = typer.Option(None, "--ca-data", help="CA data"),
|
||||
will_topic: str | None = typer.Option(None, "--will-topic", help="Last will topic"),
|
||||
will_message: str | None = typer.Option(None, "--will-message", help="Last will message"),
|
||||
will_qos: int | None = typer.Option(None, "--will-qos", help="Last will QoS"),
|
||||
will_retain: bool = typer.Option(False, "--will-retain", help="Set retain flag for last will message"),
|
||||
extra_headers_json: str | None = typer.Option(
|
||||
None, "--extra-headers", help="JSON object with key-value headers for websocket connections"
|
||||
),
|
||||
debug: bool = typer.Option(False, "-d", help="Enable debug messages"),
|
||||
version: bool | None = typer.Option( # noqa : ARG001
|
||||
None,
|
||||
"--version",
|
||||
callback=_version,
|
||||
is_eager=True,
|
||||
help="Show version and exit",
|
||||
),
|
||||
) -> None:
|
||||
"""Run the MQTT publisher."""
|
||||
arguments = docopt(__doc__, version=amqtt.__version__)
|
||||
provided = [bool(message), bool(file), stdin, lines, no_message]
|
||||
if sum(provided) != 1:
|
||||
typer.echo("❌ You must provide exactly one of --config, --file, or --stdin.", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
|
||||
level = logging.DEBUG if arguments["-d"] else logging.INFO
|
||||
level = logging.DEBUG if debug else logging.INFO
|
||||
logging.basicConfig(level=level, format=formatter)
|
||||
|
||||
config = None
|
||||
if arguments["-c"]:
|
||||
config = read_yaml_config(arguments["-c"])
|
||||
if config_file:
|
||||
config = read_yaml_config(config_file)
|
||||
else:
|
||||
default_config_path = Path(__file__).parent / "default_client.yaml"
|
||||
logger.debug(f"Using default configuration from {default_config_path}")
|
||||
|
@ -151,7 +181,6 @@ def main() -> None:
|
|||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
client_id = arguments.get("-i", None)
|
||||
if not client_id:
|
||||
client_id = _gen_client_id()
|
||||
|
||||
|
@ -159,22 +188,49 @@ def main() -> None:
|
|||
logger.debug("Failed to correctly initialize config")
|
||||
return
|
||||
|
||||
if arguments["-k"]:
|
||||
config["keep_alive"] = int(arguments["-k"])
|
||||
if keep_alive:
|
||||
config["keep_alive"] = int(keep_alive)
|
||||
|
||||
if arguments["--will-topic"] and arguments["--will-message"] and arguments["--will-qos"]:
|
||||
if will_topic and will_message and will_qos:
|
||||
config["will"] = {
|
||||
"topic": arguments["--will-topic"],
|
||||
"message": arguments["--will-message"].encode("utf-8"),
|
||||
"qos": int(arguments["--will-qos"]),
|
||||
"retain": arguments["--will-retain"],
|
||||
"topic": will_topic,
|
||||
"message": will_message.encode("utf-8"),
|
||||
"qos": int(will_qos),
|
||||
"retain": will_retain,
|
||||
}
|
||||
|
||||
client = MQTTClient(client_id=client_id, config=config)
|
||||
message_input = MessageInput(
|
||||
message_str=message,
|
||||
file=file,
|
||||
stdin=stdin,
|
||||
no_message=no_message,
|
||||
lines=lines,
|
||||
)
|
||||
ca_info = CAInfo(
|
||||
ca_file=ca_file,
|
||||
ca_path=ca_path,
|
||||
ca_data=ca_data,
|
||||
)
|
||||
with contextlib.suppress(KeyboardInterrupt):
|
||||
loop.run_until_complete(do_pub(client, arguments))
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
do_pub(
|
||||
client=client,
|
||||
message_input=message_input,
|
||||
url=url,
|
||||
topic=topic,
|
||||
retain=retain,
|
||||
clean_session=clean_session,
|
||||
ca_info=ca_info,
|
||||
extra_headers_json=extra_headers_json,
|
||||
qos=qos,
|
||||
)
|
||||
)
|
||||
except (ClientError, ConnectError) as ce:
|
||||
typer.echo(f"❌ Connection failed: {ce}", err=True)
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
typer.run(main)
|
||||
|
|
|
@ -2,6 +2,7 @@ try:
|
|||
from datetime import UTC, datetime
|
||||
except ImportError:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
UTC = timezone.utc
|
||||
|
||||
import logging
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
[build-system]
|
||||
requires = ["hatchling", "hatch-vcs"]
|
||||
requires = ["hatchling", "hatch-vcs", "uv-dynamic-versioning"]
|
||||
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
|
@ -117,11 +118,13 @@ ignore = [
|
|||
"TD003", # TODO Missing issue link for this TODO.
|
||||
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
|
||||
"ARG002", # Unused method argument
|
||||
"PERF203" # try-except penalty within loops (3.10 only)
|
||||
"PERF203",# try-except penalty within loops (3.10 only),
|
||||
"COM812" # rule causes conflicts when used with the formatter
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**" = ["ALL"]
|
||||
"amqtt/scripts/pub_script.py" = ["FBT003"]
|
||||
|
||||
[tool.ruff.lint.flake8-pytest-style]
|
||||
fixture-parentheses = false
|
||||
|
|
Ładowanie…
Reference in New Issue