From afad06a4cc3b81a2aa1d4c9fbb4dbbab3de588c0 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Wed, 6 Aug 2025 22:12:00 -0400 Subject: [PATCH] lint fixes --- amqtt/contrib/jwt.py | 62 ++++++++++++++++++++++----------------- tests/contrib/test_jwt.py | 17 +++++++---- 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/amqtt/contrib/jwt.py b/amqtt/contrib/jwt.py index dcdaeb9..43de890 100644 --- a/amqtt/contrib/jwt.py +++ b/amqtt/contrib/jwt.py @@ -1,9 +1,8 @@ -import logging from dataclasses import dataclass +import logging from typing import ClassVar -from amqtt.broker import BrokerContext -from amqtt.plugins import TopicMatcher +import jwt try: from enum import StrEnum @@ -13,41 +12,43 @@ except ImportError: class StrEnum(str, Enum): #type: ignore[no-redef] pass - +from amqtt.broker import BrokerContext from amqtt.contexts import Action +from amqtt.plugins import TopicMatcher from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin from amqtt.session import Session -import jwt - logger = logging.getLogger(__name__) class Algorithms(StrEnum): - ES256 = 'ES256' - ES256K = 'ES256K' - ES384 = 'ES384' - ES512 = 'ES512' - ES521 = 'ES521' - EdDSA = 'EdDSA' - HS256 = 'HS256' - HS384 = 'HS384' - HS512 = 'HS512' - PS256 = 'PS256' - PS384 = 'PS384' - PS512 = 'PS512' - RS256 = 'RS256' - RS384 = 'RS384' - RS512 = 'RS512' + ES256 = "ES256" + ES256K = "ES256K" + ES384 = "ES384" + ES512 = "ES512" + ES521 = "ES521" + EdDSA = "EdDSA" + HS256 = "HS256" + HS384 = "HS384" + HS512 = "HS512" + PS256 = "PS256" + PS384 = "PS384" + PS512 = "PS512" + RS256 = "RS256" + RS384 = "RS384" + RS512 = "RS512" class UserAuthJwtPlugin(BaseAuthPlugin): async def authenticate(self, *, session: Session) -> bool | None: + if not session.username or not session.password: + return None + try: decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"]) - return decoded_payload.get(self.config.user_claim, None) == session.username + return bool(decoded_payload.get(self.config.user_claim, None) == session.username) except jwt.ExpiredSignatureError: logger.debug(f"jwt for '{session.username}' is expired") return False @@ -57,6 +58,8 @@ class UserAuthJwtPlugin(BaseAuthPlugin): @dataclass class Config: + """Configuration for the JWT user authentication.""" + secret_key: str """Secret key to decrypt the token.""" user_claim: str @@ -69,9 +72,9 @@ class UserAuthJwtPlugin(BaseAuthPlugin): class TopicAuthJwtPlugin(BaseTopicPlugin): _topic_jwt_claims: ClassVar = { - Action.PUBLISH: 'publish_claim', - Action.SUBSCRIBE: 'subscribe_claim', - Action.RECEIVE: 'receive_claim', + Action.PUBLISH: "publish_claim", + Action.SUBSCRIBE: "subscribe_claim", + Action.RECEIVE: "receive_claim", } def __init__(self, context: BrokerContext) -> None: @@ -83,11 +86,14 @@ class TopicAuthJwtPlugin(BaseTopicPlugin): self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None ) -> bool | None: - if not any([session, topic, action]): + if not session or not topic or not action: + return None + + if not session.password: return None try: - decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"]) + decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"]) claim = getattr(self.config, self._topic_jwt_claims[action]) return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, [])) except jwt.ExpiredSignatureError: @@ -99,6 +105,8 @@ class TopicAuthJwtPlugin(BaseTopicPlugin): @dataclass class Config: + """Configuration for the JWT topic authorization.""" + secret_key: str """Secret key to decrypt the token.""" publish_claim: str diff --git a/tests/contrib/test_jwt.py b/tests/contrib/test_jwt.py index 2a84065..20d37ea 100644 --- a/tests/contrib/test_jwt.py +++ b/tests/contrib/test_jwt.py @@ -1,8 +1,14 @@ import asyncio -import datetime import logging import secrets +try: + from datetime import UTC, datetime, timedelta +except ImportError: + from datetime import datetime, timezone, timedelta + + UTC = timezone.utc + import jwt import pytest @@ -18,9 +24,10 @@ from amqtt.session import Session def secret_key(): return secrets.token_urlsafe(32) + @pytest.mark.parametrize("exp_time, outcome", [ - (datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), True), - (datetime.datetime.now(datetime.UTC) - datetime.timedelta(hours=1), False), + (datetime.now(UTC) + timedelta(hours=1), True), + (datetime.now(UTC) - timedelta(hours=1), False), ]) @pytest.mark.asyncio async def test_user_jwt_plugin(secret_key, exp_time, outcome): @@ -50,7 +57,7 @@ async def test_topic_jwt_plugin(secret_key): payload = { "username": "example_user", - "exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), + "exp": datetime.now(UTC) + timedelta(hours=1), "publish_acl": ['my/topic/#', 'my/+/other'] } @@ -75,7 +82,7 @@ async def test_topic_jwt_plugin(secret_key): async def test_broker_with_jwt_plugin(secret_key, caplog): payload = { "username": "example_user", - "exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), + "exp": datetime.now(UTC) + timedelta(hours=1), "publish_acl": ['my/topic/#', 'my/+/other'], "subscribe_acl": ['my/+/other'], }