fix #398 and also fix pagination info not surviving persisting

pull/400/head
halcy 2025-03-02 11:13:35 +02:00
rodzic 85e0cf468b
commit 0f775f5fe1
6 zmienionych plików z 1699 dodań i 41 usunięć

Wyświetl plik

@ -82,7 +82,13 @@ class Mastodon():
return_type = override_type
except:
return_type = AttribAccessDict
return try_cast_recurse(return_type, value)
return_val = try_cast_recurse(return_type, value)
return_type_repr = None
try:
return_type_repr = return_val._mastopy_type
except:
pass
return return_val, return_type_repr
def __api_request(self, method, endpoint, params={}, files={}, headers={}, access_token_override=None, base_url_override=None,
do_ratelimiting=True, use_json=False, parse=True, return_response_object=False, skip_error_check=False, lang_override=None, override_type=None,
@ -93,6 +99,7 @@ class Mastodon():
Does a large amount of different things that I should document one day, but not today.
"""
response = None
final_type = None
remaining_wait = 0
# Add language to params if not None
@ -208,7 +215,7 @@ class Mastodon():
if not response_object.ok:
try:
response = self.__try_cast_to_type(response_object.json(), override_type = override_type) # TODO actually cast to an error type
response, final_type = self.__try_cast_to_type(response_object.json(), override_type = override_type) # TODO actually cast to an error type
if isinstance(response, dict) and 'error' in response:
error_msg = response['error']
elif isinstance(response, str):
@ -270,7 +277,7 @@ class Mastodon():
f"bad json content was {response_object.content!r}.",
f"Exception was: {e}"
)
response = self.__try_cast_to_type(response, override_type = override_type)
response, final_type = self.__try_cast_to_type(response, override_type = override_type)
else:
response = response_object.content
@ -278,6 +285,8 @@ class Mastodon():
if (isinstance(response, list) or force_pagination) and 'Link' in response_object.headers and response_object.headers['Link'] != "":
if not isinstance(response, PaginatableList) and not force_pagination:
response = PaginatableList(response)
if final_type is None:
final_type = str(type(response))
tmp_urls = requests.utils.parse_header_links(response_object.headers['Link'].rstrip('>').replace('>,<', ',<'))
for url in tmp_urls:
if 'rel' not in url:
@ -292,6 +301,7 @@ class Mastodon():
next_params = copy.deepcopy(params)
next_params['_pagination_method'] = method
next_params['_pagination_endpoint'] = endpoint
next_params['_mastopy_type'] = final_type
max_id = matchgroups.group(1)
if max_id.isdigit():
next_params['max_id'] = int(max_id)
@ -313,6 +323,7 @@ class Mastodon():
prev_params = copy.deepcopy(params)
prev_params['_pagination_method'] = method
prev_params['_pagination_endpoint'] = endpoint
prev_params['_mastopy_type'] = final_type
since_id = matchgroups.group(1)
if since_id.isdigit():
prev_params['since_id'] = int(since_id)
@ -328,6 +339,7 @@ class Mastodon():
prev_params = copy.deepcopy(params)
prev_params['_pagination_method'] = method
prev_params['_pagination_endpoint'] = endpoint
prev_params['_mastopy_type'] = final_type
min_id = matchgroups.group(1)
if min_id.isdigit():
prev_params['min_id'] = int(min_id)

Wyświetl plik

@ -54,6 +54,39 @@ This is a breaking change, and I'm sorry about it, but this will make every piec
of software using Mastodon.py more robust in the long run.
"""
def _str_to_type(mastopy_type):
"""
String name to internal type resolver
"""
# See if we need to parse a sub-type (i.e. [<something>] in the type name
sub_type = None
if "[" in mastopy_type and "]" in mastopy_type:
mastopy_type, sub_type = mastopy_type.split("[")
sub_type = sub_type[:-1]
if mastopy_type not in ["PaginatableList", "NonPaginatableList", "typing.Optional", "typing.Union"]:
raise ValueError(f"Subtype not allowed for type {mastopy_type} and subtype {sub_type}")
if "[" in mastopy_type or "]" in mastopy_type:
raise ValueError(f"Invalid type {mastopy_type}")
if sub_type is not None and ("[" in sub_type or "]" in sub_type):
raise ValueError(f"Invalid subtype {sub_type}")
# Build the actual type object.
from mastodon.return_types import ENTITY_NAME_MAP
full_type = None
if sub_type is not None:
sub_type = ENTITY_NAME_MAP.get(sub_type, None)
full_type = {
"PaginatableList": PaginatableList[sub_type],
"NonPaginatableList": NonPaginatableList[sub_type],
"typing.Optional": Optional[sub_type],
"typing.Union": Union[sub_type],
}[mastopy_type]
else:
full_type = ENTITY_NAME_MAP.get(mastopy_type, None)
if full_type is None:
raise ValueError(f"Unknown type {mastopy_type}")
return full_type
class MaybeSnowflakeIdType(str):
"""
Represents, maybe, a snowflake ID.
@ -282,6 +315,8 @@ def try_cast_recurse(t, value, union_specializer=None):
* Casting to Union, trying all types in the union until one works
Gives up and returns as-is if none of the above work.
"""
if type(t) == str:
t = _str_to_type(t)
if value is None:
return value
t = resolve_type(t)
@ -354,7 +389,7 @@ class Entity():
"""
def __init__(self):
self._mastopy_type = None
def to_json(self, pretty=True) -> str:
"""
Serialize to JSON.
@ -378,16 +413,21 @@ class Entity():
remove_renamed_fields(mastopy_data)
serialize_data = {
"_mastopy_version": "2.0.0",
"_mastopy_version": "2.0.1",
"_mastopy_type": self._mastopy_type,
"_mastopy_data": mastopy_data
"_mastopy_data": mastopy_data,
"_mastopy_extra_data": {}
}
if hasattr(self, "_pagination_next") and self._pagination_next is not None:
serialize_data["_mastopy_extra_data"]["_pagination_next"] = self._pagination_next
if hasattr(self, "_pagination_prev") and self._pagination_prev is not None:
serialize_data["_mastopy_extra_data"]["_pagination_prev"] = self._pagination_prev
def json_serial(obj):
if isinstance(obj, datetime):
return obj.isoformat()
if pretty:
return json.dumps(serialize_data, default=json_serial, indent=4)
else:
@ -421,37 +461,25 @@ class Entity():
if "_mastopy_type" not in json_result:
raise ValueError("JSON does not contain _mastopy_type field, refusing to parse.")
mastopy_type = json_result["_mastopy_type"]
# See if we need to parse a sub-type (i.e. [<something>] in the type name
sub_type = None
if "[" in mastopy_type and "]" in mastopy_type:
mastopy_type, sub_type = mastopy_type.split("[")
sub_type = sub_type[:-1]
if mastopy_type not in ["PaginatableList", "NonPaginatableList", "typing.Optional", "typing.Union"]:
raise ValueError(f"Subtype not allowed for type {mastopy_type} and subtype {sub_type}")
if "[" in mastopy_type or "]" in mastopy_type:
raise ValueError(f"Invalid type {mastopy_type}")
if sub_type is not None and ("[" in sub_type or "]" in sub_type):
raise ValueError(f"Invalid subtype {sub_type}")
# Build the actual type object.
from mastodon.return_types import ENTITY_NAME_MAP
full_type = None
if sub_type is not None:
sub_type = ENTITY_NAME_MAP.get(sub_type, None)
full_type = {
"PaginatableList": PaginatableList[sub_type],
"NonPaginatableList": NonPaginatableList[sub_type],
"typing.Optional": Optional[sub_type],
"typing.Union": Union[sub_type],
}[mastopy_type]
else:
full_type = ENTITY_NAME_MAP.get(mastopy_type, None)
if full_type is None:
raise ValueError(f"Unknown type {mastopy_type}")
full_type = _str_to_type(mastopy_type)
# Finally, try to cast to the generated type
return try_cast_recurse(full_type, json_result["_mastopy_data"])
return_data = try_cast_recurse(full_type, json_result["_mastopy_data"])
# Fill in pagination information if it is present in the persisted data
if "_mastopy_extra_data" in json_result:
if "_pagination_next" in json_result["_mastopy_extra_data"]:
return_data._pagination_next = try_cast_recurse(PaginationInfo, json_result["_mastopy_extra_data"]["_pagination_next"])
response_type = return_data._pagination_next.get("_mastopy_type", None)
if response_type is not None:
return_data._pagination_next["_mastopy_type"] = _str_to_type(response_type)
if "_pagination_prev" in json_result["_mastopy_extra_data"]:
return_data._pagination_prev = try_cast_recurse(PaginationInfo, json_result["_mastopy_extra_data"]["_pagination_prev"])
response_type = return_data._pagination_prev.get("_mastopy_type", None)
if response_type is not None:
return_data._pagination_prev["_mastopy_type"] = _str_to_type(response_type)
return return_data
class PaginationInfo(OrderedDict):

Wyświetl plik

@ -151,14 +151,19 @@ class Mastodon(Internals):
endpoint = params['_pagination_endpoint']
del params['_pagination_endpoint']
response_type = None
if '_mastopy_type' in params:
response_type = params['_mastopy_type']
del params['_mastopy_type']
force_pagination = False
if not isinstance(previous_page, list):
force_pagination = True
if not is_pagination_dict:
return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=type(previous_page))
return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=response_type)
else:
return self.__api_request(method, endpoint, params)
return self.__api_request(method, endpoint, params, override_type=response_type)
def fetch_previous(self, next_page: Union[PaginatableList[Entity], Entity, Dict]) -> Optional[Union[PaginatableList[Entity], Entity]]:
"""
@ -193,14 +198,19 @@ class Mastodon(Internals):
endpoint = params['_pagination_endpoint']
del params['_pagination_endpoint']
response_type = None
if '_mastopy_type' in params:
response_type = params['_mastopy_type']
del params['_mastopy_type']
force_pagination = False
if not isinstance(next_page, list):
force_pagination = True
if not is_pagination_dict:
return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=type(next_page))
return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=response_type)
else:
return self.__api_request(method, endpoint, params)
return self.__api_request(method, endpoint, params, override_type=response_type)
def fetch_remaining(self, first_page: PaginatableList[Entity]) -> PaginatableList[Entity]:
"""

