kopia lustrzana https://github.com/Yakifo/amqtt
intermediate check in
rodzic
c06e585be5
commit
e42461a8cc
|
@ -5,22 +5,22 @@ from typing import Any, TypeVar
|
|||
import warnings
|
||||
|
||||
from sqlalchemy import JSON, Boolean, Integer, LargeBinary, String, select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine, AsyncSession
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from amqtt.broker import BrokerContext
|
||||
from amqtt.broker import BrokerContext, RetainedApplicationMessage
|
||||
from amqtt.errors import PluginError
|
||||
from amqtt.mqtt.constants import QOS_0
|
||||
from amqtt.plugins.base import BasePlugin
|
||||
from amqtt.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLitePlugin:
|
||||
class SQLitePlugin(BasePlugin[BrokerContext]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, context: BrokerContext) -> None:
|
||||
super().__init__(context)
|
||||
warnings.warn("SQLitePlugin is deprecated, use amqtt.plugins.persistence.SessionDBPlugin", stacklevel=1)
|
||||
|
||||
|
||||
|
@ -108,7 +108,7 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create(db_session: AsyncSession, client_id:str):
|
||||
async def _get_or_create(db_session: AsyncSession, client_id:str) -> StoredSession:
|
||||
|
||||
stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
|
||||
stored_session = await db_session.scalar(stmt)
|
||||
|
@ -181,7 +181,35 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
|
|||
msg = "SessionDBPlugin : broker shouldn't have any sessions yet"
|
||||
raise PluginError(msg)
|
||||
|
||||
await self.context.add_subscription("test_client1", "a/b", QOS_0)
|
||||
async with self._db_session_maker() as db_session:
|
||||
async with db_session.begin():
|
||||
stmt = select(StoredSession)
|
||||
stored_sessions = await db_session.execute(stmt)
|
||||
logger.debug("> stored sessions retrieved")
|
||||
for stored_session in stored_sessions.scalars():
|
||||
for subscription in stored_session.subscriptions:
|
||||
await self.context.add_subscription(stored_session.client_id,
|
||||
subscription.topic,
|
||||
subscription.qos)
|
||||
session, _ = self.context.get_session(stored_session.client_id)
|
||||
if not session:
|
||||
continue
|
||||
session.clean_session = stored_session.clean_session
|
||||
session.will_flag = stored_session.will_flag
|
||||
session.will_message = stored_session.will_message
|
||||
session.will_qos = stored_session.will_qos
|
||||
session.will_retain = stored_session.will_retain
|
||||
session.will_topic = stored_session.will_topic
|
||||
session.keep_alive = stored_session.keep_alive
|
||||
|
||||
for message in stored_session.retained:
|
||||
retained_message = RetainedApplicationMessage(
|
||||
source_session=None,
|
||||
topic=message.topic,
|
||||
data=message.data.encode(),
|
||||
qos=message.qos
|
||||
)
|
||||
await session.retained_messages.put(retained_message)
|
||||
|
||||
async def on_broker_pre_shutdown(self) -> None:
|
||||
"""Clean up the db connection."""
|
||||
|
|
|
@ -143,7 +143,41 @@ async def test_update_stored_session(db_file, broker_context, db_session_factory
|
|||
assert row[1] == 'test_client_1'
|
||||
assert row[-1] == '[{"topic": "sensors/#", "qos": 1}, {"topic": "my/topic", "qos": 2}]'
|
||||
has_stored_session = True
|
||||
assert has_stored_session
|
||||
assert has_stored_session, "stored session wasn't updated"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repopulate_stored_sessions(db_file, broker_context, db_session_factory) -> None:
|
||||
|
||||
broker_context.config = SessionDBPlugin.Config(file=db_file)
|
||||
session_db_plugin = SessionDBPlugin(broker_context)
|
||||
await session_db_plugin.on_broker_pre_start()
|
||||
|
||||
async with aiosqlite.connect(str(db_file)) as db:
|
||||
sql = """INSERT INTO stored_sessions (
|
||||
client_id, clean_session, will_flag,
|
||||
will_qos, keep_alive,
|
||||
retained, subscriptions
|
||||
) VALUES (
|
||||
'test_client_1',
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
60,
|
||||
'[{"topic":"sensors/#","data":"this message is retained when client reconnects","qos":1}]',
|
||||
'[{"topic":"sensors/#","qos":1}]'
|
||||
)"""
|
||||
await db.execute(sql)
|
||||
await db.commit()
|
||||
|
||||
await session_db_plugin.on_broker_post_start()
|
||||
|
||||
session, _ = broker_context.get_session('test_client_1')
|
||||
assert session is not None
|
||||
assert session.retained_messages.qsize() == 1
|
||||
|
||||
assert 'sensors/#' in broker_context._broker_instance._subscriptions
|
||||
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
|
|
Ładowanie…
Reference in New Issue