kopia lustrzana https://github.com/Yakifo/amqtt
Yakifo/amqtt#260 : adding additional test cases
rodzic
2ac9ba043c
commit
04c0212c46
|
@ -88,8 +88,8 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
|
|||
|
||||
try:
|
||||
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
|
||||
return any(self.topic_matcher.is_topic_allowed(topic, a_filter)
|
||||
for a_filter in decoded_payload.get(self._topic_jwt_claims, []))
|
||||
claim = getattr(self.config, self._topic_jwt_claims[action])
|
||||
return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(f"jwt for '{session.username}' is expired")
|
||||
return False
|
||||
|
@ -102,9 +102,11 @@ class TopicAuthJwtPlugin(BaseTopicPlugin):
|
|||
secret_key: str
|
||||
"""Secret key to decrypt the token."""
|
||||
publish_claim: str
|
||||
"""Payload key for list of ."""
|
||||
"""Payload key for contains a list of permissible publish topics."""
|
||||
subscribe_claim: str
|
||||
"""Payload key for contains a list of permissible subscribe topics."""
|
||||
receive_claim: str
|
||||
"""Payload key for contains a list of permissible receive topics."""
|
||||
algorithm: str = "HS256"
|
||||
"""Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256',
|
||||
'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
from amqtt.broker import BrokerContext, Broker
|
||||
from amqtt.contrib.jwt import UserAuthJwtPlugin
|
||||
from amqtt.client import MQTTClient
|
||||
from amqtt.contexts import Action, ListenerConfig, BrokerConfig
|
||||
from amqtt.contrib.jwt import UserAuthJwtPlugin, TopicAuthJwtPlugin
|
||||
from amqtt.mqtt.constants import QOS_0
|
||||
from amqtt.session import Session
|
||||
|
||||
|
||||
|
@ -38,3 +43,77 @@ async def test_user_jwt_plugin(secret_key, exp_time, outcome):
|
|||
s.password = jwt.encode(payload, secret_key, algorithm="HS256")
|
||||
|
||||
assert await jwt_plugin.authenticate(session=s) == outcome, "access should have been granted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_topic_jwt_plugin(secret_key):
|
||||
|
||||
payload = {
|
||||
"username": "example_user",
|
||||
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1),
|
||||
"publish_acl": ['my/topic/#', 'my/+/other']
|
||||
}
|
||||
|
||||
ctx = BrokerContext(Broker())
|
||||
ctx.config = TopicAuthJwtPlugin.Config(
|
||||
secret_key=secret_key,
|
||||
publish_claim='publish_acl',
|
||||
subscribe_claim='subscribe_acl',
|
||||
receive_claim='receive_acl'
|
||||
)
|
||||
|
||||
jwt_plugin = TopicAuthJwtPlugin(context=ctx)
|
||||
|
||||
s = Session()
|
||||
s.username = "example_user"
|
||||
s.password = jwt.encode(payload, secret_key, algorithm="HS256")
|
||||
|
||||
assert await jwt_plugin.topic_filtering(session=s, topic="my/topic/one", action=Action.PUBLISH), "access should be granted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broker_with_jwt_plugin(secret_key, caplog):
|
||||
payload = {
|
||||
"username": "example_user",
|
||||
"exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1),
|
||||
"publish_acl": ['my/topic/#', 'my/+/other'],
|
||||
"subscribe_acl": ['my/+/other'],
|
||||
}
|
||||
username = "example_user"
|
||||
password = jwt.encode(payload, secret_key, algorithm="HS256")
|
||||
|
||||
cfg = BrokerConfig(
|
||||
listeners={'default': ListenerConfig()},
|
||||
plugins={
|
||||
'amqtt.contrib.jwt.UserAuthJwtPlugin': {
|
||||
'secret_key': secret_key,
|
||||
'user_claim': 'username',
|
||||
},
|
||||
'amqtt.contrib.jwt.TopicAuthJwtPlugin': {
|
||||
'secret_key': secret_key,
|
||||
'publish_claim': 'publish_acl',
|
||||
'subscribe_claim': 'subscribe_acl',
|
||||
'receive_claim': 'receive_acl'
|
||||
}
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
b = Broker(config=cfg)
|
||||
await b.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
c = MQTTClient()
|
||||
await c.connect(f'mqtt://{username}:{password}@localhost:1883')
|
||||
await asyncio.sleep(0.1)
|
||||
result = await c.subscribe([('my/one', QOS_0)])
|
||||
assert result == [128, ]
|
||||
result = await c.subscribe([('my/one/other', QOS_0)])
|
||||
assert result == [0]
|
||||
await c.publish('my/one', b'message should not get published')
|
||||
await asyncio.sleep(0.1)
|
||||
assert "not allowed to publish to TOPIC my/one" in caplog.text
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await c.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
await b.shutdown()
|
||||
|
|
Ładowanie…
Reference in New Issue