fix various issues. some remain.

pull/350/head
halcy 2023-06-21 00:15:40 +03:00
rodzic 826a6f457a
commit d891589092
4 zmienionych plików z 103 dodań i 87 usunięć

Wyświetl plik

@ -380,7 +380,7 @@ class Mastodon(Internals):
@api_version("4.0.0", "4.0.0", _DICT_VERSION_ADMIN_DOMAIN_BLOCK)
def admin_domain_blocks(self, id: Optional[IdType] = None, max_id: Optional[IdType] = None, min_id: Optional[IdType] = None,
since_id: Optional[IdType] = None, limit: Optional[int] = None) -> PaginatableList[AdminDomainBlock]:
since_id: Optional[IdType] = None, limit: Optional[int] = None) -> Union[AdminDomainBlock, PaginatableList[AdminDomainBlock]]:
"""
Fetches a list of blocked domains. Requires scope `admin:read:domain_blocks`.

Wyświetl plik

@ -22,7 +22,7 @@ from mastodon.errors import MastodonNetworkError, MastodonIllegalArgumentError,
MastodonGatewayTimeoutError, MastodonServerError, MastodonAPIError, MastodonMalformedEventError
from mastodon.compat import urlparse, magic, PurePath, Path
from mastodon.defaults import _DEFAULT_STREAM_TIMEOUT, _DEFAULT_STREAM_RECONNECT_WAIT_SEC
from mastodon.types import AttribAccessDict, try_cast_recurse
from mastodon.types import AttribAccessDict, PaginatableList, try_cast_recurse
from mastodon.types import *
###
@ -271,10 +271,9 @@ class Mastodon():
response = response_object.content
# Parse link headers
if isinstance(response, list) and \
'Link' in response_object.headers and \
response_object.headers['Link'] != "":
response = AttribAccessList(response)
if isinstance(response, list) and 'Link' in response_object.headers and response_object.headers['Link'] != "":
if not isinstance(response, PaginatableList):
response = PaginatableList(response)
tmp_urls = requests.utils.parse_header_links(
response_object.headers['Link'].rstrip('>').replace('>,<', ',<'))
for url in tmp_urls:
@ -301,18 +300,12 @@ class Mastodon():
del next_params['min_id']
response._pagination_next = next_params
# Maybe other API users rely on the pagination info in the last item
# Will be removed in future
if isinstance(response[-1], AttribAccessDict):
response[-1]._pagination_next = next_params
if url['rel'] == 'prev':
# Be paranoid and extract since_id or min_id specifically
prev_url = url['url']
# Old and busted (pre-2.6.0): since_id pagination
matchgroups = re.search(
r"[?&]since_id=([^&]+)", prev_url)
matchgroups = re.search(r"[?&]since_id=([^&]+)", prev_url)
if matchgroups:
prev_params = copy.deepcopy(params)
prev_params['_pagination_method'] = method
@ -326,14 +319,8 @@ class Mastodon():
del prev_params['max_id']
response._pagination_prev = prev_params
# Maybe other API users rely on the pagination info in the first item
# Will be removed in future
if isinstance(response[0], AttribAccessDict):
response[0]._pagination_prev = prev_params
# New and fantastico (post-2.6.0): min_id pagination
matchgroups = re.search(
r"[?&]min_id=([^&]+)", prev_url)
matchgroups = re.search(r"[?&]min_id=([^&]+)", prev_url)
if matchgroups:
prev_params = copy.deepcopy(params)
prev_params['_pagination_method'] = method
@ -346,11 +333,6 @@ class Mastodon():
if "max_id" in prev_params:
del prev_params['max_id']
response._pagination_prev = prev_params
# Maybe other API users rely on the pagination info in the first item
# Will be removed in future
if isinstance(response[0], AttribAccessDict):
response[0]._pagination_prev = prev_params
return response

Wyświetl plik

@ -137,72 +137,28 @@ class MaybeSnowflakeIdType(str):
"""
return str(self.__val)
"""
IDs returned from Mastodon.py ar either primitive (int or str) or snowflake
(still int or str, but potentially convertible to datetime).
"""
IdType = Union[PrimitiveIdType, MaybeSnowflakeIdType]
T = TypeVar('T')
class PaginatableList(List[T]):
"""
This is a list with pagination information attached.
It is returned by the API when a list of items is requested, and the response contains
a Link header with pagination information.
"""
def __getattr__(self, attr):
if attr in self:
return self[attr]
else:
raise AttributeError(f"Attribute not found: {attr}")
def __setattr__(self, attr, val):
if attr in self:
raise AttributeError("Attribute-style access is read only")
super(NonPaginatableList, self).__setattr__(attr, val)
# TODO add the pagination housekeeping stuff
class NonPaginatableList(List[T]):
"""
This is just a list. I am subclassing the regular list out of pure paranoia about
potential oversights that might require me to add things to it later.
"""
pass
# Lists in Mastodon.py are either regular or paginatable
EntityList = Union[NonPaginatableList[T], PaginatableList[T]]
# Backwards compat alias
AttribAccessList = EntityList
# Helper functions for typecasting attempts
def try_cast(t, value, retry = True):
"""
Base case casting function. Handles:
* Casting to any AttribAccessDict subclass (directly, no special handling)
* Casting to MaybeSnowflakeIdType (directly, no special handling)
* Casting to bool (with possible conversion from json bool strings)
* Casting to datetime (with possible conversion from all kinds of funny date formats because unfortunately this is the world we live in)
* Casting to whatever t is
* Trying once again to AttribAccessDict as a fallback
Gives up and returns as-is if none of the above work.
"""
try:
if issubclass(t, AttribAccessDict) or t is MaybeSnowflakeIdType:
try:
value = t(**value)
except:
try:
value = AttribAccessDict(**value)
except:
pass
elif isinstance(t, bool):
if issubclass(t, AttribAccessDict):
value = t(**value)
elif issubclass(t, bool):
if isinstance(value, str):
if value.lower() == 'true':
value = True
elif value.lower() == 'false':
value = False
value = bool(value)
elif isinstance(t, datetime):
elif issubclass(t, datetime):
if isinstance(value, int):
value = datetime.fromtimestamp(value, timezone.utc)
elif isinstance(value, str):
@ -211,8 +167,11 @@ def try_cast(t, value, retry = True):
value = datetime.fromtimestamp(value_int, timezone.utc)
except:
value = dateutil.parser.parse(value)
except:
value = try_cast(AttribAccessDict, value, False)
else:
value = t(**value)
except Exception as e:
if retry:
value = try_cast(AttribAccessDict, value, False)
return value
def try_cast_recurse(t, value):
@ -241,6 +200,38 @@ def try_cast_recurse(t, value):
pass
return try_cast(t, value)
"""
IDs returned from Mastodon.py ar either primitive (int or str) or snowflake
(still int or str, but potentially convertible to datetime).
"""
IdType = Union[PrimitiveIdType, MaybeSnowflakeIdType]
T = TypeVar('T')
class PaginatableList(List[T]):
"""
This is a list with pagination information attached.
It is returned by the API when a list of items is requested, and the response contains
a Link header with pagination information.
"""
def __init__(self, *args, **kwargs):
"""
Initializes basic list and adds empty pagination information.
"""
super(PaginatableList, self).__init__(*args, **kwargs)
self._pagination_next = None
self._pagination_prev = None
class NonPaginatableList(List[T]):
"""
This is just a list. I am subclassing the regular list out of pure paranoia about
potential oversights that might require me to add things to it later.
"""
pass
"""Lists in Mastodon.py are either regular or paginatable"""
EntityList = Union[NonPaginatableList[T], PaginatableList[T]]
class AttribAccessDict(OrderedDict[str, Any]):
"""
Base return object class for Mastodon.py.
@ -256,12 +247,12 @@ class AttribAccessDict(OrderedDict[str, Any]):
Constructor that calls through to dict constructor and then sets attributes for all keys.
"""
super(AttribAccessDict, self).__init__()
if __annotations__ in self.__class__.__dict__:
if "__annotations__" in self.__class__.__dict__:
for attr, _ in self.__class__.__annotations__.items():
attr_name = attr
if hasattr(self.__class__, "_rename_map"):
attr_name = getattr(self.__class__, "_rename_map").get(attr, attr)
if attr_name in kwargs:
if attr_name in kwargs:
self[attr] = kwargs[attr_name]
assert not attr in kwargs, f"Duplicate attribute {attr}"
elif attr in kwargs:
@ -337,6 +328,23 @@ class AttribAccessDict(OrderedDict[str, Any]):
super(AttribAccessDict, self).__setattr__(key, val)
super(AttribAccessDict, self).__setitem__(key, val)
def __eq__(self, other):
"""
Equality checker with casting
"""
if isinstance(other, self.__class__):
return super(AttribAccessDict, self).__eq__(other)
else:
try:
casted = try_cast_recurse(self.__class__, other)
if isinstance(casted, self.__class__):
return super(AttribAccessDict, self).__eq__(casted)
else:
return False
except Exception as e:
pass
return False
"""An entity returned by the Mastodon API is either a dict or a list"""
Entity = Union[AttribAccessDict, EntityList]
@ -344,4 +352,7 @@ Entity = Union[AttribAccessDict, EntityList]
WebpushCryptoParamsPubkey = Dict[str, str]
"""A type containing the parameters for a derypting webpush data. Considered opaque / implementation detail."""
WebpushCryptoParamsPrivkey = Dict[str, str]
WebpushCryptoParamsPrivkey = Dict[str, str]
"""Backwards compatibility alias"""
AttribAccessList = PaginatableList

