diff --git a/activitypub/database/__init__.py b/activitypub/database/__init__.py index 1465cd1..793214e 100644 --- a/activitypub/database/__init__.py +++ b/activitypub/database/__init__.py @@ -3,3 +3,4 @@ from .base import Database, Table from .listdb import ListDatabase from .mongodb import MongoDatabase from .redisdb import RedisDatabase +from .sqldb import SQLDatabase diff --git a/activitypub/database/listdb.py b/activitypub/database/listdb.py index f6aea3c..3ee6db2 100644 --- a/activitypub/database/listdb.py +++ b/activitypub/database/listdb.py @@ -256,14 +256,31 @@ class ListTable(Table): [] >>> table.find({"d": 4}) # doctest: +ELLIPSIS [{'c': 3, 'd': 4, '_id': ObjectId('...')}] + >>> table.remove({"d": 4}) + >>> table.find({"d": 4}) + [] + >>> table.find({"b": 2}) # doctest: +ELLIPSIS + [{'a': 1, 'b': 2, '_id': ObjectId('...')}] """ if query: - items = [doc for doc in self.data if self.match(doc, query)] - # delete them + items = [(i,doc) for (i,doc) in enumerate(self.data) if self.match(doc, query)] + for i,doc in items: + del self.data[i] else: - self.data = [] + self.data.clear() def find_one(self, query): + """ + >>> table = ListTable() + >>> table.insert_one({"a": 1, "b": 2}) + >>> table.insert_one({"a": 3, "b": 4}) + + >>> table.find_one({"b": 2}) # doctest: +ELLIPSIS + {'a': 1, 'b': 2, '_id': ObjectId('...')} + >>> table.find_one({"b": 3}) + >>> table.find_one({"b": 4}) # doctest: +ELLIPSIS + {'a': 3, 'b': 4, '_id': ObjectId('...')} + """ results = [doc for doc in self.data if self.match(doc, query)] if results: return results[0] diff --git a/activitypub/database/sqldb.py b/activitypub/database/sqldb.py new file mode 100644 index 0000000..04bd044 --- /dev/null +++ b/activitypub/database/sqldb.py @@ -0,0 +1,162 @@ +import json + +from ..bson import ObjectId +from .base import Database +from .listdb import ListTable + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ObjectId): + return {"$oid": str(o)} + return super().default(o) + +class JSONDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + super().__init__(object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, obj): + if '$oid' not in obj: + return obj + return ObjectId(obj['$oid']) + +class SQLList(): + def __init__(self, database, name): + self.database = database + self.name = name + + def __getitem__(self, item): + if isinstance(item, int): + result = self.database.execute( + """SELECT blob_data FROM %s WHERE rowid = :rowid""" + % (self.name), {"rowid": item}) + item = result.fetchone() + if item: + item = json.loads(item[0], cls=JSONDecoder) + elif isinstance(item, slice): + items = [item in islice(self, item.start, item.stop, item.step)] + return items + if item: + return item + else: + raise IndexError("list index out of range") + + def __setitem__(self, item, value): + s = json.dumps(value, cls=JSONEncoder) + # first see if it exists: + try: + old_item = self[item] + except: + old_item = None + if old_item: + # update it + try: + self.database.execute( + """UPDATE %s SET blob_data = :s WHERE rowid = :rowid;""" + % (self.name), {"s": s, "rowid": item}) + self.database.commit() + except: + self.database.rollback() + raise + else: + # insert it + oid = str(value["_id"]) + try: + self.database.execute( + """INSERT INTO %s (blob_data, oid, rowid) VALUES (:s, :oid, :rowid);""" + % (self.name), {"s": s, "rowid": item, "oid": oid}) + self.database.commit() + except: + self.database.rollback() + raise + + def __delitem__(self, key): + try: + self.database.execute( + """DELETE FROM %s WHERE rowid = :rowid;""" + % (self.name), {"rowid": key}) + self.database.execute( + """UPDATE %s SET rowid = (rowid - 1) WHERE rowid > :rowid;""" + % self.name, {"rowid": key}) + self.database.commit() + except: + self.database.rollback() + raise + + def clear(self): + try: + self.database.execute("DELETE from %s;" % self.name) + self.database.commit() + except: + self.database.rollback() + raise + + def append(self, item): + pos = len(self) + self[pos] = item + + def __len__(self): + result = self.database.execute("SELECT count(1) FROM %s" % self.name) + row = result.fetchone() + return row[0] + +class SQLTable(ListTable): + def __init__(self, database, name): + super().__init__(database, name) + if not self.table_exists(name): + self.build_table(name) + self.data = SQLList(database, name) + + def table_exists(self, table): + result = self.database.execute("""SELECT COUNT(*) + FROM sqlite_master + WHERE type='table' AND name='%s';""" % table) + return result.fetchone()[0] != 0 + + def build_table(self, name): + try: + self.database.execute( + """CREATE TABLE %s ( + rowid INTEGER PRIMARY KEY ASC, + oid CHAR(24), + blob_data BLOB + )""" % name) + self.database.commit() + except: + self.database.rollback() + raise + + def sort(self, sort_key, sort_order): + # sort_key = "_id" + # sort_order = 1 or -1 + ## Always use ListTable here: + return ListTable(data=sorted( + self.data, + key=lambda row: self.get_item_in_dict(row, sort_key), + reverse=(sort_order == -1))) + +class SQLDatabase(Database): + Table = SQLTable + + def __init__(self, *args, **kwargs): + from sqlalchemy import create_engine + from sqlalchemy.orm import scoped_session, sessionmaker + #from sqlalchemy.pool import StaticPool + #from sqlalchemy.pool import QueuePool + + super().__init__() + self.engine = create_engine(*args, **kwargs) + # poolclass=QueuePool, + # convert_unicode=True, + # connect_args={'check_same_thread':False}, + # poolclass=StaticPool, + self.session = scoped_session(sessionmaker(bind=self.engine)) + + def commit(self): + self.session.commit() + + def rollback(self): + self.session.rollback() + + def execute(self, *args, **kwargs): + return self.session.execute(*args, **kwargs) +