diff --git a/protocol.py b/protocol.py index fa847e3..0169fb2 100644 --- a/protocol.py +++ b/protocol.py @@ -45,7 +45,19 @@ objects_cache_lock = threading.Lock() logger = logging.getLogger(__name__) -class Protocol: +# maps string label to Protocol subclass. populated by ProtocolMeta. +protocols = {} + +class ProtocolMeta(type): + """:class:`Protocol` metaclass. Registers all subclasses in the protocols global.""" + def __new__(meta, name, bases, class_dict): + cls = super().__new__(meta, name, bases, class_dict) + if cls.LABEL: + protocols[cls.LABEL] = cls + return cls + + +class Protocol(metaclass=ProtocolMeta): """Base protocol class. Not to be instantiated; classmethods only. Attributes: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 0366a6f..312bcc0 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -5,9 +5,11 @@ from flask import g from oauth_dropins.webutil.testutil import requests_response import requests +import protocol from protocol import Protocol from flask_app import app from models import Follower, Object, User +from webmention import Webmention from .test_activitypub import ACTOR, REPLY from . import testutil @@ -35,6 +37,10 @@ class ProtocolTest(testutil.TestCase): self.request_context.pop() super().tearDown() + def test_protocols_global(self): + self.assertEqual(FakeProtocol, protocol.protocols['fake']) + self.assertEqual(Webmention, protocol.protocols['webmention']) + @patch('requests.get') def test_receive_reply_not_feed_not_notification(self, mock_get): Follower.get_or_create(ACTOR['id'], 'foo.com')