kopia lustrzana https://github.com/snarfed/bridgy-fed
				
				
				
			
		
			
				
	
	
		
			368 wiersze
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			368 wiersze
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
"""ATProto firehose client. Enqueues receive tasks for events for bridged users.
 | 
						|
 | 
						|
https://atproto.com/specs/event-stream
 | 
						|
https://atproto.com/specs/sync#firehose
 | 
						|
"""
 | 
						|
from collections import namedtuple
 | 
						|
from datetime import datetime, timedelta
 | 
						|
from io import BytesIO
 | 
						|
import itertools
 | 
						|
import logging
 | 
						|
import os
 | 
						|
from queue import Queue
 | 
						|
from threading import Event, Lock, Thread, Timer
 | 
						|
import threading
 | 
						|
import time
 | 
						|
 | 
						|
from arroba.datastore_storage import AtpRepo
 | 
						|
from arroba.util import parse_at_uri
 | 
						|
import dag_cbor
 | 
						|
import dag_json
 | 
						|
from google.cloud import ndb
 | 
						|
from google.cloud.ndb.exceptions import ContextError
 | 
						|
from granary.bluesky import AT_URI_PATTERN
 | 
						|
from lexrpc.client import Client
 | 
						|
import libipld
 | 
						|
from oauth_dropins.webutil import util
 | 
						|
from oauth_dropins.webutil.appengine_config import ndb_client
 | 
						|
from oauth_dropins.webutil.appengine_info import DEBUG
 | 
						|
from oauth_dropins.webutil.util import json_dumps, json_loads
 | 
						|
 | 
						|
from atproto import ATProto, Cursor
 | 
						|
from common import (
 | 
						|
    create_task,
 | 
						|
    NDB_CONTEXT_KWARGS,
 | 
						|
    PROTOCOL_DOMAINS,
 | 
						|
    report_error,
 | 
						|
    report_exception,
 | 
						|
    USER_AGENT,
 | 
						|
)
 | 
						|
from protocol import DELETE_TASK_DELAY
 | 
						|
from web import Web
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
RECONNECT_DELAY = timedelta(seconds=30)
 | 
						|
STORE_CURSOR_FREQ = timedelta(seconds=10)
 | 
						|
 | 
						|
# a commit operation. similar to arroba.repo.Write. record is None for deletes.
 | 
						|
Op = namedtuple('Op', ['action', 'repo', 'path', 'seq', 'record', 'time'],
 | 
						|
                # last four fields are optional
 | 
						|
                defaults=[None, None, None, None])
 | 
						|
 | 
						|
# contains Ops
 | 
						|
#
 | 
						|
# maxsize is important here! if we hit this limit, subscribe will block when it
 | 
						|
# tries to add more commits until handle consumes some. this keeps subscribe
 | 
						|
# from getting too far ahead of handle and using too much memory in this queue.
 | 
						|
commits = Queue(maxsize=1000)
 | 
						|
 | 
						|
# global so that subscribe can reuse it across calls
 | 
						|
cursor = None
 | 
						|
 | 
						|
# global: _load_dids populates them, subscribe and handle use them
 | 
						|
atproto_dids = set()
 | 
						|
atproto_loaded_at = datetime(1900, 1, 1)
 | 
						|
bridged_dids = set()
 | 
						|
bridged_loaded_at = datetime(1900, 1, 1)
 | 
						|
protocol_bot_dids = set()
 | 
						|
dids_initialized = Event()
 | 
						|
 | 
						|
 | 
						|
def load_dids():
 | 
						|
    # run in a separate thread since it needs to make its own NDB
 | 
						|
    # context when it runs in the timer thread
 | 
						|
    Thread(target=_load_dids).start()
 | 
						|
    dids_initialized.wait()
 | 
						|
    dids_initialized.clear()
 | 
						|
 | 
						|
 | 
						|
