From 0575f0e04167b4ea563bd90ce7a6cd34a54f84c4 Mon Sep 17 00:00:00 2001 From: Andrew Mirsky Date: Tue, 8 Jul 2025 22:05:58 -0400 Subject: [PATCH] Yakifo/amqtt#258 (in progress): a plugin which requests topic acl via http --- amqtt/broker.py | 8 ++- amqtt/contexts.py | 1 + amqtt/contrib/http_acl.py | 18 ++++- amqtt/plugins/base.py | 4 +- amqtt/plugins/topic_checking.py | 32 +++++++-- docs/custom_plugins.md | 23 ++++-- docs/packaged_plugins.md | 13 ++-- tests/contrib/test_http_acl.py | 124 ++++++++++++++++++++++++++++---- 8 files changed, 188 insertions(+), 35 deletions(-) diff --git a/amqtt/broker.py b/amqtt/broker.py index e9740f6..b0c835a 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -681,7 +681,7 @@ class Broker: permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH) if not permitted: - self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.") + self.logger.info(f"{client_session.client_id} not allowed to publish to TOPIC {app_message.topic}.") else: await self.plugins_manager.fire_event( BrokerEvents.MESSAGE_RECEIVED, @@ -892,6 +892,12 @@ class Broker: for target_session, sub_qos in subscriptions: qos = broadcast.get("qos", sub_qos) + sendable = await self._topic_filtering(target_session, topic=broadcast["topic"], action=Action.RECEIVE) + if not sendable: + self.logger.info( + f"{target_session.client_id} not allowed to receive messages from TOPIC {broadcast['topic']}.") + continue + # Retain all messages which cannot be broadcasted, due to the session not being connected # but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4] # and, if a client used anonymous authentication, there is no expectation that messages should be retained diff --git a/amqtt/contexts.py b/amqtt/contexts.py index 3a78b44..26442d2 100644 --- a/amqtt/contexts.py +++ b/amqtt/contexts.py @@ -20,3 +20,4 @@ class Action(Enum): SUBSCRIBE = "subscribe" PUBLISH = "publish" + RECEIVE = "receive" diff --git a/amqtt/contrib/http_acl.py b/amqtt/contrib/http_acl.py index 7891cd5..302fa5f 100644 --- a/amqtt/contrib/http_acl.py +++ b/amqtt/contrib/http_acl.py @@ -97,8 +97,20 @@ class HttpAuthACL(BaseAuthPlugin, BaseTopicPlugin): async def topic_filtering(self, *, session: Session | None = None, topic: str | None = None, - action: Action | None = None) -> bool: - return False + action: Action | None = None) -> bool | None: + if not session: + return None + acc = 0 + match action: + case Action.PUBLISH: + acc = 2 + case Action.SUBSCRIBE: + acc = 4 + case Action.RECEIVE: + acc = 1 + + d = {"username": session.username, "client_id": session.client_id, "topic": topic, "acc": acc} + return await self._send_request(self.get_url(self.config.acl_uri), d) @dataclass class Config: @@ -128,7 +140,7 @@ class HttpAuthACL(BaseAuthPlugin, BaseTopicPlugin): - username *(str)* - client_id *(str)* - topic *(str)* - - acc *(int)* read only = 1, write only = 2, read & write = 3 and subscribe = 4 + - acc *(int)* client can receive (1), can publish(2), can receive & publish (3) and can subscribe (4) """ host: str diff --git a/amqtt/plugins/base.py b/amqtt/plugins/base.py index 41b951b..eb5c637 100644 --- a/amqtt/plugins/base.py +++ b/amqtt/plugins/base.py @@ -88,7 +88,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]): async def topic_filtering( self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None - ) -> bool: + ) -> bool | None: """Logic for filtering out topics. Args: @@ -97,7 +97,7 @@ class BaseTopicPlugin(BasePlugin[BaseContext]): action: amqtt.broker.Action Returns: - bool: `True` if topic is allowed, `False` otherwise + bool: `True` if topic is allowed, `False` otherwise. `None` if it can't be determined """ return bool(self.topic_config) or is_dataclass(self.context.config) diff --git a/amqtt/plugins/topic_checking.py b/amqtt/plugins/topic_checking.py index c5183ea..7f06100 100644 --- a/amqtt/plugins/topic_checking.py +++ b/amqtt/plugins/topic_checking.py @@ -1,7 +1,9 @@ from dataclasses import dataclass, field from typing import Any +import warnings from amqtt.contexts import Action, BaseContext +from amqtt.errors import PluginInitError from amqtt.plugins.base import BaseTopicPlugin from amqtt.session import Session @@ -13,7 +15,7 @@ class TopicTabooPlugin(BaseTopicPlugin): async def topic_filtering( self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None - ) -> bool: + ) -> bool | None: filter_result = await super().topic_filtering(session=session, topic=topic, action=action) if filter_result: if session and session.username == "admin": @@ -24,6 +26,16 @@ class TopicTabooPlugin(BaseTopicPlugin): class TopicAccessControlListPlugin(BaseTopicPlugin): + def __init__(self, context: BaseContext) -> None: + super().__init__(context) + + if self._get_config_option("acl", None): + warnings.warn("The 'acl' option is deprecated, please use 'subscribe-acl' instead.", stacklevel=1) + + if self._get_config_option("acl", None) and self._get_config_option("subscribe-acl", None): + msg = "'acl' has been replaced with 'subscribe-acl'; only one may be included" + raise PluginInitError(msg) + @staticmethod def topic_ac(topic_requested: str, topic_allowed: str) -> bool: req_split = topic_requested.split("/") @@ -46,7 +58,7 @@ class TopicAccessControlListPlugin(BaseTopicPlugin): async def topic_filtering( self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None - ) -> bool: + ) -> bool | None: filter_result = await super().topic_filtering(session=session, topic=topic, action=action) if not filter_result: return False @@ -58,18 +70,26 @@ class TopicAccessControlListPlugin(BaseTopicPlugin): req_topic = topic if not req_topic: - return False\ + return False username = session.username if session else None if username is None: username = "anonymous" - acl: dict[str, Any] = {} + acl: dict[str, Any] | None = None match action: case Action.PUBLISH: - acl = self._get_config_option("publish-acl", {}) + acl = self._get_config_option("publish-acl", None) case Action.SUBSCRIBE: - acl = self._get_config_option("acl", {}) + acl = self._get_config_option("subscribe-acl", self._get_config_option("acl", None)) + case Action.RECEIVE: + acl = self._get_config_option("receive-acl", None) + case _: + msg = "Received an invalid action type." + raise ValueError(msg) + + if acl is None: + return True allowed_topics = acl.get(username, []) if not allowed_topics: diff --git a/docs/custom_plugins.md b/docs/custom_plugins.md index 346babb..c8f0c55 100644 --- a/docs/custom_plugins.md +++ b/docs/custom_plugins.md @@ -96,16 +96,31 @@ none In addition to receiving any of the event callbacks, a plugin which subclasses from `BaseAuthPlugin` is used by the aMQTT `Broker` to determine if a connection from a client is allowed by -implementing the `authenticate` method and returning `True` if the session is allowed or `False` otherwise. +implementing the `authenticate` method and returning: +- `True` if the session is allowed +- `False` if not allowed +- `None` if plugin can't determine authentication + +If there are multiple authentication plugins: +- at least one plugin must return `True` to allow access +- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access) +- `None` gets ignored from the determination ::: amqtt.plugins.base.BaseAuthPlugin ## Topic Filter Plugins In addition to receiving any of the event callbacks, a plugin which is subclassed from `BaseTopicPlugin` -is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH) or receive (SUBSCRIBE) -messages to a particular topic by implementing the `topic_filtering` method and returning `True` if allowed or -`False` otherwise. +is used by the aMQTT `Broker` to determine if a connected client can send (PUBLISH), receive (RECEIVE) +and/or subscribe (SUBSCRIBE) messages to a particular topic by implementing the `topic_filtering` method and returning: +- `True` if topic is allowed +- `False` if not allowed +- `None` will be ignored + +If there are multiple topic plugins: +- at least one plugin must return `True` to allow access +- `False` from any plugin will deny access (i.e. all plugins must return `True` to allow access) +- `None` will be ignored ::: amqtt.plugins.base.BaseTopicPlugin diff --git a/docs/packaged_plugins.md b/docs/packaged_plugins.md index bc43d3a..0bc23d9 100644 --- a/docs/packaged_plugins.md +++ b/docs/packaged_plugins.md @@ -134,13 +134,18 @@ plugins: **Configuration** -- `acl` *(mapping)*: determines subscription access +- `subscribe-acl` *(mapping)*: determines subscription access. If absent, no restrictions are placed on client subscriptions. The list should be a key-value pair, where: - `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). + `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). +- `acl` *(mapping)*: Deprecated and replaced by `subscribe-acl`. -- `publish-acl` *(mapping)*: determines publish access. If absent, no restrictions are placed on client publishing. - `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). +- `publish-acl` *(mapping)*: determines publish access. If absent, no restrictions are placed on client publishing. + `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). + +- `receive-acl` *(mapping)*: determines if a message can be sent. If absent, no restrictions are placed on client's receiving messages. + `:[, , ...]` *(string, list[string])*: username of the client followed by a list of allowed topics (wildcards are supported: `#`, `+`). + !!! info "Reserved usernames" diff --git a/tests/contrib/test_http_acl.py b/tests/contrib/test_http_acl.py index 5401047..bbaf85e 100644 --- a/tests/contrib/test_http_acl.py +++ b/tests/contrib/test_http_acl.py @@ -9,13 +9,14 @@ from aiohttp import web from aiohttp.web import Response, Request from amqtt.broker import BrokerContext, Broker +from amqtt.contexts import Action from amqtt.contrib.http_acl import HttpAuthACL, ParamsMode, ResponseMode, RequestMethod from amqtt.session import Session logger = logging.getLogger(__name__) -def determine_response_mode(d) -> Response: +def determine_auth_response_mode(d) -> Response: assert 'username' in d assert 'password' in d assert 'client_id' in d @@ -31,29 +32,29 @@ class JsonAuthView(web.View): async def get(self) -> Response: d = await self.request.json() - return determine_response_mode(d) + return determine_auth_response_mode(d) async def post(self) -> Response: d = dict(await self.request.json()) - return determine_response_mode(d) + return determine_auth_response_mode(d) async def put(self) -> Response: d = dict(await self.request.json()) - return determine_response_mode(d) + return determine_auth_response_mode(d) class FormAuthView(web.View): async def get(self) -> Response: d = self.request.query - return determine_response_mode(d) + return determine_auth_response_mode(d) async def post(self) -> Response: d = dict(await self.request.post()) - return determine_response_mode(d) + return determine_auth_response_mode(d) async def put(self) -> Response: d = dict(await self.request.post()) - return determine_response_mode(d) + return determine_auth_response_mode(d) @pytest.fixture @@ -64,7 +65,7 @@ async def empty_broker(): @pytest_asyncio.fixture -async def http_acl_server(): +async def http_auth_server(): app = web.Application() app.add_routes([ web.view('/user/json', JsonAuthView), @@ -80,10 +81,10 @@ async def http_acl_server(): await runner.cleanup() -def test_server_up_and_down(http_acl_server): +def test_server_up_and_down(http_auth_server): pass -def generate_cases(): +def generate_use_cases(root_url): # generate all variations of: # ('/user/json', RequestMethod.GET, ParamsMode.JSON, ResponseMode.JSON, 'json', 'json', True), @@ -91,22 +92,22 @@ def generate_cases(): for request in RequestMethod: for params in ParamsMode: for response in ResponseMode: - url = '/user/json' if params == ParamsMode.JSON else '/user/form' + url = f'/{root_url}/json' if params == ParamsMode.JSON else f'/{root_url}/form' for is_authenticated in [True, False]: prefix = '' if is_authenticated else 'not' case = (url, request, params, response, str(response), f"{prefix}{str(response)}", is_authenticated) cases.append(case) return cases -def test_generated_cases(): - cases = generate_cases() +def test_generated_use_cases(): + cases = generate_use_cases('user') assert len(cases) == 36 @pytest.mark.parametrize("url,request_method,params_mode,response_mode,username,password,is_authenticated", - generate_cases()) + generate_use_cases('user')) @pytest.mark.asyncio -async def test_request_params_response(empty_broker, http_acl_server, url, +async def test_request_auth_response(empty_broker, http_auth_server, url, request_method, params_mode, response_mode, username, password, is_authenticated): @@ -129,3 +130,96 @@ async def test_request_params_response(empty_broker, http_acl_server, url, assert await http_acl.authenticate(session=session) == is_authenticated await http_acl.on_broker_pre_shutdown() + + +def determine_acl_response(d) -> Response: + assert 'username' in d + assert 'client_id' in d + assert 'topic' in d + assert 'acc' in d + if d['username'] == 'json': + return web.json_response({'Ok': d['username'] == d['client_id']}) + elif d['username'] == 'status': + return web.Response(status=200) if d['username'] == d['client_id'] else web.Response(status=400) + else: # text + return web.Response(text='ok' if d['username'] == d['client_id'] else 'error') + + +class JsonACLView(web.View): + + async def get(self) -> Response: + d = await self.request.json() + return determine_acl_response(d) + + async def post(self) -> Response: + d = dict(await self.request.json()) + return determine_acl_response(d) + + async def put(self) -> Response: + d = dict(await self.request.json()) + return determine_acl_response(d) + + +class FormACLView(web.View): + + async def get(self) -> Response: + d = self.request.query + return determine_acl_response(d) + + async def post(self) -> Response: + d = dict(await self.request.post()) + return determine_acl_response(d) + + async def put(self) -> Response: + d = dict(await self.request.post()) + return determine_acl_response(d) + + +@pytest_asyncio.fixture +async def http_acl_server(): + app = web.Application() + app.add_routes([ + web.view('/acl/json', JsonACLView), + web.view('/acl/form', FormACLView), + ]) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8080) + await site.start() + + yield f"http://localhost:8080" + + await runner.cleanup() + + +@pytest.mark.parametrize("url,request_method,params_mode,response_mode,username,client_id,is_authenticated", + generate_use_cases('acl')) +@pytest.mark.asyncio +async def test_request_acl_response(empty_broker, http_acl_server, url, + request_method, params_mode, response_mode, + username, client_id, is_authenticated): + + # url = '/acl/json' + # request_method = RequestMethod.GET + # params_mode = ParamsMode.JSON + # response_mode = ResponseMode.JSON + + context = BrokerContext(broker=empty_broker) + context.config = HttpAuthACL.Config( + host="localhost", + port=8080, + user_uri='/user', + acl_uri=url, + request_method=request_method, + params_mode=params_mode, + response_mode=response_mode, + ) + http_acl = HttpAuthACL(context) + + s = Session() + s.username = username + s.client_id = client_id + t = 'my/topic' + a = Action.PUBLISH + + assert await http_acl.topic_filtering(session=s, topic=t, action=a) == is_authenticated \ No newline at end of file