kopia lustrzana https://github.com/Yakifo/amqtt
lint fixes
rodzic
04c0212c46
commit
afad06a4cc
|
@ -1,9 +1,8 @@
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
from amqtt.broker import BrokerContext
|
import jwt
|
||||||
from amqtt.plugins import TopicMatcher
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
@ -13,41 +12,43 @@ except ImportError:
|
||||||
class StrEnum(str, Enum): #type: ignore[no-redef]
|
class StrEnum(str, Enum): #type: ignore[no-redef]
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
from amqtt.broker import BrokerContext
|
||||||
from amqtt.contexts import Action
|
from amqtt.contexts import Action
|
||||||
|
from amqtt.plugins import TopicMatcher
|
||||||
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
|
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
|
||||||
from amqtt.session import Session
|
from amqtt.session import Session
|
||||||
|
|
||||||
import jwt
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Algorithms(StrEnum):
|
class Algorithms(StrEnum):
|
||||||
ES256 = 'ES256'
|
ES256 = "ES256"
|
||||||
ES256K = 'ES256K'
|
ES256K = "ES256K"
|
||||||
ES384 = 'ES384'
|
ES384 = "ES384"
|
||||||
ES512 = 'ES512'
|
ES512 = "ES512"
|
||||||
ES521 = 'ES521'
|
ES521 = "ES521"
|
||||||
EdDSA = 'EdDSA'
|
EdDSA = "EdDSA"
|
||||||
HS256 = 'HS256'
|
HS256 = "HS256"
|
||||||
HS384 = 'HS384'
|
HS384 = "HS384"
|
||||||
HS512 = 'HS512'
|
HS512 = "HS512"
|
||||||
PS256 = 'PS256'
|
PS256 = "PS256"
|
||||||
PS384 = 'PS384'
|
PS384 = "PS384"
|
||||||
PS512 = 'PS512'
|
PS512 = "PS512"
|
||||||
RS256 = 'RS256'
|
RS256 = "RS256"
|
||||||
RS384 = 'RS384'
|
RS384 = "RS384"
|
||||||
RS512 = 'RS512'
|
RS512 = "RS512"
|
||||||
|
|
||||||
|
|
||||||
class UserAuthJwtPlugin(BaseAuthPlugin):
|
class UserAuthJwtPlugin(BaseAuthPlugin):
|
||||||
|
|
||||||
async def authenticate(self, *, session: Session) -> bool | None:
|
async def authenticate(self, *, session: Session) -> bool | None:
|
||||||
|
|
||||||
|
if not session.username or not session.password:
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
|
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
|
||||||
return decoded_payload.get(self.config.user_claim, None) == session.username
|
return bool(decoded_payload.get(self.config.user_claim, None) == session.username)
|
||||||
except jwt.ExpiredSignatureError:
|
except jwt.ExpiredSignatureError:
|
||||||
logger.debug(f"jwt for '{session.username}' is expired")
|
logger.debug(f"jwt for '{session.username}' is expired")
|
||||||
return False
|
return False
|
||||||
|
@ -57,6 +58,8 @@ class UserAuthJwtPlugin(BaseAuthPlugin):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
|
"""Configuration for the JWT user authentication."""
|
||||||
|
|
||||||
secret_key: str
|
secret_key: str
|
||||||
"""Secret key to decrypt the token."""
|
"""Secret key to decrypt the token."""
|
||||||
user_claim: str
|
user_claim: str
|
||||||
|
@ -69,9 +72,9 @@ class UserAuthJwtPlugin(BaseAuthPlugin):
|
||||||
class TopicAuthJwtPlugin(BaseTopicPlugin):
|
class TopicAuthJwtPlugin(BaseTopicPlugin):
|
||||||
|
|
||||||
_topic_jwt_claims: ClassVar = {
|
_topic_jwt_claims: ClassVar = {
|
||||||
Action.PUBLISH: 'publish_claim',
|
Action.PUBLISH: "publish_claim",
|
||||||
Action.SUBSCRIBE: 'subscribe_claim',
|
Action.SUBSCRIBE: "subscribe_claim",
|
||||||
Action.RECEIVE: 'receive_claim',
|
Action.RECEIVE: "receive_claim",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, context: BrokerContext) -> None:
|
def __init__(self, context: BrokerContext) -> None:
|
||||||
|
@ -83,11 +86,14 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
|
||||||
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
|
||||||
) -> bool | None:
|
) -> bool | None:
|
||||||
|
|
||||||
if not any([session, topic, action]):
|
if not session or not topic or not action:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not session.password:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
|
decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"])
|
||||||
claim = getattr(self.config, self._topic_jwt_claims[action])
|
claim = getattr(self.config, self._topic_jwt_claims[action])
|
||||||
return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
|
return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
|
||||||
except jwt.ExpiredSignatureError:
|
except jwt.ExpiredSignatureError:
|
||||||
|
@ -99,6 +105,8 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
|
"""Configuration for the JWT topic authorization."""
|
||||||
|
|
||||||
secret_key: str
|
secret_key: str
|
||||||
"""Secret key to decrypt the token."""
|
"""Secret key to decrypt the token."""
|
||||||
publish_claim: str
|
publish_claim: str
|
||||||
|
|
|
@ -1,8 +1,14 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
|
try:
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
except ImportError:
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
|
||||||
|
UTC = timezone.utc
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -18,9 +24,10 @@ from amqtt.session import Session
|
||||||
def secret_key():
|
def secret_key():
|
||||||
return secrets.token_urlsafe(32)
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("exp_time, outcome", [
|
@pytest.mark.parametrize("exp_time, outcome", [
|
||||||
(datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), True),
|
(datetime.now(UTC) + timedelta(hours=1), True),
|
||||||
(datetime.datetime.now(datetime.UTC) - datetime.timedelta(hours=1), False),
|
(datetime.now(UTC) - timedelta(hours=1), False),
|
||||||
])
|
])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_jwt_plugin(secret_key, exp_time, outcome):
|
async def test_user_jwt_plugin(secret_key, exp_time, outcome):
|
||||||
|
@ -50,7 +57,7 @@ async def test_topic_jwt_plugin(secret_key):
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"username": "example_user",
|
"username": "example_user",
|
||||||
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1),
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||||
"publish_acl": ['my/topic/#', 'my/+/other']
|
"publish_acl": ['my/topic/#', 'my/+/other']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,7 +82,7 @@ async def test_topic_jwt_plugin(secret_key):
|
||||||
async def test_broker_with_jwt_plugin(secret_key, caplog):
|
async def test_broker_with_jwt_plugin(secret_key, caplog):
|
||||||
payload = {
|
payload = {
|
||||||
"username": "example_user",
|
"username": "example_user",
|
||||||
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1),
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||||
"publish_acl": ['my/topic/#', 'my/+/other'],
|
"publish_acl": ['my/topic/#', 'my/+/other'],
|
||||||
"subscribe_acl": ['my/+/other'],
|
"subscribe_acl": ['my/+/other'],
|
||||||
}
|
}
|
||||||
|
|
Ładowanie…
Reference in New Issue