kopia lustrzana https://github.com/jointakahe/takahe
299 wiersze
9.1 KiB
Python
299 wiersze
9.1 KiB
Python
"""
|
|
Wrapper around HTTPX that provides some fedi-specific features.
|
|
|
|
The API is identical to httpx, but some features has been added:
|
|
|
|
* Fedi-compatible HTTP signatures
|
|
* Blocked IP ranges
|
|
|
|
(Because Y is next after X).
|
|
"""
|
|
import asyncio
|
|
import ipaddress
|
|
import logging
|
|
import socket
|
|
import typing
|
|
from ssl import SSLCertVerificationError, SSLError
|
|
from types import EllipsisType
|
|
|
|
import httpx
|
|
from django.conf import settings
|
|
from httpx import RequestError
|
|
from httpx._types import TimeoutTypes, URLTypes
|
|
from idna.core import InvalidCodepoint
|
|
|
|
from .signatures import HttpSignature
|
|
|
|
__all__ = (
|
|
"SigningActor",
|
|
"Client",
|
|
"AsyncClient",
|
|
"RequestError",
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SigningActor(typing.Protocol):
|
|
"""
|
|
An AP Actor with keys, that can sign requests.
|
|
|
|
Both :class:`users.models.identity.Identity`, and
|
|
:class:`users.models.system_actor.SystemActor` implement this protocol.
|
|
"""
|
|
|
|
#: The private key used for signing, in PEM format
|
|
private_key: str
|
|
|
|
# This is pretty much part of the interface, but we don't need it when
|
|
# making requests.
|
|
# public_key: str
|
|
|
|
#: The URL we should use to advertise this key
|
|
public_key_id: str
|
|
|
|
|
|
class SignedAuth(httpx.Auth):
|
|
"""
|
|
Handles signing the request.
|
|
"""
|
|
|
|
# Doing it this way so we get automatic sync/async handling
|
|
requires_request_body = True
|
|
|
|
def __init__(self, actor: SigningActor):
|
|
self.actor = actor
|
|
|
|
def auth_flow(self, request: httpx.Request):
|
|
HttpSignature.sign_request(
|
|
request, self.actor.private_key, self.actor.public_key_id
|
|
)
|
|
yield request
|
|
|
|
|
|
class BlockedIPError(Exception):
|
|
"""
|
|
Attempted to make a request that might have hit a blocked IP range.
|
|
"""
|
|
|
|
|
|
class IpFilterWrapperTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
|
|
def __init__(
|
|
self,
|
|
blocked_ranges: list[ipaddress.IPv4Network | ipaddress.IPv6Network | str],
|
|
wrappee: httpx.BaseTransport,
|
|
):
|
|
self.blocked_ranges = blocked_ranges
|
|
self.wrappee = wrappee
|
|
|
|
def __enter__(self):
|
|
self.wrappee.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.wrappee.__exit__(*exc)
|
|
|
|
def close(self):
|
|
self.wrappee.close()
|
|
|
|
async def __aenter__(self):
|
|
await self.wrappee.__aenter__()
|
|
return self
|
|
|
|
async def __aexit__(self, *exc):
|
|
await self.wrappee.__aexit__(self, *exc)
|
|
|
|
async def aclose(self):
|
|
await self.wrappee.close()
|
|
|
|
def _request_to_addrinfo(self, request) -> tuple:
|
|
return (
|
|
request.url.raw_host.decode("ascii"),
|
|
request.url.port or request.url.scheme,
|
|
)
|
|
|
|
def _check_addrinfo(self, req: httpx.Request, ai: typing.Sequence[tuple]):
|
|
"""
|
|
Compare an IP to the blocked ranges
|
|
"""
|
|
addr: ipaddress._BaseAddress
|
|
for info in ai:
|
|
match info:
|
|
case (socket.AF_INET, _, _, _, (addr, _)):
|
|
addr = ipaddress.IPv4Address(addr)
|
|
case (socket.AF_INET6, _, _, _, (addr, _, _, _)):
|
|
addr = ipaddress.IPv6Address(addr) # TODO: Do we need the flowinfo?
|
|
case _:
|
|
continue
|
|
|
|
for net in self.blocked_ranges:
|
|
if addr in net:
|
|
raise BlockedIPError(
|
|
"Attempted to make a connection to {addr} as {request.url.host} (blocked by {net})"
|
|
)
|
|
|
|
# It would have been nicer to do this at a lower level, so we know what
|
|
# IPs we're _actually_ connecting to, but:
|
|
# * That's really deep in httpcore and ughhhhhh
|
|
# * httpcore just passes the string hostname to the socket API anyway,
|
|
# and nobody wants to reimplement happy eyeballs, address fallback, etc
|
|
# * If any public name resolves to one of these ranges anyway, it's either
|
|
# misconfigured or malicious
|
|
|
|
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
|
try:
|
|
self._check_addrinfo(
|
|
request, socket.getaddrinfo(*self._request_to_addrinfo(request))
|
|
)
|
|
except socket.gaierror:
|
|
# Some kind of look up error. Gonna assume safe and let farther
|
|
# down the stack handle it.
|
|
pass
|
|
return self.wrappee.handle_request(request)
|
|
|
|
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
|
try:
|
|
self._check_addrinfo(
|
|
request,
|
|
await asyncio.get_running_loop().getaddrinfo(
|
|
*self._request_to_addrinfo(request)
|
|
),
|
|
)
|
|
except socket.gaierror:
|
|
# Some kind of look up error. Gonna assume safe and let farther
|
|
# down the stack handle it.
|
|
pass
|
|
return await self.wrappee.handle_await_request(request)
|
|
|
|
|
|
def _wrap_transport(
|
|
blocked_ranges: list[ipaddress.IPv4Network | ipaddress.IPv6Network | str]
|
|
| None
|
|
| EllipsisType,
|
|
transport,
|
|
):
|
|
"""
|
|
Gets an (Async)Transport that blocks the given IP ranges
|
|
"""
|
|
if blocked_ranges is ...:
|
|
blocked_ranges = settings.HTTP_BLOCKED_RANGES
|
|
|
|
if not blocked_ranges:
|
|
return transport
|
|
|
|
blocked_ranges = [
|
|
ipaddress.ip_network(net) if isinstance(net, str) else net
|
|
for net in typing.cast(typing.Iterable, blocked_ranges)
|
|
]
|
|
return IpFilterWrapperTransport(blocked_ranges, transport)
|
|
|
|
|
|
class BaseClient(httpx._client.BaseClient):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
actor: SigningActor | None = None,
|
|
blocked_ranges: list[ipaddress.IPv4Network | ipaddress.IPv6Network | str]
|
|
| None
|
|
| EllipsisType = ...,
|
|
timeout: TimeoutTypes = settings.SETUP.REMOTE_TIMEOUT,
|
|
**opts,
|
|
):
|
|
"""
|
|
Params:
|
|
actor: Actor to sign requests as, or None to not sign requests.
|
|
blocked_ranges: IP address to refuse to connect to. Either a list of
|
|
Networks, None to disable the feature, or Ellipsis to
|
|
pull the Django setting.
|
|
"""
|
|
if actor:
|
|
opts["auth"] = SignedAuth(actor)
|
|
self._blocked_ranges = blocked_ranges
|
|
|
|
super().__init__(timeout=timeout, **opts)
|
|
|
|
def _init_transport(self, *p, **kw):
|
|
transport = super()._init_transport(*p, **kw)
|
|
return _wrap_transport(self._blocked_ranges, transport)
|
|
|
|
def build_request(self, *pargs, **kwargs):
|
|
request = super().build_request(*pargs, **kwargs)
|
|
|
|
# GET requests get implicit accept headers added
|
|
if request.method == "GET" and "Accept" not in request.headers:
|
|
request.headers["Accept"] = "application/ld+json"
|
|
|
|
# TODO: Move this to __init__
|
|
request.headers["User-Agent"] = settings.TAKAHE_USER_AGENT
|
|
return request
|
|
|
|
|
|
# BaseClient before (Async)Client because __init__
|
|
|
|
|
|
class Client(BaseClient, httpx.Client):
|
|
def request(self, method: str, url: URLTypes, **params) -> httpx.Response:
|
|
"""
|
|
Wraps some errors up nicer
|
|
"""
|
|
if method.lower == "get":
|
|
if params["follow_redirects"] is httpx._client.USE_CLIENT_DEFAULT:
|
|
params["follow_redirects"] = True
|
|
|
|
try:
|
|
response = super().request(method, url, **params)
|
|
except SSLError as invalid_cert:
|
|
# Not our problem if the other end doesn't have proper SSL
|
|
logger.info("Invalid cert on %s %s", url, invalid_cert)
|
|
raise SSLCertVerificationError(invalid_cert) from invalid_cert
|
|
except InvalidCodepoint as ex:
|
|
# Convert to a more generic error we handle
|
|
raise httpx.HTTPError(f"InvalidCodepoint: {str(ex)}") from None
|
|
else:
|
|
return response
|
|
|
|
# Deliberately not doing the above to stream() because those use cases don't
|
|
# want that handling
|
|
|
|
def get(
|
|
self, url: URLTypes, *, accept: str | None = "application/ld+json", **params
|
|
):
|
|
"""
|
|
Args:
|
|
accept: Accept header, set to None to get the open option
|
|
"""
|
|
if accept:
|
|
params.setdefault("headers", {})["Accept"] = accept
|
|
return super().get(url, **params)
|
|
|
|
def post2(self, url: URLTypes, *, activity=None, **params):
|
|
"""
|
|
Like .post() but:
|
|
|
|
* Adds activity which is like json but for activities
|
|
* Handles response errors a bit
|
|
"""
|
|
if activity is not None:
|
|
params["json"] = activity
|
|
params.setdefault("headers", {}).setdefault(
|
|
"Content-Type", "application/activity+json"
|
|
)
|
|
|
|
response = self.post(url, **params)
|
|
|
|
if (
|
|
response.status_code >= 400
|
|
and response.status_code < 500
|
|
and response.status_code != 404
|
|
):
|
|
raise ValueError(
|
|
f"POST error to {url}: {response.status_code} {response.content!r}"
|
|
)
|
|
return response
|
|
|
|
|
|
class AsyncClient(BaseClient, httpx.AsyncClient):
|
|
# FIXME: Add the fancy methods the sync version has.
|
|
# (I'm being lazy because I don't think anyone's making async requests)
|
|
pass
|