From feae3a7545c65ccc794d165b9915fee9a2f99b3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20de=20Giessen?= Date: Tue, 19 Mar 2024 15:24:39 +0100 Subject: [PATCH] extmod/modtls_mbedtls: Test SSLSession reuse. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniƫl van de Giessen --- .../sslcontext_server_client_session.py | 131 ++++++++++++++++++ .../sslcontext_server_client_session.py.exp | 29 ++++ 2 files changed, 160 insertions(+) create mode 100644 tests/multi_net/sslcontext_server_client_session.py create mode 100644 tests/multi_net/sslcontext_server_client_session.py.exp diff --git a/tests/multi_net/sslcontext_server_client_session.py b/tests/multi_net/sslcontext_server_client_session.py new file mode 100644 index 0000000000..cee8093475 --- /dev/null +++ b/tests/multi_net/sslcontext_server_client_session.py @@ -0,0 +1,131 @@ +# Test creating an SSL connection with certificates as bytes objects. + +try: + from io import IOBase + import os + import socket + import ssl +except ImportError: + print("SKIP") + raise SystemExit + +if not hasattr(ssl, "SSLSession"): + print("SKIP") + raise SystemExit + +PORT = 8000 + +# These are test certificates. See tests/README.md for details. +certfile = "ec_cert.der" +keyfile = "ec_key.der" + +try: + os.stat(certfile) + os.stat(keyfile) +except OSError: + print("SKIP") + raise SystemExit + +with open(certfile, "rb") as cf: + cert = cadata = cf.read() + +with open(keyfile, "rb") as kf: + key = kf.read() + + +# Helper class to count number of bytes going over a TCP socket +class CountingStream(IOBase): + def __init__(self, stream): + self.stream = stream + self.count = 0 + + def readinto(self, buf, nbytes=None): + result = self.stream.readinto(buf) if nbytes is None else self.stream.readinto(buf, nbytes) + self.count += result + return result + + def write(self, buf): + self.count += len(buf) + return self.stream.write(buf) + + def ioctl(self, req, arg): + if hasattr(self.stream, "ioctl"): + return self.stream.ioctl(req, arg) + return 0 + + +# Server +def instance0(): + multitest.globals(IP=multitest.get_network_ip()) + s = socket.socket() + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1]) + s.listen(1) + multitest.next() + server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_ctx.load_cert_chain(cert, key) + for i in range(7): + s2, _ = s.accept() + s2 = server_ctx.wrap_socket(s2, server_side=True) + print(s2.read(18)) + s2.write(b"server to client {}".format(i)) + s2.close() + s.close() + + +# Client +def instance1(): + multitest.next() + + def connect_and_count(i, session, set_method="wrap_socket"): + s = socket.socket() + s.connect(socket.getaddrinfo(IP, PORT)[0][-1]) + s = CountingStream(s) + client_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_ctx.verify_mode = ssl.CERT_REQUIRED + client_ctx.load_verify_locations(cadata=cadata) + wrap_socket_kwargs = {} + if set_method == "wrap_socket": + wrap_socket_kwargs = {"session": session} + elif set_method == "socket_attr": + wrap_socket_kwargs = {"do_handshake_on_connect": False} + s2 = client_ctx.wrap_socket(s, server_hostname="micropython.local", **wrap_socket_kwargs) + if set_method == "socket_attr" and session is not None: + s2.session = session + s2.write(b"client to server {}".format(i)) + print(s2.read(18)) + session = s2.session + print(type(session)) + s2.close() + return session, s.count + + # No session reuse + session, count_without_reuse = connect_and_count(0, None) + + # Direct session reuse + session, count = connect_and_count(1, session, "wrap_socket") + print(count < count_without_reuse) + + # Serialized session reuse + session = ssl.SSLSession(session.serialize()) + session, count = connect_and_count(2, session, "wrap_socket") + print(count < count_without_reuse) + + # Serialized session reuse (using buffer protocol) + session = ssl.SSLSession(bytes(session)) + session, count = connect_and_count(3, session, "wrap_socket") + print(count < count_without_reuse) + + # Direct session reuse + session, count = connect_and_count(4, session, "socket_attr") + print(count < count_without_reuse) + + # Serialized session reuse + session = ssl.SSLSession(session.serialize()) + session, count = connect_and_count(5, session, "socket_attr") + print(count < count_without_reuse) + + # Serialized session reuse (using buffer protocol) + session = ssl.SSLSession(bytes(session)) + session, count = connect_and_count(6, session, "socket_attr") + print(count < count_without_reuse) diff --git a/tests/multi_net/sslcontext_server_client_session.py.exp b/tests/multi_net/sslcontext_server_client_session.py.exp new file mode 100644 index 0000000000..f3ed2c57d6 --- /dev/null +++ b/tests/multi_net/sslcontext_server_client_session.py.exp @@ -0,0 +1,29 @@ +--- instance0 --- +b'client to server 0' +b'client to server 1' +b'client to server 2' +b'client to server 3' +b'client to server 4' +b'client to server 5' +b'client to server 6' +--- instance1 --- +b'server to client 0' + +b'server to client 1' + +True +b'server to client 2' + +True +b'server to client 3' + +True +b'server to client 4' + +True +b'server to client 5' + +True +b'server to client 6' + +True