diff --git a/activitypub/database/base.py b/activitypub/database/base.py index 9b6e89e..073387d 100644 --- a/activitypub/database/base.py +++ b/activitypub/database/base.py @@ -18,3 +18,10 @@ class Database(): if attr not in self._tables: self._tables[attr] = self.Table(self, attr) return self._tables[attr] + + def table_exists(self, table): + return table in self._tables + + def build_table(self, name): + self._tables[name] = self.Table(self, name) + diff --git a/activitypub/database/listdb.py b/activitypub/database/listdb.py index 3ee6db2..f685a13 100644 --- a/activitypub/database/listdb.py +++ b/activitypub/database/listdb.py @@ -170,7 +170,7 @@ class ListTable(Table): if i is not None: self.data[i] = dictionary - def drop(self): + def clear(self): self.data.clear() def sort(self, sort_key, sort_order): @@ -381,7 +381,8 @@ class ListTable(Table): else: return len(self.data) - count_documents = count + def count_documents(self, query): + return self.count(query) class ListDatabase(Database): Table = ListTable diff --git a/activitypub/database/mongodb.py b/activitypub/database/mongodb.py index 2cc2221..a399335 100644 --- a/activitypub/database/mongodb.py +++ b/activitypub/database/mongodb.py @@ -45,6 +45,9 @@ class MongoTable(Table): else: self.__dict__[attr] = value + def clear(self): + self.collection.drop() + class MongoDatabase(Database): Table = MongoTable def __init__(self, uri, db_name): @@ -53,3 +56,6 @@ class MongoDatabase(Database): self.client = MongoClient(self.uri) self.db_name = db_name self.DB = self.client[self.db_name] + + def table_exists(self, table): + return table in self.client.database_names() diff --git a/tests/test_all.py b/tests/test_all.py new file mode 100644 index 0000000..22b97af --- /dev/null +++ b/tests/test_all.py @@ -0,0 +1,51 @@ + +from activitypub.database import * +from activitypub.manager import Manager + +def test_all(): + for db in [ + ListDatabase(), + SQLDatabase("sqlite://"), + SQLDatabase("sqlite:///sqlite.db"), + MongoDatabase("mongodb://localhost:27017", "dsblank_localhost:5005"), + RedisDatabase("redis://localhost:6379/0"), + ]: + print("Testing", db.__class__.__name__, "...") + manager = Manager(database=db) + if manager.database.table_exists("activities"): + manager.database.activities.clear() + else: + manager.database.build_table("activities") + manager.database.activities.clear() + manager.database.actors.clear() + + assert manager.database.actors.count_documents( + {"$or": [{'id': 'https://example.com/alyssa'}, + {'id': 'https://example.com/alyssa'}]}) == 0 + assert manager.database.actors.count_documents( + {"$or": [{'id': 'https://example.com/alyssa'}, + {'id': 'https://example.com/alyssa'}]}) == 0 + + p1 = manager.Person(id="alyssa") + p2 = manager.Person(id="brenda") + + manager.database.actors.insert_one(p1.to_dict()) + manager.database.actors.insert_one(p2.to_dict()) + + assert manager.database.actors.count_documents( + {"$or": [{'id': 'https://example.com/alyssa'}, + {'id': 'https://example.com/brenda'}]}) == 2 + assert len(list(manager.database.actors.find( + {"$or": [{'id': 'https://example.com/alyssa'}, + {'id': 'https://example.com/brenda'}]}))) == 2 + + assert manager.database.actors.count_documents( + {'id': 'https://example.com/alyssa'}) == 1 + assert manager.database.actors.count_documents({}) == 2 + + ## Clean up + manager.database.activities.clear() + manager.database.actors.clear() + +if __name__ == "__main__": + test_all()