From 9dc217cf82469b0f7dd056ecb409fc843d0a5e49 Mon Sep 17 00:00:00 2001 From: Nicolas Jouanin Date: Mon, 31 Aug 2015 22:37:01 +0200 Subject: [PATCH] Add persistence plugin (WIP) --- hbmqtt/plugins/persistence.py | 65 +++++++++++++++++++++++++++++++ hbmqtt/session.py | 13 +++++++ tests/plugins/test_persistence.py | 59 ++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 hbmqtt/plugins/persistence.py create mode 100644 tests/plugins/test_persistence.py diff --git a/hbmqtt/plugins/persistence.py b/hbmqtt/plugins/persistence.py new file mode 100644 index 0000000..577fd81 --- /dev/null +++ b/hbmqtt/plugins/persistence.py @@ -0,0 +1,65 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. +import asyncio +import sqlite3 +import pickle + + +class SQLitePlugin: + def __init__(self, context): + self.context = context + self.conn = None + self.cursor = None + self.db_file = None + try: + self.persistence_config = self.context.config['persistence'] + self.init_db() + except KeyError: + self.context.logger.warn("'persistence' section not found in context configuration") + + def init_db(self): + self.db_file = self.persistence_config.get('file', None) + if not self.db_file: + self.context.logger.warn("'file' persistence parameter not found") + else: + try: + self.conn = sqlite3.connect(self.db_file) + self.cursor = self.conn.cursor() + self.context.logger.info("Database file '%s' opened" % self.db_file) + except Exception as e: + self.context.logger.error("Error while initializing database '%s' : %s" % (self.db_file, e)) + if self.cursor: + self.cursor.execute("CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)") + + @asyncio.coroutine + def save_session(self, session): + if self.cursor: + dump = pickle.dumps(session) + try: + self.cursor.execute( + "INSERT OR REPLACE INTO session (client_id, data) VALUES (?,?)", (session.client_id, dump)) + self.conn.commit() + except Exception as e: + self.context.logger.error("Failed saving session '%s': %s" % (session, e)) + + @asyncio.coroutine + def find_session(self, client_id): + if self.cursor: + row = self.cursor.execute("SELECT data FROM session where client_id=?", (client_id,)).fetchone() + if row: + return pickle.loads(row[0]) + else: + return None + + @asyncio.coroutine + def del_session(self, client_id): + if self.cursor: + self.cursor.execute("DELETE FROM session where client_id=?", (client_id,)) + self.conn.commit() + + @asyncio.coroutine + def on_broker_post_shutdown(self): + if self.conn: + self.conn.close() + self.context.logger.info("Database file '%s' closed" % self.db_file) diff --git a/hbmqtt/session.py b/hbmqtt/session.py index b627a28..cc37eaf 100644 --- a/hbmqtt/session.py +++ b/hbmqtt/session.py @@ -72,3 +72,16 @@ class Session: def __repr__(self): return type(self).__name__ + '(clientId={0}, state={1})'.format(self.client_id, self.transitions.state) + + def __getstate__(self): + state = self.__dict__.copy() + # Remove the unpicklable entries. + #del state['transitions'] + del state['retained_messages'] + del state['delivered_message_queue'] + return state + + def __setstate(self, state): + self.__dict__.update(state) + self.retained_messages = Queue() + self.delivered_message_queue = Queue() diff --git a/tests/plugins/test_persistence.py b/tests/plugins/test_persistence.py new file mode 100644 index 0000000..93d6135 --- /dev/null +++ b/tests/plugins/test_persistence.py @@ -0,0 +1,59 @@ +# Copyright (c) 2015 Nicolas JOUANIN +# +# See the file license.txt for copying permission. + +import unittest +import logging +import os +import asyncio +import sqlite3 +from hbmqtt.plugins.manager import BaseContext +from hbmqtt.plugins.persistence import SQLitePlugin +from hbmqtt.session import Session + +formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" +logging.basicConfig(level=logging.DEBUG, format=formatter) + + +class TestSQLitePlugin(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + + def test_create_tables(self): + dbfile = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.db") + context = BaseContext() + context.logger = logging.getLogger(__name__) + context.config = { + 'persistence': { + 'file': dbfile + } + } + sql_plugin = SQLitePlugin(context) + + conn = sqlite3.connect(dbfile) + cursor = conn.cursor() + rows = cursor.execute("SELECT name FROM sqlite_master where type = 'table'") + tables = [] + for row in rows: + tables.append(row[0]) + self.assertIn("session", tables) + + def test_save_session(self): + dbfile = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.db") + context = BaseContext() + context.logger = logging.getLogger(__name__) + context.config = { + 'persistence': { + 'file': dbfile + } + } + sql_plugin = SQLitePlugin(context) + s = Session() + s.client_id = 'test_save_session' + ret = self.loop.run_until_complete(sql_plugin.save_session(session=s)) + + conn = sqlite3.connect(dbfile) + cursor = conn.cursor() + row = cursor.execute("SELECT client_id FROM session where client_id = 'test_save_session'").fetchone() + self.assertTrue(len(row) == 1) + self.assertEquals(row[0], s.client_id)