kopia lustrzana https://github.com/Yakifo/amqtt
505 wiersze
17 KiB
Python
505 wiersze
17 KiB
Python
import asyncio
|
|
import sqlite3
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import aiosqlite
|
|
from passlib.context import CryptContext
|
|
|
|
from amqtt.broker import BrokerContext, Broker
|
|
from amqtt.client import MQTTClient
|
|
from amqtt.contexts import Action
|
|
from amqtt.contrib.auth_db.models import AllowedTopic, PasswordHasher
|
|
from amqtt.contrib.auth_db.plugin import UserAuthDBPlugin, TopicAuthDBPlugin
|
|
from amqtt.contrib.auth_db.managers import UserManager, TopicManager
|
|
from amqtt.errors import ConnectError, MQTTError
|
|
from amqtt.mqtt.constants import QOS_1, QOS_0
|
|
from amqtt.session import Session
|
|
from argon2 import PasswordHasher as ArgonPasswordHasher
|
|
from argon2.exceptions import VerifyMismatchError
|
|
|
|
@pytest.fixture
|
|
def password_hasher():
|
|
pwd_hasher = PasswordHasher()
|
|
pwd_hasher.crypt_context = CryptContext(schemes=["argon2", ], deprecated="auto")
|
|
yield pwd_hasher
|
|
|
|
|
|
@pytest.fixture
|
|
def db_file():
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
with tempfile.NamedTemporaryFile(mode='wb', delete=True) as tmp:
|
|
yield Path(temp_dir) / f"{tmp.name}.db"
|
|
|
|
|
|
@pytest.fixture
|
|
def db_connection(db_file):
|
|
test_db_connect = f"sqlite+aiosqlite:///{db_file}"
|
|
yield test_db_connect
|
|
|
|
|
|
@pytest.fixture
|
|
@pytest.mark.asyncio
|
|
async def user_manager(password_hasher, db_connection):
|
|
um = UserManager(db_connection)
|
|
await um.db_sync()
|
|
yield um
|
|
|
|
|
|
@pytest.fixture
|
|
@pytest.mark.asyncio
|
|
async def topic_manager(password_hasher, db_connection):
|
|
tm = TopicManager(db_connection)
|
|
await tm.db_sync()
|
|
yield tm
|
|
|
|
|
|
# ######################################
|
|
# Tests for the UserAuthDBPlugin
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_user(user_manager, db_file, db_connection):
|
|
await user_manager.create_user_auth("myuser", "mypassword")
|
|
|
|
async with aiosqlite.connect(db_file) as db_conn:
|
|
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
|
|
|
has_user = False
|
|
async with await db_conn.execute("SELECT * FROM user_auth") as cursor:
|
|
for row in await cursor.fetchall():
|
|
assert row['username'] == "myuser"
|
|
assert row['password_hash'] != "mypassword"
|
|
assert '$argon2' in row['password_hash']
|
|
ph = ArgonPasswordHasher()
|
|
ph.verify(row['password_hash'], "mypassword")
|
|
|
|
with pytest.raises(VerifyMismatchError):
|
|
ph.verify(row['password_hash'], "mywrongpassword")
|
|
|
|
has_user = True
|
|
|
|
assert has_user, "user was not created"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_users(user_manager, db_file, db_connection):
|
|
await user_manager.create_user_auth("myuser", "mypassword")
|
|
await user_manager.create_user_auth("otheruser", "mypassword")
|
|
await user_manager.create_user_auth("anotheruser", "mypassword")
|
|
|
|
assert len(list(await user_manager.list_user_auths())) == 3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_empty_users(user_manager, db_file, db_connection):
|
|
|
|
assert len(list(await user_manager.list_user_auths())) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_password_change(user_manager, db_file, db_connection):
|
|
new_user = await user_manager.create_user_auth("myuser", "mypassword")
|
|
|
|
await user_manager.update_user_auth_password("myuser", "mynewpassword")
|
|
async with aiosqlite.connect(db_file) as db_conn:
|
|
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
|
|
|
has_user = False
|
|
async with await db_conn.execute("SELECT * FROM user_auth") as cursor:
|
|
for row in await cursor.fetchall():
|
|
assert row['password_hash'] != new_user._password_hash
|
|
ph = ArgonPasswordHasher()
|
|
with pytest.raises(VerifyMismatchError):
|
|
ph.verify(row['password_hash'], "mypassword")
|
|
|
|
has_user = True
|
|
|
|
assert has_user, "user was not found"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_users(user_manager, db_file, db_connection):
|
|
await user_manager.create_user_auth("myuser", "mypassword")
|
|
await user_manager.create_user_auth("otheruser", "mypassword")
|
|
await user_manager.create_user_auth("anotheruser", "mypassword")
|
|
|
|
assert len(list(await user_manager.list_user_auths())) == 3
|
|
|
|
await user_manager.delete_user_auth("myuser")
|
|
|
|
assert len(list(await user_manager.list_user_auths())) == 2
|
|
|
|
async with aiosqlite.connect(db_file) as db_conn:
|
|
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
|
|
|
test_run = False
|
|
async with await db_conn.execute("SELECT * FROM user_auth") as cursor:
|
|
for row in await cursor.fetchall():
|
|
assert row['username'] in ("otheruser", "anotheruser")
|
|
test_run = True
|
|
|
|
assert test_run, "users weren't not found"
|
|
|
|
|
|
@pytest.mark.parametrize("user_pwd,session_pwd,outcome", [
|
|
("mypassword", "mypassword", True),
|
|
("mypassword", "myotherpassword", False),
|
|
])
|
|
@pytest.mark.asyncio
|
|
async def test_db_auth(db_connection, user_manager, user_pwd, session_pwd, outcome):
|
|
|
|
await user_manager.create_user_auth("myuser", user_pwd)
|
|
|
|
broker_context = BrokerContext(broker=Broker())
|
|
broker_context.config = UserAuthDBPlugin.Config(
|
|
connection=db_connection
|
|
)
|
|
db_auth_plugin = UserAuthDBPlugin(context=broker_context)
|
|
|
|
s = Session()
|
|
s.username = "myuser"
|
|
s.password = session_pwd
|
|
|
|
assert await db_auth_plugin.authenticate(session=s) == outcome
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_authentication(user_manager, db_connection):
|
|
|
|
user = await user_manager.create_user_auth("myuser", "mypassword")
|
|
assert user is not None
|
|
|
|
broker_cfg = {
|
|
'listeners': { 'default': {'type': 'tcp', 'bind': '127.0.0.1:1883'}},
|
|
'plugins': {
|
|
'amqtt.contrib.auth_db.UserAuthDBPlugin': {
|
|
'connection': db_connection,
|
|
}
|
|
}
|
|
}
|
|
|
|
broker = Broker(config=broker_cfg)
|
|
|
|
await broker.start()
|
|
await asyncio.sleep(0.1)
|
|
|
|
client = MQTTClient(client_id='myclientid', config={'auto_reconnect': False})
|
|
await client.connect("mqtt://myuser:mypassword@127.0.0.1:1883")
|
|
|
|
await client.subscribe([
|
|
("my/topic", QOS_1)
|
|
])
|
|
await asyncio.sleep(0.1)
|
|
await client.publish("my/topic", b"test")
|
|
await asyncio.sleep(0.1)
|
|
message = await client.deliver_message(timeout_duration=1)
|
|
assert message.topic == "my/topic"
|
|
await asyncio.sleep(0.1)
|
|
await client.disconnect()
|
|
await asyncio.sleep(0.1)
|
|
await broker.shutdown()
|
|
|
|
|
|
@pytest.mark.parametrize("client_pwd", [
|
|
("mywrongpassword", ),
|
|
("", ),
|
|
])
|
|
@pytest.mark.asyncio
|
|
async def test_client_blocked(user_manager, db_connection, client_pwd):
|
|
|
|
user = await user_manager.create_user_auth("myuser", "mypassword")
|
|
assert user is not None
|
|
|
|
broker_cfg = {
|
|
'listeners': { 'default': {'type': 'tcp', 'bind': '127.0.0.1:1883'}},
|
|
'plugins': {
|
|
'amqtt.contrib.auth_db.UserAuthDBPlugin': {
|
|
'connection': db_connection,
|
|
}
|
|
}
|
|
}
|
|
|
|
broker = Broker(config=broker_cfg)
|
|
|
|
await broker.start()
|
|
await asyncio.sleep(0.1)
|
|
|
|
client = MQTTClient(client_id='myclientid', config={'auto_reconnect': False})
|
|
with pytest.raises(ConnectError):
|
|
await client.connect(f"mqtt://myuser:{client_pwd}@127.0.0.1:1883")
|
|
|
|
await broker.shutdown()
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
# ######################################
|
|
# Tests for the TopicAuthDBPlugin
|
|
|
|
def test_allowed_topic_match():
|
|
|
|
at = AllowedTopic(topic="my/topic")
|
|
assert "my/topic" in at
|
|
|
|
at2 = AllowedTopic(topic="my/other/topic")
|
|
assert "my/other/topic" in at2
|
|
assert "my/another/topic" not in at2
|
|
|
|
at3 = AllowedTopic(topic="my/#")
|
|
assert "my/other" in at3
|
|
assert "my/other/topic" in at3
|
|
assert "other/topic" not in at3
|
|
|
|
at4 = AllowedTopic(topic="my/other/#")
|
|
assert "my/other" in at4
|
|
assert "my/other/topic" in at4
|
|
assert "other/topic" not in at4
|
|
assert "/my/other/topic" not in at4
|
|
|
|
at5 = AllowedTopic(topic="my/+/topic")
|
|
assert "my/other/topic" in at5
|
|
assert "my/another/topic" in at5
|
|
assert "my/other/another/topic" not in at5
|
|
|
|
at6 = AllowedTopic(topic="my/other/topic")
|
|
assert at6 == at2
|
|
assert at2 == at6
|
|
assert at6 != at
|
|
assert at6 not in at5
|
|
|
|
|
|
def test_allowed_topic_list_match():
|
|
at1 = AllowedTopic(topic="one/topic")
|
|
at2 = AllowedTopic(topic="one/other/topic")
|
|
at3 = AllowedTopic(topic="two/+/topic")
|
|
at4 = AllowedTopic(topic="three/topic/#")
|
|
|
|
at_list = [at1, at2, at3, at4]
|
|
|
|
assert "one/topic" in at_list
|
|
assert "two/other/topic" in at_list
|
|
assert "two/another/topic" in at_list
|
|
assert "three/topic" in at_list
|
|
assert "three/topic/other" in at_list
|
|
|
|
|
|
def test_remove_topic_list():
|
|
at1 = AllowedTopic(topic="my/topic")
|
|
at2 = AllowedTopic(topic="my/other/topic")
|
|
at3 = AllowedTopic(topic="my/another/topic")
|
|
|
|
at_list = [at1, at2, at3]
|
|
|
|
at4 = AllowedTopic(topic="my/topic")
|
|
at5 = AllowedTopic(topic="my/not/topic")
|
|
at_list.remove(at4)
|
|
|
|
with pytest.raises(ValueError):
|
|
at_list.remove(at5)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_topic_to_client(db_file, user_manager, topic_manager, db_connection):
|
|
client_id = "myuser"
|
|
|
|
topic_auth = await topic_manager.create_topic_auth(client_id)
|
|
assert topic_auth is not None
|
|
topic_list = await topic_manager.add_allowed_topic(client_id, "my/topic", Action.PUBLISH)
|
|
assert len(topic_list) > 0
|
|
|
|
async with aiosqlite.connect(db_file) as db_conn:
|
|
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
|
|
|
user_found = False
|
|
async with await db_conn.execute("SELECT * FROM topic_auth") as cursor:
|
|
for row in await cursor.fetchall():
|
|
assert row['username'] == client_id
|
|
assert "my/topic" in row['publish_acl']
|
|
user_found = True
|
|
|
|
assert user_found
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_dollar_topic_for_publish(db_file, user_manager, topic_manager, db_connection):
|
|
client_id = "myuser"
|
|
|
|
topic_auth = await topic_manager.create_topic_auth(client_id)
|
|
assert topic_auth is not None
|
|
with pytest.raises(MQTTError):
|
|
await topic_manager.add_allowed_topic(client_id, "$MY/topic", Action.PUBLISH)
|
|
|
|
async with aiosqlite.connect(db_file) as db_conn:
|
|
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
|
|
|
auth_topic_found = False
|
|
async with await db_conn.execute("SELECT * FROM topic_auth") as cursor:
|
|
for row in await cursor.fetchall():
|
|
assert row['username'] == client_id
|
|
assert "$MY/topic" not in row['publish_acl']
|
|
auth_topic_found = True
|
|
|
|
assert auth_topic_found
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_topic_from_client(db_file, user_manager, topic_manager, db_connection):
|
|
client_id = "myuser"
|
|
topic_auth = await topic_manager.create_topic_auth(client_id)
|
|
assert topic_auth is not None
|
|
user = await user_manager.create_user_auth(client_id, "mypassword")
|
|
assert user is not None
|
|
|
|
await topic_manager.add_allowed_topic(client_id, "my/topic", Action.PUBLISH)
|
|
await topic_manager.add_allowed_topic(client_id, "my/other/topic", Action.PUBLISH)
|
|
topic_list = await topic_manager.add_allowed_topic(client_id, "my/another/topic", Action.PUBLISH)
|
|
assert len(topic_list) == 3
|
|
|
|
async with aiosqlite.connect(db_file) as db_conn:
|
|
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
|
|
|
user_found = False
|
|
async with await db_conn.execute("SELECT * FROM topic_auth") as cursor:
|
|
for row in await cursor.fetchall():
|
|
assert row['username'] == client_id
|
|
assert "my/topic" in row['publish_acl']
|
|
assert "my/other/topic" in row['publish_acl']
|
|
assert "my/another/topic" in row['publish_acl']
|
|
user_found = True
|
|
assert user_found
|
|
|
|
topic_list = await topic_manager.remove_allowed_topic(client_id, "my/other/topic", Action.PUBLISH)
|
|
assert len(topic_list) == 2
|
|
|
|
async with aiosqlite.connect(db_file) as db_conn:
|
|
db_conn.row_factory = sqlite3.Row # Set the row_factory
|
|
|
|
user_found = False
|
|
async with await db_conn.execute("SELECT * FROM topic_auth") as cursor:
|
|
for row in await cursor.fetchall():
|
|
assert row['username'] == client_id
|
|
assert "my/topic" in row['publish_acl']
|
|
assert "my/other/topic" not in row['publish_acl']
|
|
assert "my/another/topic" in row['publish_acl']
|
|
user_found = True
|
|
assert user_found
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_missing_topic(db_file, user_manager, topic_manager, db_connection):
|
|
client_id = "myuser"
|
|
topic_auth = await topic_manager.create_topic_auth(client_id)
|
|
assert topic_auth is not None
|
|
user = await user_manager.create_user_auth(client_id, "mypassword")
|
|
assert user is not None
|
|
|
|
await topic_manager.add_allowed_topic(client_id, "my/topic", Action.PUBLISH)
|
|
await topic_manager.add_allowed_topic(client_id, "my/other/topic", Action.PUBLISH)
|
|
topic_list = await topic_manager.add_allowed_topic(client_id, "my/another/topic", Action.PUBLISH)
|
|
assert len(topic_list) == 3
|
|
|
|
with pytest.raises(MQTTError):
|
|
await topic_manager.remove_allowed_topic(client_id, "my/not/topic", Action.PUBLISH)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_topic_wrong_action(db_file, user_manager, topic_manager, db_connection):
|
|
client_id = "myuser"
|
|
topic_auth = await topic_manager.create_topic_auth(client_id)
|
|
assert topic_auth is not None
|
|
user = await user_manager.create_user_auth(client_id, "mypassword")
|
|
assert user is not None
|
|
|
|
await topic_manager.add_allowed_topic(client_id, "my/topic", Action.PUBLISH)
|
|
await topic_manager.add_allowed_topic(client_id, "my/other/topic", Action.PUBLISH)
|
|
topic_list = await topic_manager.add_allowed_topic(client_id, "my/another/topic", Action.PUBLISH)
|
|
assert len(topic_list) == 3
|
|
|
|
with pytest.raises(MQTTError):
|
|
await topic_manager.remove_allowed_topic(client_id, "my/other/topic", Action.SUBSCRIBE)
|
|
|
|
|
|
@pytest.mark.parametrize("acl_topic,msg_topic,outcome", [
|
|
("my/topic", "my/topic", True),
|
|
("my/topic", "my/other/topic", False),
|
|
("my/#", "my/other/topic", True),
|
|
("my/#", "my/another/topic", True),
|
|
])
|
|
@pytest.mark.asyncio
|
|
async def test_topic_publish_filter_plugin(db_file, topic_manager, db_connection, acl_topic, msg_topic, outcome):
|
|
client_id = "myuser"
|
|
|
|
user = await topic_manager.create_topic_auth(client_id)
|
|
assert user is not None
|
|
|
|
await topic_manager.add_allowed_topic(client_id, acl_topic, Action.PUBLISH)
|
|
|
|
broker_context = BrokerContext(broker=Broker())
|
|
broker_context.config = TopicAuthDBPlugin.Config(
|
|
connection=db_connection
|
|
)
|
|
db_auth_plugin = TopicAuthDBPlugin(context=broker_context)
|
|
|
|
s = Session()
|
|
s.username = client_id
|
|
assert await db_auth_plugin.topic_filtering(session=s, topic=msg_topic, action=Action.PUBLISH) == outcome,\
|
|
f"topic filter responded incorrectly: {not outcome} vs {outcome}."
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_topic_subscribe(db_file, topic_manager, db_connection):
|
|
|
|
broker_cfg = {
|
|
'listeners': { 'default': {'type': 'tcp', 'bind': '127.0.0.1:1883'}},
|
|
'plugins': {
|
|
'amqtt.plugins.authentication.AnonymousAuthPlugin': {},
|
|
'amqtt.contrib.auth_db.TopicAuthDBPlugin': {
|
|
'connection': db_connection
|
|
},
|
|
'amqtt.plugins.sys.broker.BrokerSysPlugin': {
|
|
'sys_interval' : 2
|
|
}
|
|
}
|
|
}
|
|
client_id = "myuser"
|
|
|
|
topic_auth = await topic_manager.create_topic_auth(client_id)
|
|
assert topic_auth is not None
|
|
|
|
sub_topic_list = await topic_manager.add_allowed_topic(client_id, '$SYS/#', Action.SUBSCRIBE)
|
|
rcv_topic_list = await topic_manager.add_allowed_topic(client_id, '$SYS/#', Action.RECEIVE)
|
|
|
|
assert len(sub_topic_list) > 0
|
|
assert len(rcv_topic_list) > 0
|
|
|
|
broker = Broker(config=broker_cfg)
|
|
|
|
await broker.start()
|
|
await asyncio.sleep(0.1)
|
|
|
|
client = MQTTClient(client_id='myclientid', config={'auto_reconnect': False})
|
|
await client.connect("mqtt://myuser:mypassword@127.0.0.1:1883")
|
|
|
|
ret = await client.subscribe([
|
|
('$SYS/broker/clients/connected', QOS_0)
|
|
])
|
|
|
|
assert ret == [0x0,]
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
message_received = False
|
|
try:
|
|
message = await client.deliver_message(timeout_duration=4)
|
|
assert message.topic == '$SYS/broker/clients/connected'
|
|
message_received = True
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
|
|
assert message_received, "Did not receive a $SYS message"
|
|
await asyncio.sleep(0.1)
|
|
await client.disconnect()
|
|
await asyncio.sleep(0.1)
|
|
await broker.shutdown()
|