def _load_dids():
 | 
						|
    global atproto_dids, atproto_loaded_at, bridged_dids, bridged_loaded_at
 | 
						|
 | 
						|
    with ndb_client.context(**NDB_CONTEXT_KWARGS):
 | 
						|
        if not DEBUG:
 | 
						|
            Timer(STORE_CURSOR_FREQ.total_seconds(), _load_dids).start()
 | 
						|
 | 
						|
        atproto_query = ATProto.query(ATProto.status == None,
 | 
						|
                                      ATProto.enabled_protocols != None,
 | 
						|
                                      ATProto.updated > atproto_loaded_at)
 | 
						|
        loaded_at = ATProto.query().order(-ATProto.updated).get().updated
 | 
						|
        new_atproto = [key.id() for key in atproto_query.iter(keys_only=True)]
 | 
						|
        atproto_dids.update(new_atproto)
 | 
						|
        # set *after* we populate atproto_dids so that if we crash earlier, we
 | 
						|
        # re-query from the earlier timestamp
 | 
						|
        atproto_loaded_at = loaded_at
 | 
						|
 | 
						|
        bridged_query = AtpRepo.query(AtpRepo.status == None,
 | 
						|
                                      AtpRepo.created > bridged_loaded_at)
 | 
						|
        loaded_at = AtpRepo.query().order(-AtpRepo.created).get().created
 | 
						|
        new_bridged = [key.id() for key in bridged_query.iter(keys_only=True)]
 | 
						|
        bridged_dids.update(new_bridged)
 | 
						|
        # set *after* we populate bridged_dids so that if we crash earlier, we
 | 
						|
        # re-query from the earlier timestamp
 | 
						|
        bridged_loaded_at = loaded_at
 | 
						|
 | 
						|
        if not protocol_bot_dids:
 | 
						|
            bot_keys = [Web(id=domain).key for domain in PROTOCOL_DOMAINS]
 | 
						|
            for bot in ndb.get_multi(bot_keys):
 | 
						|
                if bot:
 | 
						|
                    if did := bot.get_copy(ATProto):
 | 
						|
                        logger.info(f'Loaded protocol bot user {bot.key.id()} {did}')
 | 
						|
                        protocol_bot_dids.add(did)
 | 
						|
 | 
						|
        dids_initialized.set()
 | 
						|
        total = len(atproto_dids) + len(bridged_dids)
 | 
						|
        logger.info(f'DIDs: {total} ATProto {len(atproto_dids)} (+{len(new_atproto)}), AtpRepo {len(bridged_dids)} (+{len(new_bridged)}); commits {commits.qsize()}')
 | 
						|
 | 
						|
 | 
						|
def subscriber():
 | 
						|
    """Wrapper around :func:`_subscribe` that catches exceptions and reconnects."""
 | 
						|
    logger.info(f'started thread to subscribe to {os.environ["BGS_HOST"]} firehose')
 | 
						|
    load_dids()
 | 
						|
 | 
						|
    with ndb_client.context(**NDB_CONTEXT_KWARGS):
 | 
						|
         while True:
 | 
						|
            try:
 | 
						|
                subscribe()
 | 
						|
            except BaseException:
 | 
						|
                report_exception()
 | 
						|
            logger.info(f'disconnected! waiting {RECONNECT_DELAY} and then reconnecting')
 | 
						|
            time.sleep(RECONNECT_DELAY.total_seconds())
 | 
						|
 | 
						|
 | 
						|
