From f6575cca4f114021c226eadca408fad1220ec335 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Sun, 19 May 2024 13:47:13 -0700 Subject: [PATCH] atproto_firehose.subscribe: skip bad commits fixes #1061 --- atproto_firehose.py | 33 ++++++++++++++++++++++++--------- tests/test_atproto_firehose.py | 33 ++++++++++++++++++++++++++++----- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/atproto_firehose.py b/atproto_firehose.py index bcd5af5..9a60583 100644 --- a/atproto_firehose.py +++ b/atproto_firehose.py @@ -42,6 +42,10 @@ Op = namedtuple('Op', ['action', 'repo', 'path', 'seq', 'record'], # contains Ops new_commits = SimpleQueue() +# global so that subscribe can reuse it across calls +subscribe_cursor = None + +# global: _load_dids populates them, subscribe and handle use them atproto_dids = set() atproto_loaded_at = datetime(1900, 1, 1) bridged_dids = set() @@ -77,7 +81,7 @@ def _load_dids(): bridged_dids.update(new_bridged) dids_initialized.set() - logger.info(f'DIDs: ATProto {len(atproto_dids)} (+{len(new_atproto)}, AtpRepo {len(bridged_dids)} (+{len(new_bridged)})') + logger.info(f'DIDs: ATProto {len(atproto_dids)} (+{len(new_atproto)}), AtpRepo {len(bridged_dids)} (+{len(new_bridged)})') def subscriber(): @@ -97,7 +101,6 @@ def subscriber(): report_exception() - def subscribe(): """Subscribes to the relay's firehose. @@ -108,14 +111,17 @@ def subscribe(): """ load_dids() - cursor = Cursor.get_by_id( - f'{os.environ["BGS_HOST"]} com.atproto.sync.subscribeRepos') - assert cursor + global subscribe_cursor + if not subscribe_cursor: + cursor = Cursor.get_by_id( + f'{os.environ["BGS_HOST"]} com.atproto.sync.subscribeRepos') + assert cursor + subscribe_cursor = cursor.cursor + 1 if cursor.cursor else None client = Client(f'https://{os.environ["BGS_HOST"]}') - sub_cursor = cursor.cursor + 1 if cursor.cursor else None - for header, payload in client.com.atproto.sync.subscribeRepos(cursor=sub_cursor): + for header, payload in client.com.atproto.sync.subscribeRepos( + cursor=subscribe_cursor): # parse header if header.get('op') == -1: logger.warning(f'Got error from relay! {payload}') @@ -130,6 +136,16 @@ def subscribe(): repo = payload.get('repo') assert repo + + seq = payload.get('seq') + if not seq: + logger.warning(f'Payload missing seq! {payload}') + continue + + # if we fail processing this commit and raise an exception up to subscriber, + # skip it and start with the next commit when we're restarted + subscribe_cursor = seq + 1 + # ops = ' '.join(f'{op.get("action")} {op.get("path")}' # for op in payload.get('ops', [])) # logger.info(f'seeing {payload.get("seq")} {repo} {ops}') @@ -147,7 +163,7 @@ def subscribe(): # reposts, mentions for p_op in payload.get('ops', []): op = Op(repo=repo, action=p_op.get('action'), path=p_op.get('path'), - seq=payload.get('seq')) + seq=seq) if not op.action or not op.path: logger.info( f'bad payload! seq {op.seq} has action {op.action} path {op.path}!') @@ -319,4 +335,3 @@ if LOCAL_SERVER or not DEBUG: Thread(target=subscriber, name='atproto_firehose.subscriber').start() Thread(target=handler, name='atproto_firehose.handler').start() - diff --git a/tests/test_atproto_firehose.py b/tests/test_atproto_firehose.py index 06e9a27..6b42e5c 100644 --- a/tests/test_atproto_firehose.py +++ b/tests/test_atproto_firehose.py @@ -125,7 +125,7 @@ class ATProtoFirehoseSubscribeTest(TestCase): subscribe() self.assertTrue(new_commits.empty()) - def test_error(self): + def test_error_message(self): FakeWebsocketClient.to_receive = [( {'op': -1}, {'error': 'ConsumerTooSlow', 'message': 'ketchup!'}, @@ -134,7 +134,7 @@ class ATProtoFirehoseSubscribeTest(TestCase): subscribe() self.assertTrue(new_commits.empty()) - def test_info(self): + def test_info_message(self): FakeWebsocketClient.to_receive = [( {'op': 1, 't': '#info'}, {'name': 'OutdatedCursor'}, @@ -356,6 +356,29 @@ class ATProtoFirehoseSubscribeTest(TestCase): 'subject': {'uri': 'at://did:alice/app.bsky.feed.post/tid'}, }) + def test_uncaught_exception_skips_commit(self): + self.cursor.cursor = 1 + self.cursor.put() + + FakeWebsocketClient.setup_receive( + Op(repo='did:x', action='create', path='y', seq=4, record={'foo': 'bar'})) + with patch('atproto_firehose.read_car', side_effect=RuntimeError('oops')), \ + self.assertRaises(RuntimeError): + subscribe() + + self.assertTrue(new_commits.empty()) + self.assertEqual( + 'https://bgs.local/xrpc/com.atproto.sync.subscribeRepos?cursor=2', + FakeWebsocketClient.url) + + self.assert_enqueues(action='update', record={ + '$type': 'app.bsky.feed.like', + 'subject': {'uri': 'at://did:alice/app.bsky.feed.post/tid'}, + }) + self.assertEqual( + 'https://bgs.local/xrpc/com.atproto.sync.subscribeRepos?cursor=5', + FakeWebsocketClient.url) + @patch('oauth_dropins.webutil.appengine_config.tasks_client.create_task') class ATProtoFirehoseHandleTest(TestCase): @@ -375,7 +398,7 @@ class ATProtoFirehoseHandleTest(TestCase): atproto_firehose.bridged_dids = None atproto_firehose.dids_initialized.clear() - def test_handle_create(self, mock_create_task): + def test_create(self, mock_create_task): reply = copy.deepcopy(REPLY_BSKY) # test that we encode CIDs and bytes as JSON reply['reply']['root']['cid'] = reply['reply']['parent']['cid'] = A_CID @@ -397,7 +420,7 @@ class ATProtoFirehoseHandleTest(TestCase): self.assert_task(mock_create_task, 'receive', '/queue/receive', obj=obj.key.urlsafe(), authed_as='did:plc:user') - def test_handle_delete(self, mock_create_task): + def test_delete(self, mock_create_task): new_commits.put(Op(repo='did:plc:user', action='delete', seq=789, path='app.bsky.feed.post/123', record=POST_BSKY)) @@ -417,7 +440,7 @@ class ATProtoFirehoseHandleTest(TestCase): self.assert_task(mock_create_task, 'receive', '/queue/receive', obj=obj.key.urlsafe(), authed_as='did:plc:user') - def test_handle_store_cursor(self, mock_create_task): + def test_store_cursor(self, mock_create_task): now = None def _now(tz=None): assert tz is None