pull/289/head
Andrew Mirsky 2025-08-06 22:12:00 -04:00
rodzic 04c0212c46
commit afad06a4cc
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
2 zmienionych plików z 47 dodań i 32 usunięć

Wyświetl plik

@ -1,9 +1,8 @@
import logging
from dataclasses import dataclass from dataclasses import dataclass
import logging
from typing import ClassVar from typing import ClassVar
from amqtt.broker import BrokerContext import jwt
from amqtt.plugins import TopicMatcher
try: try:
from enum import StrEnum from enum import StrEnum
@ -13,41 +12,43 @@ except ImportError:
class StrEnum(str, Enum): #type: ignore[no-redef] class StrEnum(str, Enum): #type: ignore[no-redef]
pass pass
from amqtt.broker import BrokerContext
from amqtt.contexts import Action from amqtt.contexts import Action
from amqtt.plugins import TopicMatcher
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
from amqtt.session import Session from amqtt.session import Session
import jwt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Algorithms(StrEnum): class Algorithms(StrEnum):
ES256 = 'ES256' ES256 = "ES256"
ES256K = 'ES256K' ES256K = "ES256K"
ES384 = 'ES384' ES384 = "ES384"
ES512 = 'ES512' ES512 = "ES512"
ES521 = 'ES521' ES521 = "ES521"
EdDSA = 'EdDSA' EdDSA = "EdDSA"
HS256 = 'HS256' HS256 = "HS256"
HS384 = 'HS384' HS384 = "HS384"
HS512 = 'HS512' HS512 = "HS512"
PS256 = 'PS256' PS256 = "PS256"
PS384 = 'PS384' PS384 = "PS384"
PS512 = 'PS512' PS512 = "PS512"
RS256 = 'RS256' RS256 = "RS256"
RS384 = 'RS384' RS384 = "RS384"
RS512 = 'RS512' RS512 = "RS512"
class UserAuthJwtPlugin(BaseAuthPlugin): class UserAuthJwtPlugin(BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None: async def authenticate(self, *, session: Session) -> bool | None:
if not session.username or not session.password:
return None
try: try:
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"]) decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
return decoded_payload.get(self.config.user_claim, None) == session.username return bool(decoded_payload.get(self.config.user_claim, None) == session.username)
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
logger.debug(f"jwt for '{session.username}' is expired") logger.debug(f"jwt for '{session.username}' is expired")
return False return False
@ -57,6 +58,8 @@ class UserAuthJwtPlugin(BaseAuthPlugin):
@dataclass @dataclass
class Config: class Config:
"""Configuration for the JWT user authentication."""
secret_key: str secret_key: str
"""Secret key to decrypt the token.""" """Secret key to decrypt the token."""
user_claim: str user_claim: str
@ -69,9 +72,9 @@ class UserAuthJwtPlugin(BaseAuthPlugin):
class TopicAuthJwtPlugin(BaseTopicPlugin): class TopicAuthJwtPlugin(BaseTopicPlugin):
_topic_jwt_claims: ClassVar = { _topic_jwt_claims: ClassVar = {
Action.PUBLISH: 'publish_claim', Action.PUBLISH: "publish_claim",
Action.SUBSCRIBE: 'subscribe_claim', Action.SUBSCRIBE: "subscribe_claim",
Action.RECEIVE: 'receive_claim', Action.RECEIVE: "receive_claim",
} }
def __init__(self, context: BrokerContext) -> None: def __init__(self, context: BrokerContext) -> None:
@ -83,11 +86,14 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool | None: ) -> bool | None:
if not any([session, topic, action]): if not session or not topic or not action:
return None
if not session.password:
return None return None
try: try:
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"]) decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"])
claim = getattr(self.config, self._topic_jwt_claims[action]) claim = getattr(self.config, self._topic_jwt_claims[action])
return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, [])) return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
@ -99,6 +105,8 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
@dataclass @dataclass
class Config: class Config:
"""Configuration for the JWT topic authorization."""
secret_key: str secret_key: str
"""Secret key to decrypt the token.""" """Secret key to decrypt the token."""
publish_claim: str publish_claim: str

Wyświetl plik

@ -1,8 +1,14 @@
import asyncio import asyncio
import datetime
import logging import logging
import secrets import secrets
try:
from datetime import UTC, datetime, timedelta
except ImportError:
from datetime import datetime, timezone, timedelta
UTC = timezone.utc
import jwt import jwt
import pytest import pytest
@ -18,9 +24,10 @@ from amqtt.session import Session
def secret_key(): def secret_key():
return secrets.token_urlsafe(32) return secrets.token_urlsafe(32)
@pytest.mark.parametrize("exp_time, outcome", [ @pytest.mark.parametrize("exp_time, outcome", [
(datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), True), (datetime.now(UTC) + timedelta(hours=1), True),
(datetime.datetime.now(datetime.UTC) - datetime.timedelta(hours=1), False), (datetime.now(UTC) - timedelta(hours=1), False),
]) ])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_jwt_plugin(secret_key, exp_time, outcome): async def test_user_jwt_plugin(secret_key, exp_time, outcome):
@ -50,7 +57,7 @@ async def test_topic_jwt_plugin(secret_key):
payload = { payload = {
"username": "example_user", "username": "example_user",
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), "exp": datetime.now(UTC) + timedelta(hours=1),
"publish_acl": ['my/topic/#', 'my/+/other'] "publish_acl": ['my/topic/#', 'my/+/other']
} }
@ -75,7 +82,7 @@ async def test_topic_jwt_plugin(secret_key):
async def test_broker_with_jwt_plugin(secret_key, caplog): async def test_broker_with_jwt_plugin(secret_key, caplog):
payload = { payload = {
"username": "example_user", "username": "example_user",
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), "exp": datetime.now(UTC) + timedelta(hours=1),
"publish_acl": ['my/topic/#', 'my/+/other'], "publish_acl": ['my/topic/#', 'my/+/other'],
"subscribe_acl": ['my/+/other'], "subscribe_acl": ['my/+/other'],
} }