Wyświetl plik

@ -1,18 +1,42 @@
import pytest
from datetime import datetime
from mastodon.types import IdType
import typing
def get_type_class(typ):
try:
return typ.__extra__
except AttributeError:
try:
return typ.__origin__
except AttributeError:
pass
return typ
def real_issubclass(obj1, type2orig):
type1 = get_type_class(type(obj1))
type2 = get_type_class(type2orig)
valid_types = []
if type2 is typing.Union:
valid_types = type2orig.__args__
elif type2 is typing.Generic:
valid_types = [type2orig.__args__[0]]
else:
valid_types = [type2orig]
return issubclass(type1, tuple(valid_types))
@pytest.mark.vcr()
def test_id_hook(status):
assert isinstance(status['id'], int)
assert real_issubclass(status['id'], IdType)
@pytest.mark.vcr()
def test_id_hook_in_reply_to(api, status):
reply = api.status_post('Reply!', in_reply_to_id=status['id'])
try:
assert isinstance(reply['in_reply_to_id'], int)
assert isinstance(reply['in_reply_to_account_id'], int)
assert real_issubclass(reply['in_reply_to_id'], IdType)
assert real_issubclass(reply['in_reply_to_account_id'], IdType)
finally:
api.status_delete(reply['id'])
@ -21,18 +45,17 @@ def test_id_hook_in_reply_to(api, status):
def test_id_hook_within_reblog(api, status):
reblog = api.status_reblog(status['id'])
try:
assert isinstance(reblog['reblog']['id'], int)
assert real_issubclass(reblog['reblog']['id'], IdType)
finally:
api.status_delete(reblog['id'])
@pytest.mark.vcr()
def test_date_hook(status):
assert isinstance(status['created_at'], datetime)
assert real_issubclass(status['created_at'], datetime)
@pytest.mark.vcr()
def test_attribute_access(status):
assert status.id is not None
with pytest.raises(AttributeError):
status.id = 420
status.id = 420