intermediate check in

pull/256/head
Andrew Mirsky 2025-07-05 22:55:19 -04:00
rodzic c06e585be5
commit e42461a8cc
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: A98E67635CDF2C39
2 zmienionych plików z 70 dodań i 8 usunięć

Wyświetl plik

@ -5,22 +5,22 @@ from typing import Any, TypeVar
import warnings
from sqlalchemy import JSON, Boolean, Integer, LargeBinary, String, select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine, AsyncSession
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.types import TypeDecorator
from amqtt.broker import BrokerContext
from amqtt.broker import BrokerContext, RetainedApplicationMessage
from amqtt.errors import PluginError
from amqtt.mqtt.constants import QOS_0
from amqtt.plugins.base import BasePlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class SQLitePlugin:
class SQLitePlugin(BasePlugin[BrokerContext]):
def __init__(self) -> None:
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
warnings.warn("SQLitePlugin is deprecated, use amqtt.plugins.persistence.SessionDBPlugin", stacklevel=1)
@ -108,7 +108,7 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
@staticmethod
async def _get_or_create(db_session: AsyncSession, client_id:str):
async def _get_or_create(db_session: AsyncSession, client_id:str) -> StoredSession:
stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
stored_session = await db_session.scalar(stmt)
@ -181,7 +181,35 @@ class SessionDBPlugin(BasePlugin[BrokerContext]):
msg = "SessionDBPlugin : broker shouldn't have any sessions yet"
raise PluginError(msg)
await self.context.add_subscription("test_client1", "a/b", QOS_0)
async with self._db_session_maker() as db_session:
async with db_session.begin():
stmt = select(StoredSession)
stored_sessions = await db_session.execute(stmt)
logger.debug("> stored sessions retrieved")
for stored_session in stored_sessions.scalars():
for subscription in stored_session.subscriptions:
await self.context.add_subscription(stored_session.client_id,
subscription.topic,
subscription.qos)
session, _ = self.context.get_session(stored_session.client_id)
if not session:
continue
session.clean_session = stored_session.clean_session
session.will_flag = stored_session.will_flag
session.will_message = stored_session.will_message
session.will_qos = stored_session.will_qos
session.will_retain = stored_session.will_retain
session.will_topic = stored_session.will_topic
session.keep_alive = stored_session.keep_alive
for message in stored_session.retained:
retained_message = RetainedApplicationMessage(
source_session=None,
topic=message.topic,
data=message.data.encode(),
qos=message.qos
)
await session.retained_messages.put(retained_message)
async def on_broker_pre_shutdown(self) -> None:
"""Clean up the db connection."""

Wyświetl plik

@ -143,7 +143,41 @@ async def test_update_stored_session(db_file, broker_context, db_session_factory
assert row[1] == 'test_client_1'
assert row[-1] == '[{"topic": "sensors/#", "qos": 1}, {"topic": "my/topic", "qos": 2}]'
has_stored_session = True
assert has_stored_session
assert has_stored_session, "stored session wasn't updated"
@pytest.mark.asyncio
async def test_repopulate_stored_sessions(db_file, broker_context, db_session_factory) -> None:
broker_context.config = SessionDBPlugin.Config(file=db_file)
session_db_plugin = SessionDBPlugin(broker_context)
await session_db_plugin.on_broker_pre_start()
async with aiosqlite.connect(str(db_file)) as db:
sql = """INSERT INTO stored_sessions (
client_id, clean_session, will_flag,
will_qos, keep_alive,
retained, subscriptions
) VALUES (
'test_client_1',
1,
0,
1,
60,
'[{"topic":"sensors/#","data":"this message is retained when client reconnects","qos":1}]',
'[{"topic":"sensors/#","qos":1}]'
)"""
await db.execute(sql)
await db.commit()
await session_db_plugin.on_broker_post_start()
session, _ = broker_context.get_session('test_client_1')
assert session is not None
assert session.retained_messages.qsize() == 1
assert 'sensors/#' in broker_context._broker_instance._subscriptions
# @pytest.mark.asyncio