From 1f019f90eaef146960988319166076e12e689e81 Mon Sep 17 00:00:00 2001 From: Mirza Kapetanovic Date: Wed, 12 Jun 2024 11:41:04 +0200 Subject: [PATCH] requests: Make possible to override headers and allow raw data upload. This removes all the hard-coded request headers from the requests module so they can be overridden by user provided headers dict. Furthermore allow streaming request data without chunk encoding in those cases where content length is known but it's not desirable to load the whole content into memory. Also some servers (e.g. nginx) reject HTTP/1.0 requests with the Transfer-Encoding header set. The change should be backwards compatible as long as the user hasn't provided any of the previously hard-coded headers. Signed-off-by: Mirza Kapetanovic --- python-ecosys/requests/manifest.py | 2 +- python-ecosys/requests/requests/__init__.py | 55 ++++--- python-ecosys/requests/test_requests.py | 155 ++++++++++++++++++++ 3 files changed, 193 insertions(+), 19 deletions(-) create mode 100644 python-ecosys/requests/test_requests.py diff --git a/python-ecosys/requests/manifest.py b/python-ecosys/requests/manifest.py index 97df1560..eb7bb2d4 100644 --- a/python-ecosys/requests/manifest.py +++ b/python-ecosys/requests/manifest.py @@ -1,3 +1,3 @@ -metadata(version="0.9.0", pypi="requests") +metadata(version="0.10.0", pypi="requests") package("requests") diff --git a/python-ecosys/requests/requests/__init__.py b/python-ecosys/requests/requests/__init__.py index 74010291..b6bf515d 100644 --- a/python-ecosys/requests/requests/__init__.py +++ b/python-ecosys/requests/requests/__init__.py @@ -38,12 +38,15 @@ def request( url, data=None, json=None, - headers={}, + headers=None, stream=None, auth=None, timeout=None, parse_headers=True, ): + if headers is None: + headers = {} + redirect = None # redirection url, None means no redirection chunked_data = data and getattr(data, "__next__", None) and not getattr(data, "__len__", None) @@ -94,33 +97,49 @@ def request( context.verify_mode = tls.CERT_NONE s = context.wrap_socket(s, server_hostname=host) s.write(b"%s /%s HTTP/1.0\r\n" % (method, path)) + if "Host" not in headers: - s.write(b"Host: %s\r\n" % host) + headers["Host"] = host + + if json is not None: + assert data is None + import ujson + + data = ujson.dumps(json) + + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json" + + if data: + if chunked_data: + if "Transfer-Encoding" not in headers and "Content-Length" not in headers: + headers["Transfer-Encoding"] = "chunked" + elif "Content-Length" not in headers: + headers["Content-Length"] = str(len(data)) + + if "Connection" not in headers: + headers["Connection"] = "close" + # Iterate over keys to avoid tuple alloc for k in headers: s.write(k) s.write(b": ") s.write(headers[k]) s.write(b"\r\n") - if json is not None: - assert data is None - import ujson - data = ujson.dumps(json) - s.write(b"Content-Type: application/json\r\n") + s.write(b"\r\n") + if data: if chunked_data: - s.write(b"Transfer-Encoding: chunked\r\n") - else: - s.write(b"Content-Length: %d\r\n" % len(data)) - s.write(b"Connection: close\r\n\r\n") - if data: - if chunked_data: - for chunk in data: - s.write(b"%x\r\n" % len(chunk)) - s.write(chunk) - s.write(b"\r\n") - s.write("0\r\n\r\n") + if headers.get("Transfer-Encoding", None) == "chunked": + for chunk in data: + s.write(b"%x\r\n" % len(chunk)) + s.write(chunk) + s.write(b"\r\n") + s.write("0\r\n\r\n") + else: + for chunk in data: + s.write(chunk) else: s.write(data) diff --git a/python-ecosys/requests/test_requests.py b/python-ecosys/requests/test_requests.py new file mode 100644 index 00000000..540d335c --- /dev/null +++ b/python-ecosys/requests/test_requests.py @@ -0,0 +1,155 @@ +import io +import sys + + +class Socket: + def __init__(self): + self._write_buffer = io.BytesIO() + self._read_buffer = io.BytesIO(b"HTTP/1.0 200 OK\r\n\r\n") + + def connect(self, address): + pass + + def write(self, buf): + self._write_buffer.write(buf) + + def readline(self): + return self._read_buffer.readline() + + +class usocket: + AF_INET = 2 + SOCK_STREAM = 1 + IPPROTO_TCP = 6 + + @staticmethod + def getaddrinfo(host, port, af=0, type=0, flags=0): + return [(usocket.AF_INET, usocket.SOCK_STREAM, usocket.IPPROTO_TCP, "", ("127.0.0.1", 80))] + + def socket(af=AF_INET, type=SOCK_STREAM, proto=IPPROTO_TCP): + return Socket() + + +sys.modules["usocket"] = usocket +# ruff: noqa: E402 +import requests + + +def format_message(response): + return response.raw._write_buffer.getvalue().decode("utf8") + + +def test_simple_get(): + response = requests.request("GET", "http://example.com") + + assert response.raw._write_buffer.getvalue() == ( + b"GET / HTTP/1.0\r\n" + b"Connection: close\r\n" + b"Host: example.com\r\n\r\n" + ), format_message(response) + + +def test_get_auth(): + response = requests.request( + "GET", "http://example.com", auth=("test-username", "test-password") + ) + + assert response.raw._write_buffer.getvalue() == ( + b"GET / HTTP/1.0\r\n" + + b"Host: example.com\r\n" + + b"Authorization: Basic dGVzdC11c2VybmFtZTp0ZXN0LXBhc3N3b3Jk\r\n" + + b"Connection: close\r\n\r\n" + ), format_message(response) + + +def test_get_custom_header(): + response = requests.request("GET", "http://example.com", headers={"User-Agent": "test-agent"}) + + assert response.raw._write_buffer.getvalue() == ( + b"GET / HTTP/1.0\r\n" + + b"User-Agent: test-agent\r\n" + + b"Host: example.com\r\n" + + b"Connection: close\r\n\r\n" + ), format_message(response) + + +def test_post_json(): + response = requests.request("GET", "http://example.com", json="test") + + assert response.raw._write_buffer.getvalue() == ( + b"GET / HTTP/1.0\r\n" + + b"Connection: close\r\n" + + b"Content-Type: application/json\r\n" + + b"Host: example.com\r\n" + + b"Content-Length: 6\r\n\r\n" + + b'"test"' + ), format_message(response) + + +def test_post_chunked_data(): + def chunks(): + yield "test" + + response = requests.request("GET", "http://example.com", data=chunks()) + + assert response.raw._write_buffer.getvalue() == ( + b"GET / HTTP/1.0\r\n" + + b"Transfer-Encoding: chunked\r\n" + + b"Host: example.com\r\n" + + b"Connection: close\r\n\r\n" + + b"4\r\ntest\r\n" + + b"0\r\n\r\n" + ), format_message(response) + + +def test_overwrite_get_headers(): + response = requests.request( + "GET", "http://example.com", headers={"Connection": "keep-alive", "Host": "test.com"} + ) + + assert response.raw._write_buffer.getvalue() == ( + b"GET / HTTP/1.0\r\n" + b"Host: test.com\r\n" + b"Connection: keep-alive\r\n\r\n" + ), format_message(response) + + +def test_overwrite_post_json_headers(): + response = requests.request( + "GET", + "http://example.com", + json="test", + headers={"Content-Type": "text/plain", "Content-Length": "10"}, + ) + + assert response.raw._write_buffer.getvalue() == ( + b"GET / HTTP/1.0\r\n" + + b"Connection: close\r\n" + + b"Content-Length: 10\r\n" + + b"Content-Type: text/plain\r\n" + + b"Host: example.com\r\n\r\n" + + b'"test"' + ), format_message(response) + + +def test_overwrite_post_chunked_data_headers(): + def chunks(): + yield "test" + + response = requests.request( + "GET", "http://example.com", data=chunks(), headers={"Content-Length": "4"} + ) + + assert response.raw._write_buffer.getvalue() == ( + b"GET / HTTP/1.0\r\n" + + b"Host: example.com\r\n" + + b"Content-Length: 4\r\n" + + b"Connection: close\r\n\r\n" + + b"test" + ), format_message(response) + + +test_simple_get() +test_get_auth() +test_get_custom_header() +test_post_json() +test_post_chunked_data() +test_overwrite_get_headers() +test_overwrite_post_json_headers() +test_overwrite_post_chunked_data_headers()