fix #398 and also fix pagination info not surviving persisting

pull/400/head
halcy 2025-03-02 11:13:35 +02:00
rodzic 85e0cf468b
commit 0f775f5fe1
6 zmienionych plików z 1699 dodań i 41 usunięć

Wyświetl plik

@ -82,7 +82,13 @@ class Mastodon():
return_type = override_type return_type = override_type
except: except:
return_type = AttribAccessDict return_type = AttribAccessDict
return try_cast_recurse(return_type, value) return_val = try_cast_recurse(return_type, value)
return_type_repr = None
try:
return_type_repr = return_val._mastopy_type
except:
pass
return return_val, return_type_repr
def __api_request(self, method, endpoint, params={}, files={}, headers={}, access_token_override=None, base_url_override=None, def __api_request(self, method, endpoint, params={}, files={}, headers={}, access_token_override=None, base_url_override=None,
do_ratelimiting=True, use_json=False, parse=True, return_response_object=False, skip_error_check=False, lang_override=None, override_type=None, do_ratelimiting=True, use_json=False, parse=True, return_response_object=False, skip_error_check=False, lang_override=None, override_type=None,
@ -93,6 +99,7 @@ class Mastodon():
Does a large amount of different things that I should document one day, but not today. Does a large amount of different things that I should document one day, but not today.
""" """
response = None response = None
final_type = None
remaining_wait = 0 remaining_wait = 0
# Add language to params if not None # Add language to params if not None
@ -208,7 +215,7 @@ class Mastodon():
if not response_object.ok: if not response_object.ok:
try: try:
response = self.__try_cast_to_type(response_object.json(), override_type = override_type) # TODO actually cast to an error type response, final_type = self.__try_cast_to_type(response_object.json(), override_type = override_type) # TODO actually cast to an error type
if isinstance(response, dict) and 'error' in response: if isinstance(response, dict) and 'error' in response:
error_msg = response['error'] error_msg = response['error']
elif isinstance(response, str): elif isinstance(response, str):
@ -270,7 +277,7 @@ class Mastodon():
f"bad json content was {response_object.content!r}.", f"bad json content was {response_object.content!r}.",
f"Exception was: {e}" f"Exception was: {e}"
) )
response = self.__try_cast_to_type(response, override_type = override_type) response, final_type = self.__try_cast_to_type(response, override_type = override_type)
else: else:
response = response_object.content response = response_object.content
@ -278,6 +285,8 @@ class Mastodon():
if (isinstance(response, list) or force_pagination) and 'Link' in response_object.headers and response_object.headers['Link'] != "": if (isinstance(response, list) or force_pagination) and 'Link' in response_object.headers and response_object.headers['Link'] != "":
if not isinstance(response, PaginatableList) and not force_pagination: if not isinstance(response, PaginatableList) and not force_pagination:
response = PaginatableList(response) response = PaginatableList(response)
if final_type is None:
final_type = str(type(response))
tmp_urls = requests.utils.parse_header_links(response_object.headers['Link'].rstrip('>').replace('>,<', ',<')) tmp_urls = requests.utils.parse_header_links(response_object.headers['Link'].rstrip('>').replace('>,<', ',<'))
for url in tmp_urls: for url in tmp_urls:
if 'rel' not in url: if 'rel' not in url:
@ -292,6 +301,7 @@ class Mastodon():
next_params = copy.deepcopy(params) next_params = copy.deepcopy(params)
next_params['_pagination_method'] = method next_params['_pagination_method'] = method
next_params['_pagination_endpoint'] = endpoint next_params['_pagination_endpoint'] = endpoint
next_params['_mastopy_type'] = final_type
max_id = matchgroups.group(1) max_id = matchgroups.group(1)
if max_id.isdigit(): if max_id.isdigit():
next_params['max_id'] = int(max_id) next_params['max_id'] = int(max_id)
@ -313,6 +323,7 @@ class Mastodon():
prev_params = copy.deepcopy(params) prev_params = copy.deepcopy(params)
prev_params['_pagination_method'] = method prev_params['_pagination_method'] = method
prev_params['_pagination_endpoint'] = endpoint prev_params['_pagination_endpoint'] = endpoint
prev_params['_mastopy_type'] = final_type
since_id = matchgroups.group(1) since_id = matchgroups.group(1)
if since_id.isdigit(): if since_id.isdigit():
prev_params['since_id'] = int(since_id) prev_params['since_id'] = int(since_id)
@ -328,6 +339,7 @@ class Mastodon():
prev_params = copy.deepcopy(params) prev_params = copy.deepcopy(params)
prev_params['_pagination_method'] = method prev_params['_pagination_method'] = method
prev_params['_pagination_endpoint'] = endpoint prev_params['_pagination_endpoint'] = endpoint
prev_params['_mastopy_type'] = final_type
min_id = matchgroups.group(1) min_id = matchgroups.group(1)
if min_id.isdigit(): if min_id.isdigit():
prev_params['min_id'] = int(min_id) prev_params['min_id'] = int(min_id)

