kopia lustrzana https://github.com/snarfed/bridgy-fed
365 wiersze
14 KiB
Python
365 wiersze
14 KiB
Python
"""ATProto firehose client. Enqueues receive tasks for events for bridged users."""
|
|
from collections import namedtuple
|
|
from datetime import datetime, timedelta
|
|
from io import BytesIO
|
|
import itertools
|
|
import logging
|
|
import os
|
|
from queue import Queue
|
|
from threading import Event, Lock, Thread, Timer
|
|
import threading
|
|
import time
|
|
|
|
from arroba.datastore_storage import AtpRepo
|
|
from arroba.util import parse_at_uri
|
|
import dag_cbor
|
|
import dag_json
|
|
from google.cloud import ndb
|
|
from google.cloud.ndb.exceptions import ContextError
|
|
from granary.bluesky import AT_URI_PATTERN
|
|
from lexrpc.client import Client
|
|
import libipld
|
|
from oauth_dropins.webutil import util
|
|
from oauth_dropins.webutil.appengine_config import ndb_client
|
|
from oauth_dropins.webutil.appengine_info import DEBUG
|
|
from oauth_dropins.webutil.util import json_dumps, json_loads
|
|
|
|
from atproto import ATProto, Cursor
|
|
from common import (
|
|
create_task,
|
|
NDB_CONTEXT_KWARGS,
|
|
PROTOCOL_DOMAINS,
|
|
report_error,
|
|
report_exception,
|
|
USER_AGENT,
|
|
)
|
|
from protocol import DELETE_TASK_DELAY
|
|
from web import Web
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
RECONNECT_DELAY = timedelta(seconds=30)
|
|
STORE_CURSOR_FREQ = timedelta(seconds=10)
|
|
|
|
# a commit operation. similar to arroba.repo.Write. record is None for deletes.
|
|
Op = namedtuple('Op', ['action', 'repo', 'path', 'seq', 'record', 'time'],
|
|
# last four fields are optional
|
|
defaults=[None, None, None, None])
|
|
|
|
# contains Ops
|
|
#
|
|
# maxsize is important here! if we hit this limit, subscribe will block when it
|
|
# tries to add more commits until handle consumes some. this keeps subscribe
|
|
# from getting too far ahead of handle and using too much memory in this queue.
|
|
commits = Queue(maxsize=1000)
|
|
|
|
# global so that subscribe can reuse it across calls
|
|
cursor = None
|
|
|
|
# global: _load_dids populates them, subscribe and handle use them
|
|
atproto_dids = set()
|
|
atproto_loaded_at = datetime(1900, 1, 1)
|
|
bridged_dids = set()
|
|
bridged_loaded_at = datetime(1900, 1, 1)
|
|
protocol_bot_dids = set()
|
|
dids_initialized = Event()
|
|
|
|
|
|
def load_dids():
|
|
# run in a separate thread since it needs to make its own NDB
|
|
# context when it runs in the timer thread
|
|
Thread(target=_load_dids).start()
|
|
dids_initialized.wait()
|
|
dids_initialized.clear()
|
|
|
|
|
|
def _load_dids():
|
|
global atproto_dids, atproto_loaded_at, bridged_dids, bridged_loaded_at
|
|
|
|
with ndb_client.context(**NDB_CONTEXT_KWARGS):
|
|
if not DEBUG:
|
|
Timer(STORE_CURSOR_FREQ.total_seconds(), _load_dids).start()
|
|
|
|
atproto_query = ATProto.query(ATProto.status == None,
|
|
ATProto.enabled_protocols != None,
|
|
ATProto.updated > atproto_loaded_at)
|
|
loaded_at = ATProto.query().order(-ATProto.updated).get().updated
|
|
new_atproto = [key.id() for key in atproto_query.iter(keys_only=True)]
|
|
atproto_dids.update(new_atproto)
|
|
# set *after* we populate atproto_dids so that if we crash earlier, we
|
|
# re-query from the earlier timestamp
|
|
atproto_loaded_at = loaded_at
|
|
|
|
bridged_query = AtpRepo.query(AtpRepo.status == None,
|
|
AtpRepo.created > bridged_loaded_at)
|
|
loaded_at = AtpRepo.query().order(-AtpRepo.created).get().created
|
|
new_bridged = [key.id() for key in bridged_query.iter(keys_only=True)]
|
|
bridged_dids.update(new_bridged)
|
|
# set *after* we populate bridged_dids so that if we crash earlier, we
|
|
# re-query from the earlier timestamp
|
|
bridged_loaded_at = loaded_at
|
|
|
|
if not protocol_bot_dids:
|
|
bot_keys = [Web(id=domain).key for domain in PROTOCOL_DOMAINS]
|
|
for bot in ndb.get_multi(bot_keys):
|
|
if bot:
|
|
if did := bot.get_copy(ATProto):
|
|
logger.info(f'Loaded protocol bot user {bot.key.id()} {did}')
|
|
protocol_bot_dids.add(did)
|
|
|
|
dids_initialized.set()
|
|
total = len(atproto_dids) + len(bridged_dids)
|
|
logger.info(f'DIDs: {total} ATProto {len(atproto_dids)} (+{len(new_atproto)}), AtpRepo {len(bridged_dids)} (+{len(new_bridged)}); commits {commits.qsize()}')
|
|
|
|
|
|
def subscriber():
|
|
"""Wrapper around :func:`_subscribe` that catches exceptions and reconnects."""
|
|
logger.info(f'started thread to subscribe to {os.environ["BGS_HOST"]} firehose')
|
|
load_dids()
|
|
|
|
with ndb_client.context(**NDB_CONTEXT_KWARGS):
|
|
while True:
|
|
try:
|
|
subscribe()
|
|
except BaseException:
|
|
report_exception()
|
|
logger.info(f'disconnected! waiting {RECONNECT_DELAY} and then reconnecting')
|
|
time.sleep(RECONNECT_DELAY.total_seconds())
|
|
|
|
|
|
def subscribe():
|
|
"""Subscribes to the relay's firehose.
|
|
|
|
Relay hostname comes from the ``BGS_HOST`` environment variable.
|
|
|
|
Args:
|
|
reconnect (bool): whether to always reconnect after we get disconnected
|
|
"""
|
|
global cursor
|
|
if not cursor:
|
|
cursor = Cursor.get_or_insert(
|
|
f'{os.environ["BGS_HOST"]} com.atproto.sync.subscribeRepos')
|
|
# TODO: remove? does this make us skip events? if we remove it, will we
|
|
# infinite loop when we fail on an event?
|
|
if cursor.cursor:
|
|
cursor.cursor += 1
|
|
|
|
last_stored_cursor = cur_timestamp = None
|
|
|
|
client = Client(f'https://{os.environ["BGS_HOST"]}',
|
|
headers={'User-Agent': USER_AGENT})
|
|
|
|
for frame in client.com.atproto.sync.subscribeRepos(decode=False,
|
|
cursor=cursor.cursor):
|
|
# parse header
|
|
header = libipld.decode_dag_cbor(frame)
|
|
if header.get('op') == -1:
|
|
_, payload = libipld.decode_dag_cbor_multi(frame)
|
|
logger.warning(f'Got error from relay! {payload}')
|
|
continue
|
|
|
|
t = header.get('t')
|
|
|
|
if t not in ('#commit', '#account', '#identity'):
|
|
if t not in ('#handle', '#tombstone'):
|
|
logger.info(f'Got {t} from relay')
|
|
continue
|
|
|
|
# parse payload
|
|
_, payload = libipld.decode_dag_cbor_multi(frame)
|
|
repo = payload.get('repo') or payload.get('did')
|
|
if not repo:
|
|
logger.warning(f'Payload missing repo! {payload}')
|
|
continue
|
|
|
|
seq = payload.get('seq')
|
|
if not seq:
|
|
logger.warning(f'Payload missing seq! {payload}')
|
|
continue
|
|
|
|
cur_timestamp = payload['time']
|
|
|
|
# if we fail processing this commit and raise an exception up to subscriber,
|
|
# skip it and start with the next commit when we're restarted
|
|
cursor.cursor = seq + 1
|
|
|
|
elapsed = util.now().replace(tzinfo=None) - cursor.updated
|
|
if elapsed > STORE_CURSOR_FREQ:
|
|
events_s = 0
|
|
if last_stored_cursor:
|
|
events_s = int((cursor.cursor - last_stored_cursor) /
|
|
elapsed.total_seconds())
|
|
last_stored_cursor = cursor.cursor
|
|
|
|
behind = util.now() - util.parse_iso8601(cur_timestamp)
|
|
|
|
# it's been long enough, update our stored cursor and metrics
|
|
logger.info(f'updating stored cursor to {cursor.cursor}, {events_s} events/s, {behind} ({int(behind.total_seconds())} s) behind')
|
|
cursor.put()
|
|
# when running locally, comment out put above and uncomment this
|
|
# cursor.updated = util.now().replace(tzinfo=None)
|
|
|
|
if t in ('#account', '#identity'):
|
|
if repo in atproto_dids or repo in bridged_dids:
|
|
logger.debug(f'Got {t[1:]} {repo}')
|
|
commits.put(Op(action='account', repo=repo, seq=seq,
|
|
time=cur_timestamp))
|
|
continue
|
|
|
|
blocks = {} # maps base32 str CID to dict block
|
|
if block_bytes := payload.get('blocks'):
|
|
_, blocks = libipld.decode_car(block_bytes)
|
|
|
|
# detect records from bridged ATProto users that we should handle
|
|
for p_op in payload.get('ops', []):
|
|
op = Op(repo=payload['repo'], action=p_op.get('action'),
|
|
path=p_op.get('path'), seq=payload['seq'], time=payload['time'])
|
|
if not op.action or not op.path:
|
|
logger.info(
|
|
f'bad payload! seq {op.seq} action {op.action} path {op.path}!')
|
|
continue
|
|
|
|
if op.repo in atproto_dids and op.action == 'delete':
|
|
logger.debug(f'Got delete from our ATProto user: {op}')
|
|
# TODO: also detect deletes of records that *reference* our bridged
|
|
# users, eg a delete of a follow or like or repost of them.
|
|
# not easy because we need to getRecord the record to check
|
|
commits.put(op)
|
|
continue
|
|
|
|
cid = p_op.get('cid')
|
|
block = blocks.get(cid)
|
|
# our own commits are sometimes missing the record
|
|
# https://github.com/snarfed/bridgy-fed/issues/1016
|
|
if not cid or not block:
|
|
continue
|
|
|
|
op = op._replace(record=block)
|
|
type = op.record.get('$type')
|
|
if not type:
|
|
logger.warning('commit record missing $type! {op.action} {op.repo} {op.path} {cid}')
|
|
logger.warning(dag_json.encode(op.record).decode())
|
|
continue
|
|
elif type not in ATProto.SUPPORTED_RECORD_TYPES:
|
|
continue
|
|
|
|
# generally we only want records from bridged Bluesky users. the one
|
|
# exception is follows of protocol bot users.
|
|
if (op.repo not in atproto_dids
|
|
and not (type == 'app.bsky.graph.follow'
|
|
and op.record['subject'] in protocol_bot_dids)):
|
|
continue
|
|
|
|
def is_ours(ref, also_atproto_users=False):
|
|
"""Returns True if the arg is a bridge user."""
|
|
if match := AT_URI_PATTERN.match(ref['uri']):
|
|
did = match.group('repo')
|
|
return did and (did in bridged_dids
|
|
or also_atproto_users and did in atproto_dids)
|
|
|
|
if type == 'app.bsky.feed.repost':
|
|
if not is_ours(op.record['subject'], also_atproto_users=True):
|
|
continue
|
|
|
|
elif type == 'app.bsky.feed.like':
|
|
if not is_ours(op.record['subject'], also_atproto_users=False):
|
|
continue
|
|
|
|
elif type in ('app.bsky.graph.block', 'app.bsky.graph.follow'):
|
|
if op.record['subject'] not in bridged_dids:
|
|
continue
|
|
|
|
elif type == 'app.bsky.feed.post':
|
|
if reply := op.record.get('reply'):
|
|
if not is_ours(reply['parent'], also_atproto_users=True):
|
|
continue
|
|
|
|
logger.debug(f'Got {op.action} {op.repo} {op.path}')
|
|
commits.put(op)
|
|
|
|
|
|
def handler():
|
|
"""Wrapper around :func:`handle` that catches exceptions and restarts."""
|
|
logger.info(f'started handle thread to store objects and enqueue receive tasks')
|
|
|
|
while True:
|
|
with ndb_client.context(**NDB_CONTEXT_KWARGS):
|
|
try:
|
|
handle()
|
|
# if we return cleanly, that means we hit the limit
|
|
break
|
|
except BaseException:
|
|
report_exception()
|
|
# fall through to loop to create new ndb context in case this is
|
|
# a ContextError
|
|
# https://console.cloud.google.com/errors/detail/CIvwj_7MmsfOWw;time=P1D;locations=global?project=bridgy-federated
|
|
|
|
|
|
def handle(limit=None):
|
|
def _handle_account(op):
|
|
# reload DID doc to fetch new changes
|
|
ATProto.load(op.repo, did_doc=True, remote=True)
|
|
|
|
def _handle(op):
|
|
at_uri = f'at://{op.repo}/{op.path}'
|
|
|
|
type, _ = op.path.strip('/').split('/', maxsplit=1)
|
|
if type not in ATProto.SUPPORTED_RECORD_TYPES:
|
|
logger.info(f'Skipping unsupported type {type}: {at_uri}')
|
|
return
|
|
|
|
# store object, enqueue receive task
|
|
verb = None
|
|
if op.action in ('create', 'update'):
|
|
record_kwarg = {
|
|
'bsky': op.record,
|
|
}
|
|
obj_id = at_uri
|
|
elif op.action == 'delete':
|
|
verb = (
|
|
'delete' if type in ('app.bsky.actor.profile', 'app.bsky.feed.post')
|
|
else 'stop-following' if type == 'app.bsky.graph.follow'
|
|
else 'undo')
|
|
obj_id = f'{at_uri}#{verb}'
|
|
record_kwarg = {
|
|
'our_as1': {
|
|
'objectType': 'activity',
|
|
'verb': verb,
|
|
'id': obj_id,
|
|
'actor': op.repo,
|
|
'object': at_uri,
|
|
},
|
|
}
|
|
else:
|
|
logger.error(f'Unknown action {action} for {op.repo} {op.path}')
|
|
return
|
|
|
|
if verb and verb not in ATProto.SUPPORTED_AS1_TYPES:
|
|
return
|
|
|
|
delay = DELETE_TASK_DELAY if op.action == 'delete' else None
|
|
try:
|
|
create_task(queue='receive', id=obj_id, source_protocol=ATProto.LABEL,
|
|
authed_as=op.repo, received_at=op.time, delay=delay,
|
|
**record_kwarg)
|
|
# when running locally, comment out above and uncomment this
|
|
# logger.info(f'enqueuing receive task for {at_uri}')
|
|
except ContextError:
|
|
raise # handled in handle()
|
|
except BaseException:
|
|
report_error(obj_id, exception=True)
|
|
|
|
seen = 0
|
|
while op := commits.get():
|
|
match op.action:
|
|
case 'account':
|
|
_handle_account(op)
|
|
case _:
|
|
_handle(op)
|
|
|
|
seen += 1
|
|
if limit is not None and seen >= limit:
|
|
return
|
|
|
|
assert False, "handle thread shouldn't reach here!"
|