def subscribe():
 | 
						|
    """Subscribes to the relay's firehose.
 | 
						|
 | 
						|
    Relay hostname comes from the ``BGS_HOST`` environment variable.
 | 
						|
 | 
						|
    Args:
 | 
						|
      reconnect (bool): whether to always reconnect after we get disconnected
 | 
						|
    """
 | 
						|
    global cursor
 | 
						|
    if not cursor:
 | 
						|
        cursor = Cursor.get_or_insert(
 | 
						|
            f'{os.environ["BGS_HOST"]} com.atproto.sync.subscribeRepos')
 | 
						|
        # TODO: remove? does this make us skip events? if we remove it, will we
 | 
						|
        # infinite loop when we fail on an event?
 | 
						|
        if cursor.cursor:
 | 
						|
            cursor.cursor += 1
 | 
						|
 | 
						|
    last_stored_cursor = cur_timestamp = None
 | 
						|
 | 
						|
    client = Client(f'https://{os.environ["BGS_HOST"]}',
 | 
						|
                    headers={'User-Agent': USER_AGENT})
 | 
						|
 | 
						|
    for frame in client.com.atproto.sync.subscribeRepos(decode=False,
 | 
						|
                                                        cursor=cursor.cursor):
 | 
						|
        # parse header
 | 
						|
        header = libipld.decode_dag_cbor(frame)
 | 
						|
        if header.get('op') == -1:
 | 
						|
            _, payload = libipld.decode_dag_cbor_multi(frame)
 | 
						|
            logger.warning(f'Got error from relay! {payload}')
 | 
						|
            continue
 | 
						|
 | 
						|
        t = header.get('t')
 | 
						|
 | 
						|
        if t not in ('#commit', '#account', '#identity'):
 | 
						|
            if t not in ('#handle', '#tombstone'):
 | 
						|
                logger.info(f'Got {t} from relay')
 | 
						|
            continue
 | 
						|
 | 
						|
        # parse payload
 | 
						|
        _, payload = libipld.decode_dag_cbor_multi(frame)
 | 
						|
        repo = payload.get('repo') or payload.get('did')
 | 
						|
        if not repo:
 | 
						|
            logger.warning(f'Payload missing repo! {payload}')
 | 
						|
            continue
 | 
						|
 | 
						|
        seq = payload.get('seq')
 | 
						|
        if not seq:
 | 
						|
            logger.warning(f'Payload missing seq! {payload}')
 | 
						|
            continue
 | 
						|
 | 
						|
        cur_timestamp = payload['time']
 | 
						|
 | 
						|
        # if we fail processing this commit and raise an exception up to subscriber,
 | 
						|
        # skip it and start with the next commit when we're restarted
 | 
						|
        cursor.cursor = seq + 1
 | 
						|
 | 
						|
        elapsed = util.now().replace(tzinfo=None) - cursor.updated
 | 
						|
        if elapsed > STORE_CURSOR_FREQ:
 | 
						|
            events_s = 0
 | 
						|
            if last_stored_cursor:
 | 
						|
                events_s = int((cursor.cursor - last_stored_cursor) /
 | 
						|
                               elapsed.total_seconds())
 | 
						|
            last_stored_cursor = cursor.cursor
 | 
						|
 | 
						|
            behind = util.now() - util.parse_iso8601(cur_timestamp)
 | 
						|
 | 
						|
            # it's been long enough, update our stored cursor and metrics
 | 
						|
            logger.info(f'updating stored cursor to {cursor.cursor}, {events_s} events/s, {behind} ({int(behind.total_seconds())} s) behind')
 | 
						|
            cursor.put()
 | 
						|
            # when running locally, comment out put above and uncomment this
 | 
						|
            # cursor.updated = util.now().replace(tzinfo=None)
 | 
						|
 | 
						|
        if t in ('#account', '#identity'):
 | 
						|
            if repo in atproto_dids or repo in bridged_dids:
 | 
						|
                logger.debug(f'Got {t[1:]} {repo}')
 | 
						|
                commits.put(Op(action='account', repo=repo, seq=seq,
 | 
						|
                               time=cur_timestamp))
 | 
						|
            continue
 | 
						|
 | 
						|
        blocks = {}  # maps base32 str CID to dict block
 | 
						|
        if block_bytes := payload.get('blocks'):
 | 
						|
            _, blocks = libipld.decode_car(block_bytes)
 | 
						|
 | 
						|
        # detect records from bridged ATProto users that we should handle
 | 
						|
        for p_op in payload.get('ops', []):
 | 
						|
            op = Op(repo=payload['repo'], action=p_op.get('action'),
 | 
						|
                    path=p_op.get('path'), seq=payload['seq'], time=payload['time'])
 | 
						|
            if not op.action or not op.path:
 | 
						|
                logger.info(
 | 
						|
                    f'bad payload! seq {op.seq} action {op.action} path {op.path}!')
 | 
						|
                continue
 | 
						|
 | 
						|
            if op.repo in atproto_dids and op.action == 'delete':
 | 
						|
                # TODO: also detect deletes of records that *reference* our bridged
 | 
						|
                # users, eg a delete of a follow or like or repost of them.
 | 
						|
                # not easy because we need to getRecord the record to check
 | 
						|
                commits.put(op)
 | 
						|
                continue
 | 
						|
 | 
						|
            cid = p_op.get('cid')
 | 
						|
            block = blocks.get(cid)
 | 
						|
            # our own commits are sometimes missing the record
 | 
						|
            # https://github.com/snarfed/bridgy-fed/issues/1016
 | 
						|
            if not cid or not block:
 | 
						|
                continue
 | 
						|
 | 
						|
            op = op._replace(record=block)
 | 
						|
            type = op.record.get('$type')
 | 
						|
            if not type:
 | 
						|
                logger.warning('commit record missing $type! {op.action} {op.repo} {op.path} {cid}')
 | 
						|
                logger.warning(dag_json.encode(op.record).decode())
 | 
						|
                continue
 | 
						|
            elif type not in ATProto.SUPPORTED_RECORD_TYPES:
 | 
						|
                continue
 | 
						|
 | 
						|
            # generally we only want records from bridged Bluesky users. the one
 | 
						|
            # exception is follows of protocol bot users.
 | 
						|
            if (op.repo not in atproto_dids
 | 
						|
                and not (type == 'app.bsky.graph.follow'
 | 
						|
                         and op.record['subject'] in protocol_bot_dids)):
 | 
						|
                continue
 | 
						|
 | 
						|
            def is_ours(ref, also_atproto_users=False):
 | 
						|
                """Returns True if the arg is a bridge user."""
 | 
						|
                if match := AT_URI_PATTERN.match(ref['uri']):
 | 
						|
                    did = match.group('repo')
 | 
						|
                    return did and (did in bridged_dids
 | 
						|
                                    or also_atproto_users and did in atproto_dids)
 | 
						|
 | 
						|
            if type == 'app.bsky.feed.repost':
 | 
						|
                if not is_ours(op.record['subject'], also_atproto_users=True):
 | 
						|
                    continue
 | 
						|
 | 
						|
            elif type == 'app.bsky.feed.like':
 | 
						|
                if not is_ours(op.record['subject'], also_atproto_users=False):
 | 
						|
                    continue
 | 
						|
 | 
						|
            elif type in ('app.bsky.graph.block', 'app.bsky.graph.follow'):
 | 
						|
                if op.record['subject'] not in bridged_dids:
 | 
						|
                    continue
 | 
						|
 | 
						|
            elif type == 'app.bsky.feed.post':
 | 
						|
                if reply := op.record.get('reply'):
 | 
						|
                    if not is_ours(reply['parent'], also_atproto_users=True):
 | 
						|
                        continue
 | 
						|
 | 
						|
            commits.put(op)
 | 
						|
 | 
						|
 | 
						|
