diff --git a/atproto_firehose.py b/atproto_firehose.py index b9504465..990a097b 100644 --- a/atproto_firehose.py +++ b/atproto_firehose.py @@ -31,11 +31,13 @@ from common import ( global_cache, global_cache_policy, global_cache_timeout_policy, + PROTOCOL_DOMAINS, report_error, report_exception, USER_AGENT, ) from models import Object, reset_protocol_properties +from web import Web logger = logging.getLogger(__name__) @@ -62,6 +64,7 @@ atproto_dids = set() atproto_loaded_at = datetime(1900, 1, 1) bridged_dids = set() bridged_loaded_at = datetime(1900, 1, 1) +protocol_bot_dids = None dids_initialized = Event() @@ -69,6 +72,16 @@ def load_dids(): # run in a separate thread since it needs to make its own NDB # context when it runs in the timer thread Thread(target=_load_dids).start() + + global protocol_bot_dids + protocol_bot_dids = set() + bot_keys = [Web(id=domain).key for domain in PROTOCOL_DOMAINS] + for bot in ndb.get_multi(bot_keys): + if bot: + if did := bot.get_copy(ATProto): + logger.info(f'Loaded protocol bot user {bot.key.id()} {did}') + protocol_bot_dids.add(did) + dids_initialized.wait() dids_initialized.clear() @@ -188,9 +201,6 @@ def subscribe(): # when running locally, comment out put above and uncomment this # cursor.updated = util.now().replace(tzinfo=None) - if payload['repo'] not in atproto_dids: - continue - blocks = {} # maps base32 str CID to dict block if block_bytes := payload.get('blocks'): _, blocks = libipld.decode_car(block_bytes) @@ -204,7 +214,7 @@ def subscribe(): f'bad payload! seq {op.seq} action {op.action} path {op.path}!') continue - if op.action == 'delete': + if op.repo in atproto_dids and op.action == 'delete': logger.info(f'Got delete from our ATProto user: {op}') # TODO: also detect deletes of records that *reference* our bridged # users, eg a delete of a follow or like or repost of them. @@ -228,6 +238,13 @@ def subscribe(): elif type not in ATProto.SUPPORTED_RECORD_TYPES: continue + # generally we only want records from bridged Bluesky users. the one + # exception is follows of protocol bot users. + if (op.repo not in atproto_dids + and not (type == 'app.bsky.graph.follow' + and op.record['subject'] in protocol_bot_dids)): + continue + def is_ours(ref, also_atproto_users=False): """Returns True if the arg is a bridge user.""" if match := AT_URI_PATTERN.match(ref['uri']): diff --git a/tests/test_atproto_firehose.py b/tests/test_atproto_firehose.py index 7a8c5739..bd4bcab4 100644 --- a/tests/test_atproto_firehose.py +++ b/tests/test_atproto_firehose.py @@ -28,10 +28,11 @@ from atproto import ATProto, Cursor import atproto_firehose from atproto_firehose import commits, handle, Op, STORE_CURSOR_FREQ import common -from models import Object +from models import Object, Target import protocol from .testutil import TestCase from .test_atproto import DID_DOC +from web import Web A_CID = CID.decode('bafkreicqpqncshdd27sgztqgzocd3zhhqnnsv6slvzhs5uz6f57cq6lmtq') @@ -109,7 +110,7 @@ class ATProtoFirehoseSubscribeTest(ATProtoTestCase): atproto_firehose.bridged_loaded_at = datetime(1900, 1, 1) atproto_firehose.dids_initialized.clear() - self.make_bridged_atproto_user() + self.user = self.make_bridged_atproto_user() AtpRepo(id='did:alice', head='', signing_key_pem=b'').put() self.store_object(id='did:plc:bob', raw=DID_DOC) ATProto(id='did:plc:bob').put() @@ -295,6 +296,19 @@ class ATProtoFirehoseSubscribeTest(ATProtoTestCase): 'subject': 'did:eve', }) + def test_follow_of_protocol_bot_account_by_unbridged_user(self): + self.user.enabled_protocols = [] + self.user.put() + + self.make_user('fa.brid.gy', cls=Web, enabled_protocols=['atproto'], + copies=[Target(protocol='atproto', uri='did:fa')]) + AtpRepo(id='did:fa', head='', signing_key_pem=b'').put() + + self.assert_enqueues({ + '$type': 'app.bsky.graph.follow', + 'subject': 'did:fa', + }) + def test_block_of_our_user(self): self.assert_enqueues({ '$type': 'app.bsky.graph.block', @@ -373,7 +387,6 @@ class ATProtoFirehoseSubscribeTest(ATProtoTestCase): self.assertIn('did:plc:eve', atproto_firehose.atproto_dids) def test_load_dids_atprepo(self): - FakeWebsocketClient.to_receive = [({'op': 1, 't': '#info'}, {})] self.subscribe()