kopia lustrzana https://github.com/Yakifo/amqtt
Add persistence plugin (WIP)
rodzic
c8e0a4e356
commit
9dc217cf82
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright (c) 2015 Nicolas JOUANIN
|
||||||
|
#
|
||||||
|
# See the file license.txt for copying permission.
|
||||||
|
import asyncio
|
||||||
|
import sqlite3
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
class SQLitePlugin:
|
||||||
|
def __init__(self, context):
|
||||||
|
self.context = context
|
||||||
|
self.conn = None
|
||||||
|
self.cursor = None
|
||||||
|
self.db_file = None
|
||||||
|
try:
|
||||||
|
self.persistence_config = self.context.config['persistence']
|
||||||
|
self.init_db()
|
||||||
|
except KeyError:
|
||||||
|
self.context.logger.warn("'persistence' section not found in context configuration")
|
||||||
|
|
||||||
|
def init_db(self):
|
||||||
|
self.db_file = self.persistence_config.get('file', None)
|
||||||
|
if not self.db_file:
|
||||||
|
self.context.logger.warn("'file' persistence parameter not found")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.conn = sqlite3.connect(self.db_file)
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
self.context.logger.info("Database file '%s' opened" % self.db_file)
|
||||||
|
except Exception as e:
|
||||||
|
self.context.logger.error("Error while initializing database '%s' : %s" % (self.db_file, e))
|
||||||
|
if self.cursor:
|
||||||
|
self.cursor.execute("CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)")
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def save_session(self, session):
|
||||||
|
if self.cursor:
|
||||||
|
dump = pickle.dumps(session)
|
||||||
|
try:
|
||||||
|
self.cursor.execute(
|
||||||
|
"INSERT OR REPLACE INTO session (client_id, data) VALUES (?,?)", (session.client_id, dump))
|
||||||
|
self.conn.commit()
|
||||||
|
except Exception as e:
|
||||||
|
self.context.logger.error("Failed saving session '%s': %s" % (session, e))
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def find_session(self, client_id):
|
||||||
|
if self.cursor:
|
||||||
|
row = self.cursor.execute("SELECT data FROM session where client_id=?", (client_id,)).fetchone()
|
||||||
|
if row:
|
||||||
|
return pickle.loads(row[0])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def del_session(self, client_id):
|
||||||
|
if self.cursor:
|
||||||
|
self.cursor.execute("DELETE FROM session where client_id=?", (client_id,))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def on_broker_post_shutdown(self):
|
||||||
|
if self.conn:
|
||||||
|
self.conn.close()
|
||||||
|
self.context.logger.info("Database file '%s' closed" % self.db_file)
|
|
@ -72,3 +72,16 @@ class Session:
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return type(self).__name__ + '(clientId={0}, state={1})'.format(self.client_id, self.transitions.state)
|
return type(self).__name__ + '(clientId={0}, state={1})'.format(self.client_id, self.transitions.state)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
# Remove the unpicklable entries.
|
||||||
|
#del state['transitions']
|
||||||
|
del state['retained_messages']
|
||||||
|
del state['delivered_message_queue']
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self.retained_messages = Queue()
|
||||||
|
self.delivered_message_queue = Queue()
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) 2015 Nicolas JOUANIN
|
||||||
|
#
|
||||||
|
# See the file license.txt for copying permission.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import sqlite3
|
||||||
|
from hbmqtt.plugins.manager import BaseContext
|
||||||
|
from hbmqtt.plugins.persistence import SQLitePlugin
|
||||||
|
from hbmqtt.session import Session
|
||||||
|
|
||||||
|
formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format=formatter)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSQLitePlugin(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
|
def test_create_tables(self):
|
||||||
|
dbfile = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.db")
|
||||||
|
context = BaseContext()
|
||||||
|
context.logger = logging.getLogger(__name__)
|
||||||
|
context.config = {
|
||||||
|
'persistence': {
|
||||||
|
'file': dbfile
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sql_plugin = SQLitePlugin(context)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(dbfile)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
rows = cursor.execute("SELECT name FROM sqlite_master where type = 'table'")
|
||||||
|
tables = []
|
||||||
|
for row in rows:
|
||||||
|
tables.append(row[0])
|
||||||
|
self.assertIn("session", tables)
|
||||||
|
|
||||||
|
def test_save_session(self):
|
||||||
|
dbfile = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.db")
|
||||||
|
context = BaseContext()
|
||||||
|
context.logger = logging.getLogger(__name__)
|
||||||
|
context.config = {
|
||||||
|
'persistence': {
|
||||||
|
'file': dbfile
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sql_plugin = SQLitePlugin(context)
|
||||||
|
s = Session()
|
||||||
|
s.client_id = 'test_save_session'
|
||||||
|
ret = self.loop.run_until_complete(sql_plugin.save_session(session=s))
|
||||||
|
|
||||||
|
conn = sqlite3.connect(dbfile)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
row = cursor.execute("SELECT client_id FROM session where client_id = 'test_save_session'").fetchone()
|
||||||
|
self.assertTrue(len(row) == 1)
|
||||||
|
self.assertEquals(row[0], s.client_id)
|
Ładowanie…
Reference in New Issue