Wyświetl plik

@ -3,7 +3,7 @@ import os
import vcr
# Set this to True to debug issues with tests
DEBUG_REQUESTS = True
DEBUG_REQUESTS = False
def _api(access_token='__MASTODON_PY_TEST_ACCESS_TOKEN', version="4.3.0", version_check_mode="created"):
import mastodon

Wyświetl plik

@ -11,6 +11,7 @@ import requests_mock
UNLIKELY_HASHTAG = "fgiztsshwiaqqiztpmmjbtvmescsculuvmgjgopwoeidbcrixp"
from mastodon.types_base import Entity
@contextmanager
def many_statuses(api, n=10, suffix=''):
@ -30,9 +31,44 @@ def test_fetch_next_previous(api):
statuses = api.account_statuses(account['id'], limit=5)
next_statuses = api.fetch_next(statuses)
assert next_statuses
assert type(next_statuses) == type(statuses)
for status in next_statuses:
assert status['id'] < statuses[0]['id']
assert type(status) == type(statuses[0])
previous_statuses = api.fetch_previous(next_statuses)
assert previous_statuses
assert type(previous_statuses) == type(statuses)
for status in previous_statuses:
assert status['id'] > next_statuses[-1]['id']
assert type(status) == type(statuses[0])
@pytest.mark.vcr()
def test_fetch_next_previous_after_persist(api):
account = api.account_verify_credentials()
with many_statuses(api):
statuses_orig = api.account_statuses(account['id'], limit=5)
statuses_persist_json = statuses_orig.to_json()
statuses = Entity.from_json(statuses_persist_json)
assert type(statuses) == type(statuses_orig)
assert type(statuses[0]) == type(statuses_orig[0])
next_statuses = api.fetch_next(statuses)
assert next_statuses
assert type(next_statuses) == type(statuses)
for status in next_statuses:
assert status['id'] < statuses[0]['id']
assert type(status) == type(statuses[0])
persisted_next_json = next_statuses.to_json()
next_statuses = Entity.from_json(persisted_next_json)
assert type(next_statuses) == type(statuses)
for status in next_statuses:
assert status['id'] < statuses[0]['id']
assert type(status) == type(statuses[0])
previous_statuses = api.fetch_previous(next_statuses)
assert previous_statuses
assert type(previous_statuses) == type(statuses)
for status in previous_statuses:
assert status['id'] > next_statuses[-1]['id']
assert type(status) == type(statuses[0])
@pytest.mark.vcr()
def test_fetch_next_previous_from_pagination_info(api):
@ -41,8 +77,17 @@ def test_fetch_next_previous_from_pagination_info(api):
statuses = api.account_statuses(account['id'], limit=5)
next_statuses = api.fetch_next(statuses._pagination_next)
assert next_statuses
assert type(next_statuses) == type(statuses)
for status in next_statuses:
assert status['id'] < statuses[0]['id']
assert type(status) == type(statuses[0])
previous_statuses = api.fetch_previous(next_statuses._pagination_prev)
assert previous_statuses
assert type(previous_statuses) == type(statuses)
for status in previous_statuses:
assert status['id'] > next_statuses[-1]['id']
assert type(status) == type(statuses[0])
@pytest.mark.vcr()
def test_fetch_remaining(api3):
@ -51,6 +96,9 @@ def test_fetch_remaining(api3):
hashtag_remaining = api3.fetch_remaining(hashtag)
assert hashtag_remaining
assert len(hashtag_remaining) >= 30
for status in hashtag_remaining:
assert UNLIKELY_HASHTAG in status['content']
assert type(status) == type(hashtag[0])
def test_link_headers(api):
rmock = requests_mock.Adapter()