ATProto firehose: correctly load and handle bridged ATProto vs non-ATProto DIDs

for #978
pull/1049/head
Ryan Barrett 2024-05-07 16:17:44 -07:00
rodzic 38a8067ef7
commit 3057dd6757
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 6BE31FDF4776E9D4
1 zmienionych plików z 42 dodań i 20 usunięć

Wyświetl plik

@ -4,7 +4,9 @@ Usage:
ve && env GOOGLE_APPLICATION_CREDENTIALS=service_account_creds.json \
python firehose.py [RELAY_HOST [CURSOR]]
"""
import itertools
import json
import logging
import os
import sys
@ -14,15 +16,28 @@ from granary.bluesky import AT_URI_PATTERN
from lexrpc.client import Client
from oauth_dropins.webutil import appengine_config
from arroba.datastore_storage import AtpRepo
import activitypub, web # load protocol classes
from atproto import ATProto
from common import add
import models
logger = logging.getLogger(__name__)
with appengine_config.ndb_client.context():
dids = frozenset(key.id() for key in AtpRepo.query().iter(keys_only=True))
reset_protocol_properties()
print(f'Loaded {len(dids)} dids')
query = ATProto.query(ATProto.enabled_protocols != None)
our_atproto_dids = frozenset(key.id() for key in query.iter(keys_only=True))
other_queries = itertools.chain(*(
cls.query(cls.copies.protocol == 'atproto').iter()
for cls in set(models.PROTOCOLS.values())
if cls and cls != ATProto))
our_bridged_dids = frozenset(user.get_copy(ATProto) for user in other_queries)
print(f'Loaded {len(our_atproto_dids)} ATProto, {len(our_bridged_dids)} bridged dids')
print(f'Examples: {next(iter(our_atproto_dids))} {next(iter(our_bridged_dids))}')
assert len(sys.argv) <= 3
host = sys.argv[1] if len(sys.argv) >= 2 else 'bgs.bsky-sandbox.dev'
@ -30,24 +45,30 @@ cursor = sys.argv[2] if len(sys.argv) == 3 else None
scheme = 'http' if host.split(':')[0] == 'localhost' else 'https'
client = Client(f'{scheme}://{host}')
for header, payload in client.com.atproto.sync.subscribeRepos(cursor=cursor):
if header['op'] == -1:
print('error!', header)
elif header['t'] != '#commit':
continue
# is this from one of our bridged users?
repo = payload.get('repo')
if repo in dids:
# TODO: send
print('ours, from', repo)
# continue
# detect records that reference a bridged user, eg replies, likes,
# reposts, mentions
root, blocks = read_car(payload['blocks'])
blocks = {block.cid: block for block in blocks}
repo = payload.get('repo')
# is this from one of our non-Bluesky users?
if repo in our_bridged_dids:
logger.info(f'Got record from our non-ATProto bridged user {repo}, ignoring')
continue
# is this from one of our Bluesky users?
if repo in our_atproto_dids:
logger.info(f'Got record from our ATProto user {repo}, enqueueing')
# TODO: send
continue
# detect records that reference an ATProto user, eg replies, likes,
# reposts, mentions
for op in payload['ops']:
action = op['action']
cid = op['cid']
@ -61,10 +82,11 @@ for header, payload in client.com.atproto.sync.subscribeRepos(cursor=cursor):
block = blocks.get(op['cid'])
if not block:
# TODO: ???
print('missing block!!!', action, cid)
print(dag_json.encode(payload).decode())
for cid, block in blocks.items():
print(cid, dag_json.encode(block.decoded).decode())
# these are ours
# print('missing block!!!', action, cid)
# print(dag_json.encode(payload).decode())
# for cid, block in blocks.items():
# print(cid, dag_json.encode(block.decoded).decode())
continue
record = block.decoded
@ -81,7 +103,7 @@ for header, payload in client.com.atproto.sync.subscribeRepos(cursor=cursor):
subjects = []
def maybe_add(did):
if did and did in dids:
if did and did in our_atproto_dids:
add(subjects, did)
if type in ('app.bsky.feed.like', 'app.bsky.feed.repost'):
@ -99,7 +121,7 @@ for header, payload in client.com.atproto.sync.subscribeRepos(cursor=cursor):
# mentions
for facet in record.get('facets', []):
for feature in facet['features']:
if feature['$type'] == '#mention' and feature['did'] in dids:
if feature['$type'] == '#mention':
maybe_add(feature['did'])
# TODO: quote posts
@ -108,6 +130,6 @@ for header, payload in client.com.atproto.sync.subscribeRepos(cursor=cursor):
# 'app.bsky.embed.recordWithMedia'):
# if embed['record']
# print(action, type, repo, subjects)
if subjects:
logger.info(f'Got {type} that references {subjects}, enqueueing')
print(subjects, dag_json.encode(record).decode())