kopia lustrzana https://github.com/Yakifo/amqtt
lint fixes
rodzic
04c0212c46
commit
afad06a4cc
|
@ -1,9 +1,8 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import ClassVar
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.plugins import TopicMatcher
|
||||
import jwt
|
||||
|
||||
try:
|
||||
from enum import StrEnum
|
||||
|
@ -13,41 +12,43 @@ except ImportError:
|
|||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.contexts import Action
|
||||
from amqtt.plugins import TopicMatcher
|
||||
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
import jwt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Algorithms(StrEnum):
|
||||
ES256 = 'ES256'
|
||||
ES256K = 'ES256K'
|
||||
ES384 = 'ES384'
|
||||
ES512 = 'ES512'
|
||||
ES521 = 'ES521'
|
||||
EdDSA = 'EdDSA'
|
||||
HS256 = 'HS256'
|
||||
HS384 = 'HS384'
|
||||
HS512 = 'HS512'
|
||||
PS256 = 'PS256'
|
||||
PS384 = 'PS384'
|
||||
PS512 = 'PS512'
|
||||
RS256 = 'RS256'
|
||||
RS384 = 'RS384'
|
||||
RS512 = 'RS512'
|
||||
ES256 = "ES256"
|
||||
ES256K = "ES256K"
|
||||
ES384 = "ES384"
|
||||
ES512 = "ES512"
|
||||
ES521 = "ES521"
|
||||
EdDSA = "EdDSA"
|
||||
HS256 = "HS256"
|
||||
HS384 = "HS384"
|
||||
HS512 = "HS512"
|
||||
PS256 = "PS256"
|
||||
PS384 = "PS384"
|
||||
PS512 = "PS512"
|
||||
RS256 = "RS256"
|
||||
RS384 = "RS384"
|
||||
RS512 = "RS512"
|
||||
|
||||
|
||||
class UserAuthJwtPlugin(BaseAuthPlugin):
|
||||
|
||||
async def authenticate(self, *, session: Session) -> bool | None:
|
||||
|
||||
if not session.username or not session.password:
|
||||
return None
|
||||
|
||||
try:
|
||||
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
|
||||
return decoded_payload.get(self.config.user_claim, None) == session.username
|
||||
return bool(decoded_payload.get(self.config.user_claim, None) == session.username)
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(f"jwt for '{session.username}' is expired")
|
||||
return False
|
||||
|
@ -57,6 +58,8 @@ class UserAuthJwtPlugin(BaseAuthPlugin):
|
|||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration for the JWT user authentication."""
|
||||
|
||||
secret_key: str
|
||||
"""Secret key to decrypt the token."""
|
||||
user_claim: str
|
||||
|
@ -69,9 +72,9 @@ class UserAuthJwtPlugin(BaseAuthPlugin):
|
|||
class TopicAuthJwtPlugin(BaseTopicPlugin):
|
||||
|
||||
_topic_jwt_claims: ClassVar = {
|
||||
Action.PUBLISH: 'publish_claim',
|
||||
Action.SUBSCRIBE: 'subscribe_claim',
|
||||
Action.RECEIVE: 'receive_claim',
|
||||
Action.PUBLISH: "publish_claim",
|
||||
Action.SUBSCRIBE: "subscribe_claim",
|
||||
Action.RECEIVE: "receive_claim",
|
||||
}
|
||||
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
|
@ -83,11 +86,14 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
|
|||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||
) -> bool | None:
|
||||
|
||||
if not any([session, topic, action]):
|
||||
if not session or not topic or not action:
|
||||
return None
|
||||
|
||||
if not session.password:
|
||||
return None
|
||||
|
||||
try:
|
||||
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
|
||||
decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"])
|
||||
claim = getattr(self.config, self._topic_jwt_claims[action])
|
||||
return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
|
||||
except jwt.ExpiredSignatureError:
|
||||
|
@ -99,6 +105,8 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
|
|||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Configuration for the JWT topic authorization."""
|
||||
|
||||
secret_key: str
|
||||
"""Secret key to decrypt the token."""
|
||||
publish_claim: str
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
try:
|
||||
from datetime import UTC, datetime, timedelta
|
||||
except ImportError:
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
UTC = timezone.utc
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
|
@ -18,9 +24,10 @@ from amqtt.session import Session
|
|||
def secret_key():
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exp_time, outcome", [
|
||||
(datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), True),
|
||||
(datetime.datetime.now(datetime.UTC) - datetime.timedelta(hours=1), False),
|
||||
(datetime.now(UTC) + timedelta(hours=1), True),
|
||||
(datetime.now(UTC) - timedelta(hours=1), False),
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_jwt_plugin(secret_key, exp_time, outcome):
|
||||
|
@ -50,7 +57,7 @@ async def test_topic_jwt_plugin(secret_key):
|
|||
|
||||
payload = {
|
||||
"username": "example_user",
|
||||
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1),
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
"publish_acl": ['my/topic/#', 'my/+/other']
|
||||
}
|
||||
|
||||
|
@ -75,7 +82,7 @@ async def test_topic_jwt_plugin(secret_key):
|
|||
async def test_broker_with_jwt_plugin(secret_key, caplog):
|
||||
payload = {
|
||||
"username": "example_user",
|
||||
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1),
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
"publish_acl": ['my/topic/#', 'my/+/other'],
|
||||
"subscribe_acl": ['my/+/other'],
|
||||
}
|
||||
|
|
Ładowanie…
Reference in New Issue