Actually implement range blocking.

pull/679/head
Jamie Bliss 2024-01-11 03:45:04 +00:00
rodzic 6ebc23e24b
commit c2ecb53bb3
Nie znaleziono w bazie danych klucza dla tego podpisu
2 zmienionych plików z 116 dodań i 9 usunięć

Wyświetl plik

@ -8,8 +8,9 @@ The API is identical to httpx, but some features has been added:
(Because Y is next after X).
"""
import functools
import asyncio
import ipaddress
import socket
import typing
from types import EllipsisType
@ -57,11 +58,96 @@ class SignedAuth(httpx.Auth):
yield request
@functools.lru_cache # Reuse transports
def _get_transport(
class BlockedIPError(Exception):
"""
Attempted to make a request that might have hit a blocked IP range.
"""
class IpFilterWrapperTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
def __init__(
self,
blocked_ranges: list[ipaddress.IPv4Network | ipaddress.IPv6Network | str],
wrappee: httpx.BaseTransport,
):
self.blocked_ranges = blocked_ranges
self.wrappee = wrappee
def __enter__(self):
self.wrappee.__enter__()
return self
def __exit__(self, *exc):
self.wrappee.__exit__(*exc)
def close(self):
self.wrappee.close()
async def __aenter__(self):
await self.wrappee.__aenter__()
return self
async def __aexit__(self, *exc):
await self.wrappee.__aexit__(self, *exc)
async def aclose(self):
await self.wrappee.close()
def _request_to_addrinfo(self, request) -> tuple:
return (
request.url.raw_host.decode("ascii"),
request.url.port or request.url.scheme,
)
def _check_addrinfo(self, req: httpx.Request, ai: typing.Sequence[tuple]):
"""
Compare an IP to the blocked ranges
"""
addr: ipaddress._BaseAddress
for info in ai:
match info:
case (socket.AF_INET, _, _, _, (addr, _)):
addr = ipaddress.IPv4Address(addr)
case (socket.AF_INET6, _, _, _, (addr, _, _, _)):
addr = ipaddress.IPv6Address(addr) # TODO: Do we need the flowinfo?
case _:
continue
for net in self.blocked_ranges:
if addr in net:
raise BlockedIPError(
"Attempted to make a connection to {addr} as {request.url.host} (blocked by {net})"
)
# It would have been nicer to do this at a lower level, so we know what
# IPs we're _actually_ connecting to, but:
# * That's really deep in httpcore and ughhhhhh
# * httpcore just passes the string hostname to the socket API anyway,
# and nobody wants to reimplement happy eyeballs, address fallback, etc
# * If any public name resolves to one of these ranges anyway, it's either
# misconfigured or malicious
def handle_request(self, request: httpx.Request) -> httpx.Response:
self._check_addrinfo(
request, socket.getaddrinfo(*self._request_to_addrinfo(request))
)
return super().handle_request(request)
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
self._check_addrinfo(
request,
await asyncio.get_running_loop().getaddrinfo(
*self._request_to_addrinfo(request)
),
)
return await super().handle_await_request(request)
def _wrap_transport(
blocked_ranges: list[ipaddress.IPv4Network | ipaddress.IPv6Network | str]
| None
| EllipsisType,
sync: bool,
transport,
):
"""
Gets an (Async)Transport that blocks the given IP ranges
@ -69,10 +155,14 @@ def _get_transport(
if blocked_ranges is ...:
blocked_ranges = settings.HTTP_BLOCKED_RANGES
if not blocked_ranges:
return transport
blocked_ranges = [
ipaddress.ip_network(net) if isinstance(net, str) else net
for net in typing.cast(typing.Iterable, blocked_ranges)
]
return IpFilterWrapperTransport(blocked_ranges, transport)
class BaseClient(httpx.BaseClient):
@ -84,7 +174,7 @@ class BaseClient(httpx.BaseClient):
| None
| EllipsisType = ...,
timeout: TimeoutTypes = settings.SETUP.REMOTE_TIMEOUT,
**opts
**opts,
):
"""
Params:
@ -95,10 +185,13 @@ class BaseClient(httpx.BaseClient):
"""
if actor:
opts["auth"] = SignedAuth(actor)
self._blocked_ranges = blocked_ranges
super().__init__(timeout=timeout, **opts)
# TODO: If we're given blocked ranges, customize transport
def _init_transport(self, *p, **kw):
transport = super()._init_transport(*p, **kw)
return _wrap_transport(self._blocked_ranges, transport)
def build_request(self, *pargs, **kwargs):
request = super().build_request(*pargs, **kwargs)
@ -107,9 +200,8 @@ class BaseClient(httpx.BaseClient):
if request.method == "GET" and "Accept" not in request.headers:
request.headers["Accept"] = "application/ld+json"
request.headers[
"User-Agent"
] = settings.TAKAHE_USER_AGENT # TODO: Move this to __init__
# TODO: Move this to __init__
request.headers["User-Agent"] = settings.TAKAHE_USER_AGENT
return request

Wyświetl plik

@ -1,3 +1,4 @@
import ipaddress
import os
import secrets
import sys
@ -476,6 +477,20 @@ TAKAHE_USER_AGENT = (
f"(Takahe/{__version__}; +https://{SETUP.MAIN_DOMAIN}/)"
)
HTTP_BLOCKED_RANGES = map(
ipaddress.ip_network,
[
# All of these are RFC reserved ranges
# Pulled from Wikipedia
"0.0.0.0/8", # Current network
"10.0.0.0/8", # Private, local network
"100.64.0.0/10", # Private, CGNAT
"127.0.0.0/8", # Localhost
"169.254.0.0/16", # Link-local address, zeroconf
"172.16.0.0/12", # Private, local network
],
)
if SETUP.LOCAL_SETTINGS:
# Let any errors bubble up
from .local_settings import * # noqa