amqtt/tests/contrib/test_db_plugin.py

505 wiersze
17 KiB
Python

import asyncio
import sqlite3
import tempfile
from pathlib import Path
import pytest
import aiosqlite
from passlib.context import CryptContext
from amqtt.broker import BrokerContext, Broker
from amqtt.client import MQTTClient
from amqtt.contexts import Action
from amqtt.contrib.auth_db.models import AllowedTopic, PasswordHasher
from amqtt.contrib.auth_db.plugin import UserAuthDBPlugin, TopicAuthDBPlugin
from amqtt.contrib.auth_db.managers import UserManager, TopicManager
from amqtt.errors import ConnectError, MQTTError
from amqtt.mqtt.constants import QOS_1, QOS_0
from amqtt.session import Session
from argon2 import PasswordHasher as ArgonPasswordHasher
from argon2.exceptions import VerifyMismatchError
@pytest.fixture
def password_hasher():
pwd_hasher = PasswordHasher()
pwd_hasher.crypt_context = CryptContext(schemes=["argon2", ], deprecated="auto")
yield pwd_hasher
@pytest.fixture
def db_file():
with tempfile.TemporaryDirectory() as temp_dir:
with tempfile.NamedTemporaryFile(mode='wb', delete=True) as tmp:
yield Path(temp_dir) / f"{tmp.name}.db"
@pytest.fixture
def db_connection(db_file):
test_db_connect = f"sqlite+aiosqlite:///{db_file}"
yield test_db_connect
@pytest.fixture
@pytest.mark.asyncio
async def user_manager(password_hasher, db_connection):
um = UserManager(db_connection)
await um.db_sync()
yield um
@pytest.fixture
@pytest.mark.asyncio
async def topic_manager(password_hasher, db_connection):
tm = TopicManager(db_connection)
await tm.db_sync()
yield tm
# ######################################
# Tests for the UserAuthDBPlugin
@pytest.mark.asyncio
async def test_create_user(user_manager, db_file, db_connection):
await user_manager.create_user_auth("myuser", "mypassword")
async with aiosqlite.connect(db_file) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
has_user = False
async with await db_conn.execute("SELECT * FROM user_auth") as cursor:
for row in await cursor.fetchall():
assert row['username'] == "myuser"
assert row['password_hash'] != "mypassword"
assert '$argon2' in row['password_hash']
ph = ArgonPasswordHasher()
ph.verify(row['password_hash'], "mypassword")
with pytest.raises(VerifyMismatchError):
ph.verify(row['password_hash'], "mywrongpassword")
has_user = True
assert has_user, "user was not created"
@pytest.mark.asyncio
async def test_list_users(user_manager, db_file, db_connection):
await user_manager.create_user_auth("myuser", "mypassword")
await user_manager.create_user_auth("otheruser", "mypassword")
await user_manager.create_user_auth("anotheruser", "mypassword")
assert len(list(await user_manager.list_user_auths())) == 3
@pytest.mark.asyncio
async def test_list_empty_users(user_manager, db_file, db_connection):
assert len(list(await user_manager.list_user_auths())) == 0
@pytest.mark.asyncio
async def test_password_change(user_manager, db_file, db_connection):
new_user = await user_manager.create_user_auth("myuser", "mypassword")
await user_manager.update_user_auth_password("myuser", "mynewpassword")
async with aiosqlite.connect(db_file) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
has_user = False
async with await db_conn.execute("SELECT * FROM user_auth") as cursor:
for row in await cursor.fetchall():
assert row['password_hash'] != new_user._password_hash
ph = ArgonPasswordHasher()
with pytest.raises(VerifyMismatchError):
ph.verify(row['password_hash'], "mypassword")
has_user = True
assert has_user, "user was not found"
@pytest.mark.asyncio
async def test_remove_users(user_manager, db_file, db_connection):
await user_manager.create_user_auth("myuser", "mypassword")
await user_manager.create_user_auth("otheruser", "mypassword")
await user_manager.create_user_auth("anotheruser", "mypassword")
assert len(list(await user_manager.list_user_auths())) == 3
await user_manager.delete_user_auth("myuser")
assert len(list(await user_manager.list_user_auths())) == 2
async with aiosqlite.connect(db_file) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
test_run = False
async with await db_conn.execute("SELECT * FROM user_auth") as cursor:
for row in await cursor.fetchall():
assert row['username'] in ("otheruser", "anotheruser")
test_run = True
assert test_run, "users weren't not found"
@pytest.mark.parametrize("user_pwd,session_pwd,outcome", [
("mypassword", "mypassword", True),
("mypassword", "myotherpassword", False),
])
@pytest.mark.asyncio
async def test_db_auth(db_connection, user_manager, user_pwd, session_pwd, outcome):
await user_manager.create_user_auth("myuser", user_pwd)
broker_context = BrokerContext(broker=Broker())
broker_context.config = UserAuthDBPlugin.Config(
connection=db_connection
)
db_auth_plugin = UserAuthDBPlugin(context=broker_context)
s = Session()
s.username = "myuser"
s.password = session_pwd
assert await db_auth_plugin.authenticate(session=s) == outcome
@pytest.mark.asyncio
async def test_client_authentication(user_manager, db_connection):
user = await user_manager.create_user_auth("myuser", "mypassword")
assert user is not None
broker_cfg = {
'listeners': { 'default': {'type': 'tcp', 'bind': '127.0.0.1:1883'}},
'plugins': {
'amqtt.contrib.auth_db.UserAuthDBPlugin': {
'connection': db_connection,
}
}
}
broker = Broker(config=broker_cfg)
await broker.start()
await asyncio.sleep(0.1)
client = MQTTClient(client_id='myclientid', config={'auto_reconnect': False})
await client.connect("mqtt://myuser:mypassword@127.0.0.1:1883")
await client.subscribe([
("my/topic", QOS_1)
])
await asyncio.sleep(0.1)
await client.publish("my/topic", b"test")
await asyncio.sleep(0.1)
message = await client.deliver_message(timeout_duration=1)
assert message.topic == "my/topic"
await asyncio.sleep(0.1)
await client.disconnect()
await asyncio.sleep(0.1)
await broker.shutdown()
@pytest.mark.parametrize("client_pwd", [
("mywrongpassword", ),
("", ),
])
@pytest.mark.asyncio
async def test_client_blocked(user_manager, db_connection, client_pwd):
user = await user_manager.create_user_auth("myuser", "mypassword")
assert user is not None
broker_cfg = {
'listeners': { 'default': {'type': 'tcp', 'bind': '127.0.0.1:1883'}},
'plugins': {
'amqtt.contrib.auth_db.UserAuthDBPlugin': {
'connection': db_connection,
}
}
}
broker = Broker(config=broker_cfg)
await broker.start()
await asyncio.sleep(0.1)
client = MQTTClient(client_id='myclientid', config={'auto_reconnect': False})
with pytest.raises(ConnectError):
await client.connect(f"mqtt://myuser:{client_pwd}@127.0.0.1:1883")
await broker.shutdown()
await asyncio.sleep(0.1)
# ######################################
# Tests for the TopicAuthDBPlugin
def test_allowed_topic_match():
at = AllowedTopic(topic="my/topic")
assert "my/topic" in at
at2 = AllowedTopic(topic="my/other/topic")
assert "my/other/topic" in at2
assert "my/another/topic" not in at2
at3 = AllowedTopic(topic="my/#")
assert "my/other" in at3
assert "my/other/topic" in at3
assert "other/topic" not in at3
at4 = AllowedTopic(topic="my/other/#")
assert "my/other" in at4
assert "my/other/topic" in at4
assert "other/topic" not in at4
assert "/my/other/topic" not in at4
at5 = AllowedTopic(topic="my/+/topic")
assert "my/other/topic" in at5
assert "my/another/topic" in at5
assert "my/other/another/topic" not in at5
at6 = AllowedTopic(topic="my/other/topic")
assert at6 == at2
assert at2 == at6
assert at6 != at
assert at6 not in at5
def test_allowed_topic_list_match():
at1 = AllowedTopic(topic="one/topic")
at2 = AllowedTopic(topic="one/other/topic")
at3 = AllowedTopic(topic="two/+/topic")
at4 = AllowedTopic(topic="three/topic/#")
at_list = [at1, at2, at3, at4]
assert "one/topic" in at_list
assert "two/other/topic" in at_list
assert "two/another/topic" in at_list
assert "three/topic" in at_list
assert "three/topic/other" in at_list
def test_remove_topic_list():
at1 = AllowedTopic(topic="my/topic")
at2 = AllowedTopic(topic="my/other/topic")
at3 = AllowedTopic(topic="my/another/topic")
at_list = [at1, at2, at3]
at4 = AllowedTopic(topic="my/topic")
at5 = AllowedTopic(topic="my/not/topic")
at_list.remove(at4)
with pytest.raises(ValueError):
at_list.remove(at5)
@pytest.mark.asyncio
async def test_add_topic_to_client(db_file, user_manager, topic_manager, db_connection):
client_id = "myuser"
topic_auth = await topic_manager.create_topic_auth(client_id)
assert topic_auth is not None
topic_list = await topic_manager.add_allowed_topic(client_id, "my/topic", Action.PUBLISH)
assert len(topic_list) > 0
async with aiosqlite.connect(db_file) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
user_found = False
async with await db_conn.execute("SELECT * FROM topic_auth") as cursor:
for row in await cursor.fetchall():
assert row['username'] == client_id
assert "my/topic" in row['publish_acl']
user_found = True
assert user_found
@pytest.mark.asyncio
async def test_invalid_dollar_topic_for_publish(db_file, user_manager, topic_manager, db_connection):
client_id = "myuser"
topic_auth = await topic_manager.create_topic_auth(client_id)
assert topic_auth is not None
with pytest.raises(MQTTError):
await topic_manager.add_allowed_topic(client_id, "$MY/topic", Action.PUBLISH)
async with aiosqlite.connect(db_file) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
auth_topic_found = False
async with await db_conn.execute("SELECT * FROM topic_auth") as cursor:
for row in await cursor.fetchall():
assert row['username'] == client_id
assert "$MY/topic" not in row['publish_acl']
auth_topic_found = True
assert auth_topic_found
@pytest.mark.asyncio
async def test_remove_topic_from_client(db_file, user_manager, topic_manager, db_connection):
client_id = "myuser"
topic_auth = await topic_manager.create_topic_auth(client_id)
assert topic_auth is not None
user = await user_manager.create_user_auth(client_id, "mypassword")
assert user is not None
await topic_manager.add_allowed_topic(client_id, "my/topic", Action.PUBLISH)
await topic_manager.add_allowed_topic(client_id, "my/other/topic", Action.PUBLISH)
topic_list = await topic_manager.add_allowed_topic(client_id, "my/another/topic", Action.PUBLISH)
assert len(topic_list) == 3
async with aiosqlite.connect(db_file) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
user_found = False
async with await db_conn.execute("SELECT * FROM topic_auth") as cursor:
for row in await cursor.fetchall():
assert row['username'] == client_id
assert "my/topic" in row['publish_acl']
assert "my/other/topic" in row['publish_acl']
assert "my/another/topic" in row['publish_acl']
user_found = True
assert user_found
topic_list = await topic_manager.remove_allowed_topic(client_id, "my/other/topic", Action.PUBLISH)
assert len(topic_list) == 2
async with aiosqlite.connect(db_file) as db_conn:
db_conn.row_factory = sqlite3.Row # Set the row_factory
user_found = False
async with await db_conn.execute("SELECT * FROM topic_auth") as cursor:
for row in await cursor.fetchall():
assert row['username'] == client_id
assert "my/topic" in row['publish_acl']
assert "my/other/topic" not in row['publish_acl']
assert "my/another/topic" in row['publish_acl']
user_found = True
assert user_found
@pytest.mark.asyncio
async def test_remove_missing_topic(db_file, user_manager, topic_manager, db_connection):
client_id = "myuser"
topic_auth = await topic_manager.create_topic_auth(client_id)
assert topic_auth is not None
user = await user_manager.create_user_auth(client_id, "mypassword")
assert user is not None
await topic_manager.add_allowed_topic(client_id, "my/topic", Action.PUBLISH)
await topic_manager.add_allowed_topic(client_id, "my/other/topic", Action.PUBLISH)
topic_list = await topic_manager.add_allowed_topic(client_id, "my/another/topic", Action.PUBLISH)
assert len(topic_list) == 3
with pytest.raises(MQTTError):
await topic_manager.remove_allowed_topic(client_id, "my/not/topic", Action.PUBLISH)
@pytest.mark.asyncio
async def test_remove_topic_wrong_action(db_file, user_manager, topic_manager, db_connection):
client_id = "myuser"
topic_auth = await topic_manager.create_topic_auth(client_id)
assert topic_auth is not None
user = await user_manager.create_user_auth(client_id, "mypassword")
assert user is not None
await topic_manager.add_allowed_topic(client_id, "my/topic", Action.PUBLISH)
await topic_manager.add_allowed_topic(client_id, "my/other/topic", Action.PUBLISH)
topic_list = await topic_manager.add_allowed_topic(client_id, "my/another/topic", Action.PUBLISH)
assert len(topic_list) == 3
with pytest.raises(MQTTError):
await topic_manager.remove_allowed_topic(client_id, "my/other/topic", Action.SUBSCRIBE)
@pytest.mark.parametrize("acl_topic,msg_topic,outcome", [
("my/topic", "my/topic", True),
("my/topic", "my/other/topic", False),
("my/#", "my/other/topic", True),
("my/#", "my/another/topic", True),
])
@pytest.mark.asyncio
async def test_topic_publish_filter_plugin(db_file, topic_manager, db_connection, acl_topic, msg_topic, outcome):
client_id = "myuser"
user = await topic_manager.create_topic_auth(client_id)
assert user is not None
await topic_manager.add_allowed_topic(client_id, acl_topic, Action.PUBLISH)
broker_context = BrokerContext(broker=Broker())
broker_context.config = TopicAuthDBPlugin.Config(
connection=db_connection
)
db_auth_plugin = TopicAuthDBPlugin(context=broker_context)
s = Session()
s.username = client_id
assert await db_auth_plugin.topic_filtering(session=s, topic=msg_topic, action=Action.PUBLISH) == outcome,\
f"topic filter responded incorrectly: {not outcome} vs {outcome}."
@pytest.mark.asyncio
async def test_topic_subscribe(db_file, topic_manager, db_connection):
broker_cfg = {
'listeners': { 'default': {'type': 'tcp', 'bind': '127.0.0.1:1883'}},
'plugins': {
'amqtt.plugins.authentication.AnonymousAuthPlugin': {},
'amqtt.contrib.auth_db.TopicAuthDBPlugin': {
'connection': db_connection
},
'amqtt.plugins.sys.broker.BrokerSysPlugin': {
'sys_interval' : 2
}
}
}
client_id = "myuser"
topic_auth = await topic_manager.create_topic_auth(client_id)
assert topic_auth is not None
sub_topic_list = await topic_manager.add_allowed_topic(client_id, '$SYS/#', Action.SUBSCRIBE)
rcv_topic_list = await topic_manager.add_allowed_topic(client_id, '$SYS/#', Action.RECEIVE)
assert len(sub_topic_list) > 0
assert len(rcv_topic_list) > 0
broker = Broker(config=broker_cfg)
await broker.start()
await asyncio.sleep(0.1)
client = MQTTClient(client_id='myclientid', config={'auto_reconnect': False})
await client.connect("mqtt://myuser:mypassword@127.0.0.1:1883")
ret = await client.subscribe([
('$SYS/broker/clients/connected', QOS_0)
])
assert ret == [0x0,]
await asyncio.sleep(0.1)
message_received = False
try:
message = await client.deliver_message(timeout_duration=4)
assert message.topic == '$SYS/broker/clients/connected'
message_received = True
except asyncio.TimeoutError:
pass
assert message_received, "Did not receive a $SYS message"
await asyncio.sleep(0.1)
await client.disconnect()
await asyncio.sleep(0.1)
await broker.shutdown()