kopia lustrzana https://github.com/jointakahe/takahe
Actually implement range blocking.
rodzic
6ebc23e24b
commit
c2ecb53bb3
110
core/httpy.py
110
core/httpy.py
|
@ -8,8 +8,9 @@ The API is identical to httpx, but some features has been added:
|
|||
|
||||
(Because Y is next after X).
|
||||
"""
|
||||
import functools
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import socket
|
||||
import typing
|
||||
from types import EllipsisType
|
||||
|
||||
|
@ -57,11 +58,96 @@ class SignedAuth(httpx.Auth):
|
|||
yield request
|
||||
|
||||
|
||||
@functools.lru_cache # Reuse transports
|
||||
def _get_transport(
|
||||
class BlockedIPError(Exception):
|
||||
"""
|
||||
Attempted to make a request that might have hit a blocked IP range.
|
||||
"""
|
||||
|
||||
|
||||
class IpFilterWrapperTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
|
||||
def __init__(
|
||||
self,
|
||||
blocked_ranges: list[ipaddress.IPv4Network | ipaddress.IPv6Network | str],
|
||||
wrappee: httpx.BaseTransport,
|
||||
):
|
||||
self.blocked_ranges = blocked_ranges
|
||||
self.wrappee = wrappee
|
||||
|
||||
def __enter__(self):
|
||||
self.wrappee.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.wrappee.__exit__(*exc)
|
||||
|
||||
def close(self):
|
||||
self.wrappee.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.wrappee.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
await self.wrappee.__aexit__(self, *exc)
|
||||
|
||||
async def aclose(self):
|
||||
await self.wrappee.close()
|
||||
|
||||
def _request_to_addrinfo(self, request) -> tuple:
|
||||
return (
|
||||
request.url.raw_host.decode("ascii"),
|
||||
request.url.port or request.url.scheme,
|
||||
)
|
||||
|
||||
def _check_addrinfo(self, req: httpx.Request, ai: typing.Sequence[tuple]):
|
||||
"""
|
||||
Compare an IP to the blocked ranges
|
||||
"""
|
||||
addr: ipaddress._BaseAddress
|
||||
for info in ai:
|
||||
match info:
|
||||
case (socket.AF_INET, _, _, _, (addr, _)):
|
||||
addr = ipaddress.IPv4Address(addr)
|
||||
case (socket.AF_INET6, _, _, _, (addr, _, _, _)):
|
||||
addr = ipaddress.IPv6Address(addr) # TODO: Do we need the flowinfo?
|
||||
case _:
|
||||
continue
|
||||
|
||||
for net in self.blocked_ranges:
|
||||
if addr in net:
|
||||
raise BlockedIPError(
|
||||
"Attempted to make a connection to {addr} as {request.url.host} (blocked by {net})"
|
||||
)
|
||||
|
||||
# It would have been nicer to do this at a lower level, so we know what
|
||||
# IPs we're _actually_ connecting to, but:
|
||||
# * That's really deep in httpcore and ughhhhhh
|
||||
# * httpcore just passes the string hostname to the socket API anyway,
|
||||
# and nobody wants to reimplement happy eyeballs, address fallback, etc
|
||||
# * If any public name resolves to one of these ranges anyway, it's either
|
||||
# misconfigured or malicious
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
self._check_addrinfo(
|
||||
request, socket.getaddrinfo(*self._request_to_addrinfo(request))
|
||||
)
|
||||
return super().handle_request(request)
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
self._check_addrinfo(
|
||||
request,
|
||||
await asyncio.get_running_loop().getaddrinfo(
|
||||
*self._request_to_addrinfo(request)
|
||||
),
|
||||
)
|
||||
return await super().handle_await_request(request)
|
||||
|
||||
|
||||
def _wrap_transport(
|
||||
blocked_ranges: list[ipaddress.IPv4Network | ipaddress.IPv6Network | str]
|
||||
| None
|
||||
| EllipsisType,
|
||||
sync: bool,
|
||||
transport,
|
||||
):
|
||||
"""
|
||||
Gets an (Async)Transport that blocks the given IP ranges
|
||||
|
@ -69,10 +155,14 @@ def _get_transport(
|
|||
if blocked_ranges is ...:
|
||||
blocked_ranges = settings.HTTP_BLOCKED_RANGES
|
||||
|
||||
if not blocked_ranges:
|
||||
return transport
|
||||
|
||||
blocked_ranges = [
|
||||
ipaddress.ip_network(net) if isinstance(net, str) else net
|
||||
for net in typing.cast(typing.Iterable, blocked_ranges)
|
||||
]
|
||||
return IpFilterWrapperTransport(blocked_ranges, transport)
|
||||
|
||||
|
||||
class BaseClient(httpx.BaseClient):
|
||||
|
@ -84,7 +174,7 @@ class BaseClient(httpx.BaseClient):
|
|||
| None
|
||||
| EllipsisType = ...,
|
||||
timeout: TimeoutTypes = settings.SETUP.REMOTE_TIMEOUT,
|
||||
**opts
|
||||
**opts,
|
||||
):
|
||||
"""
|
||||
Params:
|
||||
|
@ -95,10 +185,13 @@ class BaseClient(httpx.BaseClient):
|
|||
"""
|
||||
if actor:
|
||||
opts["auth"] = SignedAuth(actor)
|
||||
self._blocked_ranges = blocked_ranges
|
||||
|
||||
super().__init__(timeout=timeout, **opts)
|
||||
|
||||
# TODO: If we're given blocked ranges, customize transport
|
||||
def _init_transport(self, *p, **kw):
|
||||
transport = super()._init_transport(*p, **kw)
|
||||
return _wrap_transport(self._blocked_ranges, transport)
|
||||
|
||||
def build_request(self, *pargs, **kwargs):
|
||||
request = super().build_request(*pargs, **kwargs)
|
||||
|
@ -107,9 +200,8 @@ class BaseClient(httpx.BaseClient):
|
|||
if request.method == "GET" and "Accept" not in request.headers:
|
||||
request.headers["Accept"] = "application/ld+json"
|
||||
|
||||
request.headers[
|
||||
"User-Agent"
|
||||
] = settings.TAKAHE_USER_AGENT # TODO: Move this to __init__
|
||||
# TODO: Move this to __init__
|
||||
request.headers["User-Agent"] = settings.TAKAHE_USER_AGENT
|
||||
return request
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import ipaddress
|
||||
import os
|
||||
import secrets
|
||||
import sys
|
||||
|
@ -476,6 +477,20 @@ TAKAHE_USER_AGENT = (
|
|||
f"(Takahe/{__version__}; +https://{SETUP.MAIN_DOMAIN}/)"
|
||||
)
|
||||
|
||||
HTTP_BLOCKED_RANGES = map(
|
||||
ipaddress.ip_network,
|
||||
[
|
||||
# All of these are RFC reserved ranges
|
||||
# Pulled from Wikipedia
|
||||
"0.0.0.0/8", # Current network
|
||||
"10.0.0.0/8", # Private, local network
|
||||
"100.64.0.0/10", # Private, CGNAT
|
||||
"127.0.0.0/8", # Localhost
|
||||
"169.254.0.0/16", # Link-local address, zeroconf
|
||||
"172.16.0.0/12", # Private, local network
|
||||
],
|
||||
)
|
||||
|
||||
if SETUP.LOCAL_SETTINGS:
|
||||
# Let any errors bubble up
|
||||
from .local_settings import * # noqa
|
||||
|
|
Ładowanie…
Reference in New Issue