def handler():
 | 
						|
    """Wrapper around :func:`handle` that catches exceptions and restarts."""
 | 
						|
    logger.info(f'started handle thread to store objects and enqueue receive tasks')
 | 
						|
 | 
						|
    while True:
 | 
						|
        with ndb_client.context(**NDB_CONTEXT_KWARGS):
 | 
						|
            try:
 | 
						|
                handle()
 | 
						|
                # if we return cleanly, that means we hit the limit
 | 
						|
                break
 | 
						|
            except BaseException:
 | 
						|
                report_exception()
 | 
						|
                # fall through to loop to create new ndb context in case this is
 | 
						|
                # a ContextError
 | 
						|
                # https://console.cloud.google.com/errors/detail/CIvwj_7MmsfOWw;time=P1D;locations=global?project=bridgy-federated
 | 
						|
 | 
						|
 | 
						|
def handle(limit=None):
 | 
						|
    def _handle_account(op):
 | 
						|
        # reload DID doc to fetch new changes
 | 
						|
        ATProto.load(op.repo, did_doc=True, remote=True)
 | 
						|
 | 
						|
    def _handle(op):
 | 
						|
        at_uri = f'at://{op.repo}/{op.path}'
 | 
						|
 | 
						|
        type, _ = op.path.strip('/').split('/', maxsplit=1)
 | 
						|
        if type not in ATProto.SUPPORTED_RECORD_TYPES:
 | 
						|
            logger.info(f'Skipping unsupported type {type}: {at_uri}')
 | 
						|
            return
 | 
						|
 | 
						|
        # store object, enqueue receive task
 | 
						|
        verb = None
 | 
						|
        if op.action in ('create', 'update'):
 | 
						|
            record_kwarg = {
 | 
						|
                'bsky': op.record,
 | 
						|
            }
 | 
						|
            obj_id = at_uri
 | 
						|
        elif op.action == 'delete':
 | 
						|
            verb = (
 | 
						|
                'delete' if type in ('app.bsky.actor.profile', 'app.bsky.feed.post')
 | 
						|
                else 'stop-following' if type == 'app.bsky.graph.follow'
 | 
						|
                else 'undo')
 | 
						|
            obj_id = f'{at_uri}#{verb}'
 | 
						|
            record_kwarg = {
 | 
						|
                'our_as1': {
 | 
						|
                    'objectType': 'activity',
 | 
						|
                    'verb': verb,
 | 
						|
                    'id': obj_id,
 | 
						|
                    'actor': op.repo,
 | 
						|
                    'object': at_uri,
 | 
						|
                },
 | 
						|
            }
 | 
						|
        else:
 | 
						|
            logger.error(f'Unknown action {op.action} for {op.repo} {op.path}')
 | 
						|
            return
 | 
						|
 | 
						|
        if verb and verb not in ATProto.SUPPORTED_AS1_TYPES:
 | 
						|
            return
 | 
						|
 | 
						|
        logger.debug(f'Got {op.action} {op.repo} {op.path}')
 | 
						|
        delay = DELETE_TASK_DELAY if op.action == 'delete' else None
 | 
						|
        try:
 | 
						|
            create_task(queue='receive', id=obj_id, source_protocol=ATProto.LABEL,
 | 
						|
                        authed_as=op.repo, received_at=op.time, delay=delay,
 | 
						|
                        **record_kwarg)
 | 
						|
            # when running locally, comment out above and uncomment this
 | 
						|
            # logger.info(f'enqueuing receive task for {at_uri}')
 | 
						|
        except ContextError:
 | 
						|
            raise  # handled in handle()
 | 
						|
        except BaseException:
 | 
						|
            report_error(obj_id, exception=True)
 | 
						|
 | 
						|
    seen = 0
 | 
						|
    while op := commits.get():
 | 
						|
        match op.action:
 | 
						|
            case 'account':
 | 
						|
                _handle_account(op)
 | 
						|
            case _:
 | 
						|
                _handle(op)
 | 
						|
 | 
						|
        seen += 1
 | 
						|
        if limit is not None and seen >= limit:
 | 
						|
            return
 | 
						|
 | 
						|
    assert False, "handle thread shouldn't reach here!"
 |