From d891589092c49206c2bc86bdbe9abb31a5bffa03 Mon Sep 17 00:00:00 2001 From: halcy Date: Wed, 21 Jun 2023 00:15:40 +0300 Subject: [PATCH] fix various issues. some remain. --- mastodon/admin.py | 2 +- mastodon/internals.py | 30 ++-------- mastodon/types_base.py | 121 ++++++++++++++++++++++------------------- tests/test_hooks.py | 37 ++++++++++--- 4 files changed, 103 insertions(+), 87 deletions(-) diff --git a/mastodon/admin.py b/mastodon/admin.py index 351f8ef..ca726c6 100644 --- a/mastodon/admin.py +++ b/mastodon/admin.py @@ -380,7 +380,7 @@ class Mastodon(Internals): @api_version("4.0.0", "4.0.0", _DICT_VERSION_ADMIN_DOMAIN_BLOCK) def admin_domain_blocks(self, id: Optional[IdType] = None, max_id: Optional[IdType] = None, min_id: Optional[IdType] = None, - since_id: Optional[IdType] = None, limit: Optional[int] = None) -> PaginatableList[AdminDomainBlock]: + since_id: Optional[IdType] = None, limit: Optional[int] = None) -> Union[AdminDomainBlock, PaginatableList[AdminDomainBlock]]: """ Fetches a list of blocked domains. Requires scope `admin:read:domain_blocks`. diff --git a/mastodon/internals.py b/mastodon/internals.py index 61cac86..0984b51 100644 --- a/mastodon/internals.py +++ b/mastodon/internals.py @@ -22,7 +22,7 @@ from mastodon.errors import MastodonNetworkError, MastodonIllegalArgumentError, MastodonGatewayTimeoutError, MastodonServerError, MastodonAPIError, MastodonMalformedEventError from mastodon.compat import urlparse, magic, PurePath, Path from mastodon.defaults import _DEFAULT_STREAM_TIMEOUT, _DEFAULT_STREAM_RECONNECT_WAIT_SEC -from mastodon.types import AttribAccessDict, try_cast_recurse +from mastodon.types import AttribAccessDict, PaginatableList, try_cast_recurse from mastodon.types import * ### @@ -271,10 +271,9 @@ class Mastodon(): response = response_object.content # Parse link headers - if isinstance(response, list) and \ - 'Link' in response_object.headers and \ - response_object.headers['Link'] != "": - response = AttribAccessList(response) + if isinstance(response, list) and 'Link' in response_object.headers and response_object.headers['Link'] != "": + if not isinstance(response, PaginatableList): + response = PaginatableList(response) tmp_urls = requests.utils.parse_header_links( response_object.headers['Link'].rstrip('>').replace('>,<', ',<')) for url in tmp_urls: @@ -301,18 +300,12 @@ class Mastodon(): del next_params['min_id'] response._pagination_next = next_params - # Maybe other API users rely on the pagination info in the last item - # Will be removed in future - if isinstance(response[-1], AttribAccessDict): - response[-1]._pagination_next = next_params - if url['rel'] == 'prev': # Be paranoid and extract since_id or min_id specifically prev_url = url['url'] # Old and busted (pre-2.6.0): since_id pagination - matchgroups = re.search( - r"[?&]since_id=([^&]+)", prev_url) + matchgroups = re.search(r"[?&]since_id=([^&]+)", prev_url) if matchgroups: prev_params = copy.deepcopy(params) prev_params['_pagination_method'] = method @@ -326,14 +319,8 @@ class Mastodon(): del prev_params['max_id'] response._pagination_prev = prev_params - # Maybe other API users rely on the pagination info in the first item - # Will be removed in future - if isinstance(response[0], AttribAccessDict): - response[0]._pagination_prev = prev_params - # New and fantastico (post-2.6.0): min_id pagination - matchgroups = re.search( - r"[?&]min_id=([^&]+)", prev_url) + matchgroups = re.search(r"[?&]min_id=([^&]+)", prev_url) if matchgroups: prev_params = copy.deepcopy(params) prev_params['_pagination_method'] = method @@ -346,11 +333,6 @@ class Mastodon(): if "max_id" in prev_params: del prev_params['max_id'] response._pagination_prev = prev_params - - # Maybe other API users rely on the pagination info in the first item - # Will be removed in future - if isinstance(response[0], AttribAccessDict): - response[0]._pagination_prev = prev_params return response diff --git a/mastodon/types_base.py b/mastodon/types_base.py index a50f2fb..be07213 100644 --- a/mastodon/types_base.py +++ b/mastodon/types_base.py @@ -137,72 +137,28 @@ class MaybeSnowflakeIdType(str): """ return str(self.__val) -""" -IDs returned from Mastodon.py ar either primitive (int or str) or snowflake -(still int or str, but potentially convertible to datetime). -""" -IdType = Union[PrimitiveIdType, MaybeSnowflakeIdType] - -T = TypeVar('T') -class PaginatableList(List[T]): - """ - This is a list with pagination information attached. - - It is returned by the API when a list of items is requested, and the response contains - a Link header with pagination information. - """ - def __getattr__(self, attr): - if attr in self: - return self[attr] - else: - raise AttributeError(f"Attribute not found: {attr}") - - def __setattr__(self, attr, val): - if attr in self: - raise AttributeError("Attribute-style access is read only") - super(NonPaginatableList, self).__setattr__(attr, val) - # TODO add the pagination housekeeping stuff - -class NonPaginatableList(List[T]): - """ - This is just a list. I am subclassing the regular list out of pure paranoia about - potential oversights that might require me to add things to it later. - """ - pass - -# Lists in Mastodon.py are either regular or paginatable -EntityList = Union[NonPaginatableList[T], PaginatableList[T]] - -# Backwards compat alias -AttribAccessList = EntityList - # Helper functions for typecasting attempts def try_cast(t, value, retry = True): """ Base case casting function. Handles: * Casting to any AttribAccessDict subclass (directly, no special handling) - * Casting to MaybeSnowflakeIdType (directly, no special handling) * Casting to bool (with possible conversion from json bool strings) * Casting to datetime (with possible conversion from all kinds of funny date formats because unfortunately this is the world we live in) + * Casting to whatever t is + * Trying once again to AttribAccessDict as a fallback Gives up and returns as-is if none of the above work. """ try: - if issubclass(t, AttribAccessDict) or t is MaybeSnowflakeIdType: - try: - value = t(**value) - except: - try: - value = AttribAccessDict(**value) - except: - pass - elif isinstance(t, bool): + if issubclass(t, AttribAccessDict): + value = t(**value) + elif issubclass(t, bool): if isinstance(value, str): if value.lower() == 'true': value = True elif value.lower() == 'false': value = False value = bool(value) - elif isinstance(t, datetime): + elif issubclass(t, datetime): if isinstance(value, int): value = datetime.fromtimestamp(value, timezone.utc) elif isinstance(value, str): @@ -211,8 +167,11 @@ def try_cast(t, value, retry = True): value = datetime.fromtimestamp(value_int, timezone.utc) except: value = dateutil.parser.parse(value) - except: - value = try_cast(AttribAccessDict, value, False) + else: + value = t(**value) + except Exception as e: + if retry: + value = try_cast(AttribAccessDict, value, False) return value def try_cast_recurse(t, value): @@ -241,6 +200,38 @@ def try_cast_recurse(t, value): pass return try_cast(t, value) +""" +IDs returned from Mastodon.py ar either primitive (int or str) or snowflake +(still int or str, but potentially convertible to datetime). +""" +IdType = Union[PrimitiveIdType, MaybeSnowflakeIdType] + +T = TypeVar('T') +class PaginatableList(List[T]): + """ + This is a list with pagination information attached. + + It is returned by the API when a list of items is requested, and the response contains + a Link header with pagination information. + """ + def __init__(self, *args, **kwargs): + """ + Initializes basic list and adds empty pagination information. + """ + super(PaginatableList, self).__init__(*args, **kwargs) + self._pagination_next = None + self._pagination_prev = None + +class NonPaginatableList(List[T]): + """ + This is just a list. I am subclassing the regular list out of pure paranoia about + potential oversights that might require me to add things to it later. + """ + pass + +"""Lists in Mastodon.py are either regular or paginatable""" +EntityList = Union[NonPaginatableList[T], PaginatableList[T]] + class AttribAccessDict(OrderedDict[str, Any]): """ Base return object class for Mastodon.py. @@ -256,12 +247,12 @@ class AttribAccessDict(OrderedDict[str, Any]): Constructor that calls through to dict constructor and then sets attributes for all keys. """ super(AttribAccessDict, self).__init__() - if __annotations__ in self.__class__.__dict__: + if "__annotations__" in self.__class__.__dict__: for attr, _ in self.__class__.__annotations__.items(): attr_name = attr if hasattr(self.__class__, "_rename_map"): attr_name = getattr(self.__class__, "_rename_map").get(attr, attr) - if attr_name in kwargs: + if attr_name in kwargs: self[attr] = kwargs[attr_name] assert not attr in kwargs, f"Duplicate attribute {attr}" elif attr in kwargs: @@ -337,6 +328,23 @@ class AttribAccessDict(OrderedDict[str, Any]): super(AttribAccessDict, self).__setattr__(key, val) super(AttribAccessDict, self).__setitem__(key, val) + def __eq__(self, other): + """ + Equality checker with casting + """ + if isinstance(other, self.__class__): + return super(AttribAccessDict, self).__eq__(other) + else: + try: + casted = try_cast_recurse(self.__class__, other) + if isinstance(casted, self.__class__): + return super(AttribAccessDict, self).__eq__(casted) + else: + return False + except Exception as e: + pass + return False + """An entity returned by the Mastodon API is either a dict or a list""" Entity = Union[AttribAccessDict, EntityList] @@ -344,4 +352,7 @@ Entity = Union[AttribAccessDict, EntityList] WebpushCryptoParamsPubkey = Dict[str, str] """A type containing the parameters for a derypting webpush data. Considered opaque / implementation detail.""" -WebpushCryptoParamsPrivkey = Dict[str, str] \ No newline at end of file +WebpushCryptoParamsPrivkey = Dict[str, str] + +"""Backwards compatibility alias""" +AttribAccessList = PaginatableList \ No newline at end of file diff --git a/tests/test_hooks.py b/tests/test_hooks.py index ab33d4c..1bd14a8 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -1,18 +1,42 @@ import pytest from datetime import datetime +from mastodon.types import IdType +import typing +def get_type_class(typ): + try: + return typ.__extra__ + except AttributeError: + try: + return typ.__origin__ + except AttributeError: + pass + return typ + + +def real_issubclass(obj1, type2orig): + type1 = get_type_class(type(obj1)) + type2 = get_type_class(type2orig) + valid_types = [] + if type2 is typing.Union: + valid_types = type2orig.__args__ + elif type2 is typing.Generic: + valid_types = [type2orig.__args__[0]] + else: + valid_types = [type2orig] + return issubclass(type1, tuple(valid_types)) @pytest.mark.vcr() def test_id_hook(status): - assert isinstance(status['id'], int) + assert real_issubclass(status['id'], IdType) @pytest.mark.vcr() def test_id_hook_in_reply_to(api, status): reply = api.status_post('Reply!', in_reply_to_id=status['id']) try: - assert isinstance(reply['in_reply_to_id'], int) - assert isinstance(reply['in_reply_to_account_id'], int) + assert real_issubclass(reply['in_reply_to_id'], IdType) + assert real_issubclass(reply['in_reply_to_account_id'], IdType) finally: api.status_delete(reply['id']) @@ -21,18 +45,17 @@ def test_id_hook_in_reply_to(api, status): def test_id_hook_within_reblog(api, status): reblog = api.status_reblog(status['id']) try: - assert isinstance(reblog['reblog']['id'], int) + assert real_issubclass(reblog['reblog']['id'], IdType) finally: api.status_delete(reblog['id']) @pytest.mark.vcr() def test_date_hook(status): - assert isinstance(status['created_at'], datetime) + assert real_issubclass(status['created_at'], datetime) @pytest.mark.vcr() def test_attribute_access(status): assert status.id is not None - with pytest.raises(AttributeError): - status.id = 420 + status.id = 420 \ No newline at end of file