Wyświetl plik

@ -54,6 +54,39 @@ This is a breaking change, and I'm sorry about it, but this will make every piec
of software using Mastodon.py more robust in the long run. of software using Mastodon.py more robust in the long run.
""" """
def _str_to_type(mastopy_type):
"""
String name to internal type resolver
"""
# See if we need to parse a sub-type (i.e. [<something>] in the type name
sub_type = None
if "[" in mastopy_type and "]" in mastopy_type:
mastopy_type, sub_type = mastopy_type.split("[")
sub_type = sub_type[:-1]
if mastopy_type not in ["PaginatableList", "NonPaginatableList", "typing.Optional", "typing.Union"]:
raise ValueError(f"Subtype not allowed for type {mastopy_type} and subtype {sub_type}")
if "[" in mastopy_type or "]" in mastopy_type:
raise ValueError(f"Invalid type {mastopy_type}")
if sub_type is not None and ("[" in sub_type or "]" in sub_type):
raise ValueError(f"Invalid subtype {sub_type}")
# Build the actual type object.
from mastodon.return_types import ENTITY_NAME_MAP
full_type = None
if sub_type is not None:
sub_type = ENTITY_NAME_MAP.get(sub_type, None)
full_type = {
"PaginatableList": PaginatableList[sub_type],
"NonPaginatableList": NonPaginatableList[sub_type],
"typing.Optional": Optional[sub_type],
"typing.Union": Union[sub_type],
}[mastopy_type]
else:
full_type = ENTITY_NAME_MAP.get(mastopy_type, None)
if full_type is None:
raise ValueError(f"Unknown type {mastopy_type}")
return full_type
class MaybeSnowflakeIdType(str): class MaybeSnowflakeIdType(str):
""" """
Represents, maybe, a snowflake ID. Represents, maybe, a snowflake ID.
@ -282,6 +315,8 @@ def try_cast_recurse(t, value, union_specializer=None):
* Casting to Union, trying all types in the union until one works * Casting to Union, trying all types in the union until one works
Gives up and returns as-is if none of the above work. Gives up and returns as-is if none of the above work.
""" """
if type(t) == str:
t = _str_to_type(t)
if value is None: if value is None:
return value return value
t = resolve_type(t) t = resolve_type(t)
@ -354,7 +389,7 @@ class Entity():
""" """
def __init__(self): def __init__(self):
self._mastopy_type = None self._mastopy_type = None
def to_json(self, pretty=True) -> str: def to_json(self, pretty=True) -> str:
""" """
Serialize to JSON. Serialize to JSON.
@ -378,16 +413,21 @@ class Entity():
remove_renamed_fields(mastopy_data) remove_renamed_fields(mastopy_data)
serialize_data = { serialize_data = {
"_mastopy_version": "2.0.0", "_mastopy_version": "2.0.1",
"_mastopy_type": self._mastopy_type, "_mastopy_type": self._mastopy_type,
"_mastopy_data": mastopy_data "_mastopy_data": mastopy_data,
"_mastopy_extra_data": {}
} }
if hasattr(self, "_pagination_next") and self._pagination_next is not None:
serialize_data["_mastopy_extra_data"]["_pagination_next"] = self._pagination_next
if hasattr(self, "_pagination_prev") and self._pagination_prev is not None:
serialize_data["_mastopy_extra_data"]["_pagination_prev"] = self._pagination_prev
def json_serial(obj): def json_serial(obj):
if isinstance(obj, datetime): if isinstance(obj, datetime):
return obj.isoformat() return obj.isoformat()
if pretty: if pretty:
return json.dumps(serialize_data, default=json_serial, indent=4) return json.dumps(serialize_data, default=json_serial, indent=4)
else: else:
@ -421,37 +461,25 @@ class Entity():
if "_mastopy_type" not in json_result: if "_mastopy_type" not in json_result:
raise ValueError("JSON does not contain _mastopy_type field, refusing to parse.") raise ValueError("JSON does not contain _mastopy_type field, refusing to parse.")
mastopy_type = json_result["_mastopy_type"] mastopy_type = json_result["_mastopy_type"]
full_type = _str_to_type(mastopy_type)
# See if we need to parse a sub-type (i.e. [<something>] in the type name
sub_type = None
if "[" in mastopy_type and "]" in mastopy_type:
mastopy_type, sub_type = mastopy_type.split("[")
sub_type = sub_type[:-1]
if mastopy_type not in ["PaginatableList", "NonPaginatableList", "typing.Optional", "typing.Union"]:
raise ValueError(f"Subtype not allowed for type {mastopy_type} and subtype {sub_type}")
if "[" in mastopy_type or "]" in mastopy_type:
raise ValueError(f"Invalid type {mastopy_type}")
if sub_type is not None and ("[" in sub_type or "]" in sub_type):
raise ValueError(f"Invalid subtype {sub_type}")
# Build the actual type object.
from mastodon.return_types import ENTITY_NAME_MAP
full_type = None
if sub_type is not None:
sub_type = ENTITY_NAME_MAP.get(sub_type, None)
full_type = {
"PaginatableList": PaginatableList[sub_type],
"NonPaginatableList": NonPaginatableList[sub_type],
"typing.Optional": Optional[sub_type],
"typing.Union": Union[sub_type],
}[mastopy_type]
else:
full_type = ENTITY_NAME_MAP.get(mastopy_type, None)
if full_type is None:
raise ValueError(f"Unknown type {mastopy_type}")
# Finally, try to cast to the generated type # Finally, try to cast to the generated type
return try_cast_recurse(full_type, json_result["_mastopy_data"]) return_data = try_cast_recurse(full_type, json_result["_mastopy_data"])
# Fill in pagination information if it is present in the persisted data
if "_mastopy_extra_data" in json_result:
if "_pagination_next" in json_result["_mastopy_extra_data"]:
return_data._pagination_next = try_cast_recurse(PaginationInfo, json_result["_mastopy_extra_data"]["_pagination_next"])
response_type = return_data._pagination_next.get("_mastopy_type", None)
if response_type is not None:
return_data._pagination_next["_mastopy_type"] = _str_to_type(response_type)
if "_pagination_prev" in json_result["_mastopy_extra_data"]:
return_data._pagination_prev = try_cast_recurse(PaginationInfo, json_result["_mastopy_extra_data"]["_pagination_prev"])
response_type = return_data._pagination_prev.get("_mastopy_type", None)
if response_type is not None:
return_data._pagination_prev["_mastopy_type"] = _str_to_type(response_type)
return return_data
class PaginationInfo(OrderedDict): class PaginationInfo(OrderedDict):

