bridgy-fed/atproto_firehose.py

338 wiersze
12 KiB
Python
Czysty Zwykły widok Historia

"""Bridgy Fed firehose client. Enqueues receive tasks for events for bridged users.
"""
from collections import namedtuple
from datetime import datetime, timedelta
import itertools
import logging
import os
2024-05-08 17:39:03 +00:00
from queue import SimpleQueue
import threading
from threading import Event, Lock, Thread, Timer
import time
from arroba.datastore_storage import AtpRepo
from carbox import read_car
import dag_json
from flask import Flask
from google.cloud import ndb
from granary.bluesky import AT_URI_PATTERN
from lexrpc.client import Client
2024-05-08 17:39:03 +00:00
from oauth_dropins.webutil import util
from oauth_dropins.webutil.appengine_config import ndb_client
from oauth_dropins.webutil.appengine_info import DEBUG, LOCAL_SERVER
from oauth_dropins.webutil.util import json_loads
from atproto import ATProto, Cursor
from common import add, create_task, report_exception
from models import Object, reset_protocol_properties
# all protocols
import activitypub, atproto, web
RECONNECT_DELAY = timedelta(seconds=30)
STORE_CURSOR_FREQ = timedelta(seconds=10)
logger = logging.getLogger(__name__)
# a commit operation. similar to arroba.repo.Write. record is None for deletes.
2024-05-09 06:11:27 +00:00
Op = namedtuple('Op', ['action', 'repo', 'path', 'seq', 'record'],
# record is optional
defaults=[None])
# contains Ops
2024-05-08 17:39:03 +00:00
new_commits = SimpleQueue()
# global so that subscribe can reuse it across calls
subscribe_cursor = None
# global: _load_dids populates them, subscribe and handle use them
atproto_dids = set()
atproto_loaded_at = datetime(1900, 1, 1)
bridged_dids = set()
bridged_loaded_at = datetime(1900, 1, 1)
dids_initialized = Event()
reset_protocol_properties()
def load_dids():
# run in a separate thread since it needs to make its own NDB
# context when it runs in the timer thread
Thread(target=_load_dids).start()
dids_initialized.wait()
def _load_dids():
global atproto_dids, atproto_loaded_at, bridged_dids, bridged_loaded_at
with ndb_client.context():
if not DEBUG:
Timer(STORE_CURSOR_FREQ.total_seconds(), _load_dids).start()
atproto_query = ATProto.query(ATProto.enabled_protocols != None,
ATProto.created > atproto_loaded_at)
atproto_loaded_at = ATProto.query().order(-ATProto.created).get().created
new_atproto = [key.id() for key in atproto_query.iter(keys_only=True)]
atproto_dids.update(new_atproto)
bridged_query = AtpRepo.query(AtpRepo.created > bridged_loaded_at)
bridged_loaded_at = AtpRepo.query().order(-AtpRepo.created).get().created
new_bridged = [key.id() for key in bridged_query.iter(keys_only=True)]
bridged_dids.update(new_bridged)
dids_initialized.set()
logger.info(f'DIDs: ATProto {len(atproto_dids)} (+{len(new_atproto)}), AtpRepo {len(bridged_dids)} (+{len(new_bridged)})')
def subscriber():
"""Wrapper around :func:`_subscribe` that catches exceptions and reconnects."""
logger.info(f'started thread to subscribe to {os.environ["BGS_HOST"]} firehose')
while True:
try:
with ndb_client.context(cache_policy=lambda key: False,
global_cache_policy=lambda key: False):
subscribe()
logger.info(f'disconnected! waiting {RECONNECT_DELAY} and then reconnecting')
time.sleep(RECONNECT_DELAY.total_seconds())
except BaseException:
report_exception()
def subscribe():
"""Subscribes to the relay's firehose.
Relay hostname comes from the ``BGS_HOST`` environment variable.
Args:
reconnect (bool): whether to always reconnect after we get disconnected
"""
load_dids()
global subscribe_cursor
if not subscribe_cursor:
cursor = Cursor.get_by_id(
f'{os.environ["BGS_HOST"]} com.atproto.sync.subscribeRepos')
assert cursor
subscribe_cursor = cursor.cursor + 1 if cursor.cursor else None
2024-05-09 06:11:27 +00:00
client = Client(f'https://{os.environ["BGS_HOST"]}')
for header, payload in client.com.atproto.sync.subscribeRepos(
cursor=subscribe_cursor):
2024-05-09 06:11:27 +00:00
# parse header
if header.get('op') == -1:
2024-05-08 00:36:34 +00:00
logger.warning(f'Got error from relay! {payload}')
continue
elif header.get('t') == '#info':
2024-05-08 00:36:34 +00:00
logger.info(f'Got info from relay: {payload}')
continue
elif header.get('t') != '#commit':
continue
2024-05-09 06:11:27 +00:00
# parse payload
repo = payload.get('repo')
assert repo
seq = payload.get('seq')
if not seq:
logger.warning(f'Payload missing seq! {payload}')
continue
# if we fail processing this commit and raise an exception up to subscriber,
# skip it and start with the next commit when we're restarted
subscribe_cursor = seq + 1
# ops = ' '.join(f'{op.get("action")} {op.get("path")}'
# for op in payload.get('ops', []))
# logger.info(f'seeing {payload.get("seq")} {repo} {ops}')
if repo in bridged_dids: # from a Bridgy Fed non-Bluesky user; ignore
logger.info(f'Ignoring record from our non-ATProto bridged user {repo}')
continue
blocks = {}
if block_bytes := payload.get('blocks'):
_, blocks = read_car(block_bytes)
blocks = {block.cid: block for block in blocks}
# detect records that reference an ATProto user, eg replies, likes,
# reposts, mentions
for p_op in payload.get('ops', []):
op = Op(repo=repo, action=p_op.get('action'), path=p_op.get('path'),
seq=seq)
if not op.action or not op.path:
logger.info(
f'bad payload! seq {op.seq} has action {op.action} path {op.path}!')
continue
is_ours = repo in atproto_dids
if is_ours and op.action == 'delete':
logger.info(f'Got delete from our ATProto user: {op}')
# TODO: also detect deletes of records that *reference* our bridged
# users, eg a delete of a follow or like or repost of them.
# not easy because we need to getRecord the record to check
new_commits.put(op)
continue
cid = p_op.get('cid')
block = blocks.get(cid)
# our own commits are sometimes missing the record
# https://github.com/snarfed/bridgy-fed/issues/1016
if not cid or not block:
continue
try:
op = Op(*op[:-1], record=block.decoded)
except BaseException:
# https://github.com/hashberg-io/dag-cbor/issues/14
logger.error(f"Couldn't decode block {cid} seq {op.seq}",
exc_info=True)
continue
type = op.record.get('$type')
if not type:
logger.warning('commit record missing $type! {op.action} {op.repo} {op.path} {cid}')
logger.warning(dag_json.encode(op.record).decode())
continue
if is_ours:
logger.info(f'Got one from our ATProto user: {op.action} {op.repo} {op.path}')
new_commits.put(op)
2024-05-08 17:39:03 +00:00
continue
subjects = []
2024-05-08 18:18:50 +00:00
def maybe_add(did_or_ref):
if isinstance(did_or_ref, dict):
match = AT_URI_PATTERN.match(did_or_ref['uri'])
if match:
did = match.group('repo')
else:
return
2024-05-08 18:18:50 +00:00
else:
did = did_or_ref
if did and did in bridged_dids:
add(subjects, did)
if type in ('app.bsky.feed.like', 'app.bsky.feed.repost'):
maybe_add(op.record['subject'])
elif type in ('app.bsky.graph.block', 'app.bsky.graph.follow'):
maybe_add(op.record['subject'])
elif type == 'app.bsky.feed.post':
# replies
if reply := op.record.get('reply'):
for ref in 'parent', 'root':
2024-05-08 18:18:50 +00:00
maybe_add(reply[ref])
# mentions
for facet in op.record.get('facets', []):
for feature in facet['features']:
if feature['$type'] == 'app.bsky.richtext.facet#mention':
maybe_add(feature['did'])
# quote posts
if embed := op.record.get('embed'):
if embed['$type'] == 'app.bsky.embed.record':
maybe_add(embed['record'])
elif embed['$type'] == 'app.bsky.embed.recordWithMedia':
maybe_add(embed['record']['record'])
2024-05-08 17:39:03 +00:00
if subjects:
logger.info(f'Got one re our ATProto users {subjects}: {op.action} {op.repo} {op.path}')
new_commits.put(op)
def handler():
"""Wrapper around :func:`_handle` that catches exceptions and restarts."""
logger.info(f'started handle thread to store objects and enqueue receive tasks')
2024-05-08 17:39:03 +00:00
while True:
try:
with ndb_client.context(cache_policy=lambda key: False,
global_cache_policy=lambda key: False):
handle()
# if we return cleanly, that means we hit the limit
break
except BaseException:
report_exception()
2024-05-08 17:39:03 +00:00
def handle(limit=None):
cursor = Cursor.get_by_id(
f'{os.environ["BGS_HOST"]} com.atproto.sync.subscribeRepos')
assert cursor
count = 0
while op := new_commits.get():
at_uri = f'at://{op.repo}/{op.path}'
2024-05-08 17:39:03 +00:00
# store object, enqueue receive task
if op.action in ('create', 'update'):
record_kwarg = {
'bsky': json_loads(dag_json.encode(op.record, dialect='atproto')),
}
obj_id = at_uri
elif op.action == 'delete':
obj_id = f'{at_uri}#delete'
record_kwarg = {'our_as1': {
'objectType': 'activity',
'verb': 'delete',
'id': obj_id,
'actor': op.repo,
'object': at_uri,
}}
else:
logger.error(f'Unknown action {action} for {op.repo} {op.path}')
continue
try:
# logger.info(f'enqueuing receive task for {at_uri}')
obj = Object.get_or_create(id=obj_id, actor=op.repo, status='new',
users=[ATProto(id=op.repo).key],
source_protocol=ATProto.LABEL, **record_kwarg)
create_task(queue='receive', obj=obj.key.urlsafe(), authed_as=op.repo)
except BaseException:
2024-05-09 06:11:27 +00:00
if DEBUG:
raise
report_exception()
2024-05-09 06:11:27 +00:00
if util.now().replace(tzinfo=None) - cursor.updated > STORE_CURSOR_FREQ:
# it's been long enough, update our stored cursor
logger.info(f'updating stored cursor to {op.seq}')
2024-05-09 06:11:27 +00:00
cursor.cursor = op.seq
cursor.put()
count += 1
if limit is not None and count >= limit:
return
2024-05-08 17:39:03 +00:00
assert False, "handle thread shouldn't reach here!"
2024-05-08 00:36:34 +00:00
# Flask app
app = Flask(__name__)
@app.get('/liveness_check')
@app.get('/readiness_check')
def health_check():
return 'OK'
if LOCAL_SERVER or not DEBUG:
threads = [t.name for t in threading.enumerate()]
assert 'atproto_firehose.subscriber' not in threads
assert 'atproto_firehose.handler' not in threads
Thread(target=subscriber, name='atproto_firehose.subscriber').start()
2024-05-14 23:53:15 +00:00
Thread(target=handler, name='atproto_firehose.handler').start()