diff --git a/amqtt/contrib/jwt.py b/amqtt/contrib/jwt.py index fe4f728..dcdaeb9 100644 --- a/amqtt/contrib/jwt.py +++ b/amqtt/contrib/jwt.py @@ -88,8 +88,8 @@ class TopicAuthJwtPlugin(BaseTopicPlugin): try: decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"]) - return any(self.topic_matcher.is_topic_allowed(topic, a_filter) - for a_filter in decoded_payload.get(self._topic_jwt_claims, [])) + claim = getattr(self.config, self._topic_jwt_claims[action]) + return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, [])) except jwt.ExpiredSignatureError: logger.debug(f"jwt for '{session.username}' is expired") return False @@ -102,9 +102,11 @@ class TopicAuthJwtPlugin(BaseTopicPlugin): secret_key: str """Secret key to decrypt the token.""" publish_claim: str - """Payload key for list of .""" + """Payload key for contains a list of permissible publish topics.""" subscribe_claim: str + """Payload key for contains a list of permissible subscribe topics.""" receive_claim: str + """Payload key for contains a list of permissible receive topics.""" algorithm: str = "HS256" """Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256', 'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'""" diff --git a/tests/contrib/test_jwt.py b/tests/contrib/test_jwt.py index ec40a52..2a84065 100644 --- a/tests/contrib/test_jwt.py +++ b/tests/contrib/test_jwt.py @@ -1,11 +1,16 @@ +import asyncio import datetime +import logging import secrets import jwt import pytest from amqtt.broker import BrokerContext, Broker -from amqtt.contrib.jwt import UserAuthJwtPlugin +from amqtt.client import MQTTClient +from amqtt.contexts import Action, ListenerConfig, BrokerConfig +from amqtt.contrib.jwt import UserAuthJwtPlugin, TopicAuthJwtPlugin +from amqtt.mqtt.constants import QOS_0 from amqtt.session import Session @@ -38,3 +43,77 @@ async def test_user_jwt_plugin(secret_key, exp_time, outcome): s.password = jwt.encode(payload, secret_key, algorithm="HS256") assert await jwt_plugin.authenticate(session=s) == outcome, "access should have been granted" + + +@pytest.mark.asyncio +async def test_topic_jwt_plugin(secret_key): + + payload = { + "username": "example_user", + "exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), + "publish_acl": ['my/topic/#', 'my/+/other'] + } + + ctx = BrokerContext(Broker()) + ctx.config = TopicAuthJwtPlugin.Config( + secret_key=secret_key, + publish_claim='publish_acl', + subscribe_claim='subscribe_acl', + receive_claim='receive_acl' + ) + + jwt_plugin = TopicAuthJwtPlugin(context=ctx) + + s = Session() + s.username = "example_user" + s.password = jwt.encode(payload, secret_key, algorithm="HS256") + + assert await jwt_plugin.topic_filtering(session=s, topic="my/topic/one", action=Action.PUBLISH), "access should be granted" + + +@pytest.mark.asyncio +async def test_broker_with_jwt_plugin(secret_key, caplog): + payload = { + "username": "example_user", + "exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), + "publish_acl": ['my/topic/#', 'my/+/other'], + "subscribe_acl": ['my/+/other'], + } + username = "example_user" + password = jwt.encode(payload, secret_key, algorithm="HS256") + + cfg = BrokerConfig( + listeners={'default': ListenerConfig()}, + plugins={ + 'amqtt.contrib.jwt.UserAuthJwtPlugin': { + 'secret_key': secret_key, + 'user_claim': 'username', + }, + 'amqtt.contrib.jwt.TopicAuthJwtPlugin': { + 'secret_key': secret_key, + 'publish_claim': 'publish_acl', + 'subscribe_claim': 'subscribe_acl', + 'receive_claim': 'receive_acl' + } + } + ) + with caplog.at_level(logging.INFO): + b = Broker(config=cfg) + await b.start() + await asyncio.sleep(0.1) + + c = MQTTClient() + await c.connect(f'mqtt://{username}:{password}@localhost:1883') + await asyncio.sleep(0.1) + result = await c.subscribe([('my/one', QOS_0)]) + assert result == [128, ] + result = await c.subscribe([('my/one/other', QOS_0)]) + assert result == [0] + await c.publish('my/one', b'message should not get published') + await asyncio.sleep(0.1) + assert "not allowed to publish to TOPIC my/one" in caplog.text + await asyncio.sleep(0.1) + + await c.disconnect() + await asyncio.sleep(0.1) + await b.shutdown()