Wyświetl plik

@ -151,14 +151,19 @@ class Mastodon(Internals):
endpoint = params['_pagination_endpoint'] endpoint = params['_pagination_endpoint']
del params['_pagination_endpoint'] del params['_pagination_endpoint']
response_type = None
if '_mastopy_type' in params:
response_type = params['_mastopy_type']
del params['_mastopy_type']
force_pagination = False force_pagination = False
if not isinstance(previous_page, list): if not isinstance(previous_page, list):
force_pagination = True force_pagination = True
if not is_pagination_dict: if not is_pagination_dict:
return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=type(previous_page)) return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=response_type)
else: else:
return self.__api_request(method, endpoint, params) return self.__api_request(method, endpoint, params, override_type=response_type)
def fetch_previous(self, next_page: Union[PaginatableList[Entity], Entity, Dict]) -> Optional[Union[PaginatableList[Entity], Entity]]: def fetch_previous(self, next_page: Union[PaginatableList[Entity], Entity, Dict]) -> Optional[Union[PaginatableList[Entity], Entity]]:
""" """
@ -193,14 +198,19 @@ class Mastodon(Internals):
endpoint = params['_pagination_endpoint'] endpoint = params['_pagination_endpoint']
del params['_pagination_endpoint'] del params['_pagination_endpoint']
response_type = None
if '_mastopy_type' in params:
response_type = params['_mastopy_type']
del params['_mastopy_type']
force_pagination = False force_pagination = False
if not isinstance(next_page, list): if not isinstance(next_page, list):
force_pagination = True force_pagination = True
if not is_pagination_dict: if not is_pagination_dict:
return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=type(next_page)) return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=response_type)
else: else:
return self.__api_request(method, endpoint, params) return self.__api_request(method, endpoint, params, override_type=response_type)
def fetch_remaining(self, first_page: PaginatableList[Entity]) -> PaginatableList[Entity]: def fetch_remaining(self, first_page: PaginatableList[Entity]) -> PaginatableList[Entity]:
""" """

Wyświetl plik

