diff --git a/atproto_firehose.py b/atproto_firehose.py index ed0911c..3168a5f 100644 --- a/atproto_firehose.py +++ b/atproto_firehose.py @@ -6,13 +6,16 @@ import itertools import logging import os from queue import SimpleQueue +from threading import Lock, Thread, Timer import time from carbox import read_car import dag_json +from google.cloud import ndb from granary.bluesky import AT_URI_PATTERN from lexrpc.client import Client from oauth_dropins.webutil import util +from oauth_dropins.webutil.appengine_config import ndb_client from oauth_dropins.webutil.appengine_info import DEBUG from atproto import ATProto, Cursor @@ -36,24 +39,37 @@ new_commits = SimpleQueue() atproto_dids = None bridged_dids = None loaded_dids_at = None +load_dids_lock = Lock() + def load_dids(): + # start in a a separate thread since it needs to make its own NDB context + # when it runs in the timer thread + thread = Thread(target=_load_dids) + thread.start() + thread.join() + + +def _load_dids(): global atproto_dids, bridged_dids, loaded_dids_at - if loaded_dids_at and loaded_dids_at > util.now() - RECONNECT_DELAY: - return + with load_dids_lock, ndb_client.context(): + if loaded_dids_at and loaded_dids_at > util.now() - RECONNECT_DELAY: + return - atproto_query = ATProto.query(ATProto.enabled_protocols != None) - atproto_dids = frozenset(key.id() for key in atproto_query.iter(keys_only=True)) + atproto_query = ATProto.query(ATProto.enabled_protocols != None) + atproto_dids = frozenset(key.id() for key in atproto_query.iter(keys_only=True)) - others_queries = itertools.chain(*( - cls.query(cls.copies.protocol == 'atproto').iter() - for cls in set(models.PROTOCOLS.values()) - if cls and cls != ATProto)) - bridged_dids = frozenset(user.get_copy(ATProto) for user in others_queries) + others_queries = itertools.chain(*( + cls.query(cls.copies.protocol == 'atproto').iter() + for cls in set(models.PROTOCOLS.values()) + if cls and cls != ATProto)) + bridged_dids = frozenset(user.get_copy(ATProto) for user in others_queries) - logger.info(f'Loaded {len(atproto_dids)} ATProto, {len(bridged_dids)} bridged dids') - loaded_dids_at = util.now() + logger.info(f'Loaded {len(atproto_dids)} ATProto, {len(bridged_dids)} bridged dids') + loaded_dids_at = util.now() + if not DEBUG: + Timer(STORE_CURSOR_FREQ.total_seconds(), _load_dids).start() def subscribe(reconnect=True):