pull/289/head
Andrew Mirsky 2025-08-06 22:12:00 -04:00
rodzic 04c0212c46
commit afad06a4cc
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
2 zmienionych plików z 47 dodań i 32 usunięć

Wyświetl plik

@ -1,9 +1,8 @@
import logging
from dataclasses import dataclass
import logging
from typing import ClassVar
from amqtt.broker import BrokerContext
from amqtt.plugins import TopicMatcher
import jwt
try:
from enum import StrEnum
@ -13,41 +12,43 @@ except ImportError:
class StrEnum(str, Enum): #type: ignore[no-redef]
pass
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.plugins import TopicMatcher
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
from amqtt.session import Session
import jwt
logger = logging.getLogger(__name__)
class Algorithms(StrEnum):
ES256 = 'ES256'
ES256K = 'ES256K'
ES384 = 'ES384'
ES512 = 'ES512'
ES521 = 'ES521'
EdDSA = 'EdDSA'
HS256 = 'HS256'
HS384 = 'HS384'
HS512 = 'HS512'
PS256 = 'PS256'
PS384 = 'PS384'
PS512 = 'PS512'
RS256 = 'RS256'
RS384 = 'RS384'
RS512 = 'RS512'
ES256 = "ES256"
ES256K = "ES256K"
ES384 = "ES384"
ES512 = "ES512"
ES521 = "ES521"
EdDSA = "EdDSA"
HS256 = "HS256"
HS384 = "HS384"
HS512 = "HS512"
PS256 = "PS256"
PS384 = "PS384"
PS512 = "PS512"
RS256 = "RS256"
RS384 = "RS384"
RS512 = "RS512"
class UserAuthJwtPlugin(BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
if not session.username or not session.password:
return None
try:
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
return decoded_payload.get(self.config.user_claim, None) == session.username
return bool(decoded_payload.get(self.config.user_claim, None) == session.username)
except jwt.ExpiredSignatureError:
logger.debug(f"jwt for '{session.username}' is expired")
return False
@ -57,6 +58,8 @@ class UserAuthJwtPlugin(BaseAuthPlugin):
@dataclass
class Config:
"""Configuration for the JWT user authentication."""
secret_key: str
"""Secret key to decrypt the token."""
user_claim: str
@ -69,9 +72,9 @@ class UserAuthJwtPlugin(BaseAuthPlugin):
class TopicAuthJwtPlugin(BaseTopicPlugin):
_topic_jwt_claims: ClassVar = {
Action.PUBLISH: 'publish_claim',
Action.SUBSCRIBE: 'subscribe_claim',
Action.RECEIVE: 'receive_claim',
Action.PUBLISH: "publish_claim",
Action.SUBSCRIBE: "subscribe_claim",
Action.RECEIVE: "receive_claim",
}
def __init__(self, context: BrokerContext) -> None:
@ -83,11 +86,14 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool | None:
if not any([session, topic, action]):
if not session or not topic or not action:
return None
if not session.password:
return None
try:
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"])
claim = getattr(self.config, self._topic_jwt_claims[action])
return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
except jwt.ExpiredSignatureError:
@ -99,6 +105,8 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
@dataclass
class Config:
"""Configuration for the JWT topic authorization."""
secret_key: str
"""Secret key to decrypt the token."""
publish_claim: str

Wyświetl plik

@ -1,8 +1,14 @@
import asyncio
import datetime
import logging
import secrets
try:
from datetime import UTC, datetime, timedelta
except ImportError:
from datetime import datetime, timezone, timedelta
UTC = timezone.utc
import jwt
import pytest
@ -18,9 +24,10 @@ from amqtt.session import Session
def secret_key():
return secrets.token_urlsafe(32)
@pytest.mark.parametrize("exp_time, outcome", [
(datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), True),
(datetime.datetime.now(datetime.UTC) - datetime.timedelta(hours=1), False),
(datetime.now(UTC) + timedelta(hours=1), True),
(datetime.now(UTC) - timedelta(hours=1), False),
])
@pytest.mark.asyncio
async def test_user_jwt_plugin(secret_key, exp_time, outcome):
@ -50,7 +57,7 @@ async def test_topic_jwt_plugin(secret_key):
payload = {
"username": "example_user",
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1),
"exp": datetime.now(UTC) + timedelta(hours=1),
"publish_acl": ['my/topic/#', 'my/+/other']
}
@ -75,7 +82,7 @@ async def test_topic_jwt_plugin(secret_key):
async def test_broker_with_jwt_plugin(secret_key, caplog):
payload = {
"username": "example_user",
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1),
"exp": datetime.now(UTC) + timedelta(hours=1),
"publish_acl": ['my/topic/#', 'my/+/other'],
"subscribe_acl": ['my/+/other'],
}