From d0d7cbb9eb78b3bbb1a0aa9606d1646766d53c97 Mon Sep 17 00:00:00 2001 From: Douglas Blank Date: Tue, 24 Jul 2018 16:38:53 -0400 Subject: [PATCH] SQLAlchemy working, but not fast; WIP: use indexes for queries --- activitypub/database/sqldb.py | 113 +++++++++++++++++++++++++--------- 1 file changed, 83 insertions(+), 30 deletions(-) diff --git a/activitypub/database/sqldb.py b/activitypub/database/sqldb.py index 04bd044..20dc57b 100644 --- a/activitypub/database/sqldb.py +++ b/activitypub/database/sqldb.py @@ -1,3 +1,6 @@ +from sqlalchemy import create_engine, inspect +from sqlalchemy.orm import scoped_session, sessionmaker +import logging import json from ..bson import ObjectId @@ -102,28 +105,71 @@ class SQLList(): class SQLTable(ListTable): def __init__(self, database, name): super().__init__(database, name) - if not self.table_exists(name): - self.build_table(name) + if not self.database.table_exists(name): + self.database.build_table(name) self.data = SQLList(database, name) - def table_exists(self, table): - result = self.database.execute("""SELECT COUNT(*) - FROM sqlite_master - WHERE type='table' AND name='%s';""" % table) - return result.fetchone()[0] != 0 + def get_schema(self): + ins = inspect(self.database.engine) + return ins.get_columns(self.name) - def build_table(self, name): - try: - self.database.execute( - """CREATE TABLE %s ( - rowid INTEGER PRIMARY KEY ASC, - oid CHAR(24), - blob_data BLOB - )""" % name) - self.database.commit() - except: - self.database.rollback() - raise + def get_columns(self): + schema = self.get_schema() + return [d["name"] for d in schema] + + def build_compare(self, lhs, rhs): + if isinstance(rhs, dict): + q = [] + for item in rhs: + if item == "$regex": + q.append("SQL regex") ## FIXME + elif item == "$lt": + q.append("(%s < %s)" % (lhs, rhs[item])) + elif item == "$gt": + q.append("(%s > %s)" % (lhs, rhs[item])) + elif item == "$in": + if isinstance(lhs, list): + q.append("(%s IN %s)" % (lhs, rhs)) ## FIXME? + else: + q.append("(%s IN %s)" % (lhs, rhs[item])) ## FIXME? + else: + raise Exception("unknown operator: %s" % item) + return "(" + (" AND ".join(q)) + ")" + else: + if isinstance(lhs, list): + if isinstance(rhs, list): + return "(%s = %s)" % (lhs, repr(rhs)) ## FIXME? + else: + return "(%s IN %s)" % (rhs, lhs) ## FIXME? + else: + return "(%s = %s)" % (lhs, repr(rhs)) + + def build_query(self, query, limit=None): + q = [] + for item in query: + if item == "$or": + expr = "(" + (" OR ".join([self.build_query(each) for each in query[item]])) + ")" + elif item == "$and": + expr = "(" + (" AND ".join([self.build_query(each) for each in query[item]])) + ")" + else: + expr = self.build_compare(item, query[item]) + q.append(expr) + return "(" + (" AND ".join(q)) + ")" + + def find(self, query=None, limit=None, enumerated=False): + ## if the query contains a SQL table field, then + ## use that portion + ## WIP: find portion of query that can be SQL selected + ## NOTE: limit can only be applied if full query applies + # logging.info("query: %s" % query) + # if query is not None or limit is not None: + # q = self.build_query(query, limit) + # logging.info("built q: %s" % q) + # if False: ## TODO: handle query + # results = self.database.execute(q) + # return ListTable(data=results.fetchall()).find(query, enumerated=enumerated) + ## else, just go through all of the items + return super().find(query, limit, enumerated) def sort(self, sort_key, sort_order): # sort_key = "_id" @@ -136,21 +182,12 @@ class SQLTable(ListTable): class SQLDatabase(Database): Table = SQLTable - - def __init__(self, *args, **kwargs): - from sqlalchemy import create_engine - from sqlalchemy.orm import scoped_session, sessionmaker - #from sqlalchemy.pool import StaticPool - #from sqlalchemy.pool import QueuePool + def __init__(self, *args, **kwargs): super().__init__() self.engine = create_engine(*args, **kwargs) - # poolclass=QueuePool, - # convert_unicode=True, - # connect_args={'check_same_thread':False}, - # poolclass=StaticPool, self.session = scoped_session(sessionmaker(bind=self.engine)) - + def commit(self): self.session.commit() @@ -160,3 +197,19 @@ class SQLDatabase(Database): def execute(self, *args, **kwargs): return self.session.execute(*args, **kwargs) + def table_exists(self, table): + ins = inspect(self.engine) + return table in ins.get_table_names() + + def build_table(self, name): + try: + self.execute( + """CREATE TABLE %s ( + rowid INTEGER PRIMARY KEY ASC, + oid CHAR(24), + blob_data BLOB + )""" % name) + self.commit() + except: + self.rollback() + raise