@ -3,7 +3,7 @@ import os
import vcr import vcr
# Set this to True to debug issues with tests # Set this to True to debug issues with tests
DEBUG_REQUESTS = True DEBUG_REQUESTS = False
def _api(access_token='__MASTODON_PY_TEST_ACCESS_TOKEN', version="4.3.0", version_check_mode="created"): def _api(access_token='__MASTODON_PY_TEST_ACCESS_TOKEN', version="4.3.0", version_check_mode="created"):
import mastodon import mastodon

Wyświetl plik

@ -11,6 +11,7 @@ import requests_mock
UNLIKELY_HASHTAG = "fgiztsshwiaqqiztpmmjbtvmescsculuvmgjgopwoeidbcrixp" UNLIKELY_HASHTAG = "fgiztsshwiaqqiztpmmjbtvmescsculuvmgjgopwoeidbcrixp"
from mastodon.types_base import Entity
@contextmanager @contextmanager
def many_statuses(api, n=10, suffix=''): def many_statuses(api, n=10, suffix=''):
@ -30,9 +31,44 @@ def test_fetch_next_previous(api):
statuses = api.account_statuses(account['id'], limit=5) statuses = api.account_statuses(account['id'], limit=5)
next_statuses = api.fetch_next(statuses) next_statuses = api.fetch_next(statuses)
assert next_statuses assert next_statuses
assert type(next_statuses) == type(statuses)
for status in next_statuses:
assert status['id'] < statuses[0]['id']
assert type(status) == type(statuses[0])
previous_statuses = api.fetch_previous(next_statuses) previous_statuses = api.fetch_previous(next_statuses)
assert previous_statuses assert previous_statuses
assert type(previous_statuses) == type(statuses)
for status in previous_statuses:
assert status['id'] > next_statuses[-1]['id']
assert type(status) == type(statuses[0])
@pytest.mark.vcr()
def test_fetch_next_previous_after_persist(api):
account = api.account_verify_credentials()
with many_statuses(api):
statuses_orig = api.account_statuses(account['id'], limit=5)
statuses_persist_json = statuses_orig.to_json()
statuses = Entity.from_json(statuses_persist_json)
assert type(statuses) == type(statuses_orig)
assert type(statuses[0]) == type(statuses_orig[0])
next_statuses = api.fetch_next(statuses)
assert next_statuses
assert type(next_statuses) == type(statuses)
for status in next_statuses:
assert status['id'] < statuses[0]['id']
assert type(status) == type(statuses[0])
persisted_next_json = next_statuses.to_json()
next_statuses = Entity.from_json(persisted_next_json)
assert type(next_statuses) == type(statuses)
for status in next_statuses:
assert status['id'] < statuses[0]['id']
assert type(status) == type(statuses[0])
previous_statuses = api.fetch_previous(next_statuses)
assert previous_statuses
assert type(previous_statuses) == type(statuses)
for status in previous_statuses:
assert status['id'] > next_statuses[-1]['id']
assert type(status) == type(statuses[0])
@pytest.mark.vcr() @pytest.mark.vcr()
def test_fetch_next_previous_from_pagination_info(api): def test_fetch_next_previous_from_pagination_info(api):
@ -41,8 +77,17 @@ def test_fetch_next_previous_from_pagination_info(api):
statuses = api.account_statuses(account['id'], limit=5) statuses = api.account_statuses(account['id'], limit=5)
next_statuses = api.fetch_next(statuses._pagination_next) next_statuses = api.fetch_next(statuses._pagination_next)
assert next_statuses assert next_statuses
assert type(next_statuses) == type(statuses)
for status in next_statuses:
assert status['id'] < statuses[0]['id']
assert type(status) == type(statuses[0])
previous_statuses = api.fetch_previous(next_statuses._pagination_prev) previous_statuses = api.fetch_previous(next_statuses._pagination_prev)
assert previous_statuses assert previous_statuses
assert type(previous_statuses) == type(statuses)
for status in previous_statuses:
assert status['id'] > next_statuses[-1]['id']
assert type(status) == type(statuses[0])
@pytest.mark.vcr() @pytest.mark.vcr()
def test_fetch_remaining(api3): def test_fetch_remaining(api3):
@ -51,6 +96,9 @@ def test_fetch_remaining(api3):
hashtag_remaining = api3.fetch_remaining(hashtag) hashtag_remaining = api3.fetch_remaining(hashtag)
assert hashtag_remaining assert hashtag_remaining
assert len(hashtag_remaining) >= 30 assert len(hashtag_remaining) >= 30
for status in hashtag_remaining:
assert UNLIKELY_HASHTAG in status['content']
assert type(status) == type(hashtag[0])
def test_link_headers(api): def test_link_headers(api):
rmock = requests_mock.Adapter() rmock = requests_mock.Adapter()