diff --git a/amqtt/mqtt/packet.py b/amqtt/mqtt/packet.py index 23fa377..38ad4aa 100644 --- a/amqtt/mqtt/packet.py +++ b/amqtt/mqtt/packet.py @@ -9,7 +9,7 @@ except ImportError: UTC = timezone.utc from struct import unpack -from typing import Generic, Self, TypeVar +from typing_extensions import Generic, Self, TypeVar from amqtt.adapters import ReaderAdapter, WriterAdapter from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise diff --git a/amqtt/scripts/pub_script.py b/amqtt/scripts/pub_script.py index 50bd085..656faf3 100644 --- a/amqtt/scripts/pub_script.py +++ b/amqtt/scripts/pub_script.py @@ -151,7 +151,7 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 ca_data: str | None = typer.Option(None, "--ca-data", help="CA data"), will_topic: str | None = typer.Option(None, "--will-topic", help="Last will topic"), will_message: str | None = typer.Option(None, "--will-message", help="Last will message"), - will_qos: int | None = typer.Option(None, "--will-qos", help="Last will QoS"), + will_qos: int | None = typer.Option(0, "--will-qos", help="Last will QoS"), will_retain: bool = typer.Option(False, "--will-retain", help="Set retain flag for last will message"), extra_headers_json: str | None = typer.Option( None, "--extra-headers", help="JSON object with key-value headers for websocket connections" @@ -171,6 +171,14 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 typer.echo("❌ You must provide exactly one of --config, --file, or --stdin.", err=True) raise typer.Exit(code=1) + if bool(will_message) != bool(will_topic): + typer.echo("❌ must specify both 'will_message' and 'will_topic' ") + raise typer.Exit(code=1) + + if will_retain and not (will_message and will_topic): + typer.echo("❌ 'will-retain' only valid if 'will_message' and 'will_topic' are specified.", err=True) + raise typer.Exit(code=1) + formatter = "[%(asctime)s] :: %(levelname)s - %(message)s" level = logging.DEBUG if debug else logging.INFO logging.basicConfig(level=level, format=formatter) @@ -194,11 +202,12 @@ def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 if keep_alive: config["keep_alive"] = int(keep_alive) - if will_topic and will_message and will_qos: + + if will_topic and will_message and will_qos is not None and will_retain: config["will"] = { "topic": will_topic, - "message": will_message.encode("utf-8"), - "qos": int(will_qos), + "message": will_message.encode(), + "qos": will_qos, "retain": will_retain, } diff --git a/tests/test_cli.py b/tests/test_cli.py index 4748627..239c3d8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -158,8 +158,10 @@ async def test_pub_sub_options(broker): [ "amqtt_pub", "--url", "mqtt://127.0.0.1:1884", - "-t", "test/retain", - "-m", "retained message", + "-t", "topic/test", + "-m", "standard message", + "--will-topic", "topic/retain", + "--will-message", "last will message", "--will-retain", ], capture_output=True, @@ -171,13 +173,13 @@ async def test_pub_sub_options(broker): [ "amqtt_sub", "--url", "mqtt://127.0.0.1:1884", - "-t", "test/retain", + "-t", "topic/retain", "-n", "1", ], capture_output=True, ) assert sub_proc.returncode == 0, "subscriber error code" - assert "retained message" in str(sub_proc.stdout) + assert "last will message" in str(sub_proc.stdout) @pytest.mark.asyncio