diff --git a/extmod/modussl_axtls.c b/extmod/modussl_axtls.c index 7b0e3cbcbc..0b0ce35fc8 100644 --- a/extmod/modussl_axtls.c +++ b/extmod/modussl_axtls.c @@ -29,6 +29,7 @@ #include "py/runtime.h" #include "py/stream.h" +#include "py/objstr.h" #if MICROPY_PY_USSL && MICROPY_SSL_AXTLS @@ -54,6 +55,56 @@ struct ssl_args { STATIC const mp_obj_type_t ussl_socket_type; +// Table of errors +struct ssl_errs { + int16_t errnum; + const char *errstr; +}; +STATIC const struct ssl_errs ssl_error_tab[] = { + { SSL_NOT_OK, "NOT_OK" }, + { SSL_ERROR_DEAD, "DEAD" }, + { SSL_CLOSE_NOTIFY, "CLOSE_NOTIFY" }, + { SSL_EAGAIN, "EAGAIN" }, + { SSL_ERROR_CONN_LOST, "CONN_LOST" }, + { SSL_ERROR_RECORD_OVERFLOW, "RECORD_OVERFLOW" }, + { SSL_ERROR_SOCK_SETUP_FAILURE, "SOCK_SETUP_FAILURE" }, + { SSL_ERROR_INVALID_HANDSHAKE, "INVALID_HANDSHAKE" }, + { SSL_ERROR_INVALID_PROT_MSG, "INVALID_PROT_MSG" }, + { SSL_ERROR_INVALID_HMAC, "INVALID_HMAC" }, + { SSL_ERROR_INVALID_VERSION, "INVALID_VERSION" }, + { SSL_ERROR_UNSUPPORTED_EXTENSION, "UNSUPPORTED_EXTENSION" }, + { SSL_ERROR_INVALID_SESSION, "INVALID_SESSION" }, + { SSL_ERROR_NO_CIPHER, "NO_CIPHER" }, + { SSL_ERROR_INVALID_CERT_HASH_ALG, "INVALID_CERT_HASH_ALG" }, + { SSL_ERROR_BAD_CERTIFICATE, "BAD_CERTIFICATE" }, + { SSL_ERROR_INVALID_KEY, "INVALID_KEY" }, + { SSL_ERROR_FINISHED_INVALID, "FINISHED_INVALID" }, + { SSL_ERROR_NO_CERT_DEFINED, "NO_CERT_DEFINED" }, + { SSL_ERROR_NO_CLIENT_RENOG, "NO_CLIENT_RENOG" }, + { SSL_ERROR_NOT_SUPPORTED, "NOT_SUPPORTED" }, +}; + +STATIC NORETURN void ussl_raise_error(int err) { + for (size_t i = 0; i < MP_ARRAY_SIZE(ssl_error_tab); i++) { + if (ssl_error_tab[i].errnum == err) { + // construct string object + mp_obj_str_t *o_str = m_new_obj_maybe(mp_obj_str_t); + if (o_str == NULL) { + break; + } + o_str->base.type = &mp_type_str; + o_str->data = (const byte *)ssl_error_tab[i].errstr; + o_str->len = strlen((char *)o_str->data); + o_str->hash = qstr_compute_hash(o_str->data, o_str->len); + // raise + mp_obj_t args[2] = { MP_OBJ_NEW_SMALL_INT(err), MP_OBJ_FROM_PTR(o_str)}; + nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args)); + } + } + mp_raise_OSError(err); +} + + STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args) { #if MICROPY_PY_USSL_FINALISER mp_obj_ssl_socket_t *o = m_new_obj_with_finaliser(mp_obj_ssl_socket_t); @@ -107,9 +158,7 @@ STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args int res = ssl_handshake_status(o->ssl_sock); if (res != SSL_OK) { - printf("ssl_handshake_status: %d\n", res); - ssl_display_error(res); - mp_raise_OSError(MP_EIO); + ussl_raise_error(res); } } diff --git a/extmod/modussl_mbedtls.c b/extmod/modussl_mbedtls.c index 9e117c82cc..94061ddc8b 100644 --- a/extmod/modussl_mbedtls.c +++ b/extmod/modussl_mbedtls.c @@ -34,6 +34,7 @@ #include "py/runtime.h" #include "py/stream.h" +#include "py/objstr.h" // mbedtls_time_t #include "mbedtls/platform.h" @@ -43,6 +44,7 @@ #include "mbedtls/entropy.h" #include "mbedtls/ctr_drbg.h" #include "mbedtls/debug.h" +#include "mbedtls/error.h" typedef struct _mp_obj_ssl_socket_t { mp_obj_base_t base; @@ -74,6 +76,42 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons } #endif +STATIC NORETURN void mbedtls_raise_error(int err) { + #if defined(MBEDTLS_ERROR_C) + // Including mbedtls_strerror takes about 16KB on the esp32 due to all the strings. + // MBEDTLS_ERROR_C is the define used by mbedtls to conditionally include mbedtls_strerror. + // It is set/unset in the MBEDTLS_CONFIG_FILE which is defined in the Makefile. + // "small" negative integer error codes come from underlying stream/sockets, not mbedtls + if (err < 0 && err > -256) { + mp_raise_OSError(-err); + } + + // Try to allocate memory for the message + #define ERR_STR_MAX 100 // mbedtls_strerror truncates if it doesn't fit + mp_obj_str_t *o_str = m_new_obj_maybe(mp_obj_str_t); + byte *o_str_buf = m_new_maybe(byte, ERR_STR_MAX); + if (o_str == NULL || o_str_buf == NULL) { + mp_raise_OSError(err); + } + + // print the error message into the allocated buffer + mbedtls_strerror(err, (char *)o_str_buf, ERR_STR_MAX); + size_t len = strnlen((char *)o_str_buf, ERR_STR_MAX); + + // Put the exception object together + o_str->base.type = &mp_type_str; + o_str->data = o_str_buf; + o_str->len = len; + o_str->hash = qstr_compute_hash(o_str->data, o_str->len); + // raise + mp_obj_t args[2] = { MP_OBJ_NEW_SMALL_INT(err), MP_OBJ_FROM_PTR(o_str)}; + nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args)); + #else + // mbedtls is compiled without error strings so we simply return the err number + mp_raise_OSError(err); // typ. err is negative + #endif +} + STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) { mp_obj_t sock = *(mp_obj_t *)ctx; @@ -85,7 +123,7 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) { if (mp_is_nonblocking_error(err)) { return MBEDTLS_ERR_SSL_WANT_WRITE; } - return -err; + return -err; // convert an MP_ERRNO to something mbedtls passes through as error } else { return out_sz; } @@ -197,7 +235,6 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { if (args->do_handshake.u_bool) { while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) { if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { - printf("mbedtls_ssl_handshake error: -%x\n", -ret); goto cleanup; } } @@ -221,7 +258,7 @@ cleanup: } else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA) { mp_raise_ValueError(MP_ERROR_TEXT("invalid cert")); } else { - mp_raise_OSError(MP_EIO); + mbedtls_raise_error(ret); } } diff --git a/tests/extmod/ussl_basic.py b/tests/extmod/ussl_basic.py index b4e21c7dce..9e1821dca9 100644 --- a/tests/extmod/ussl_basic.py +++ b/tests/extmod/ussl_basic.py @@ -9,7 +9,7 @@ except ImportError: # create in client mode try: - ss = ssl.wrap_socket(io.BytesIO()) + ss = ssl.wrap_socket(io.BytesIO(), server_hostname="test.example.com") except OSError as er: print("wrap_socket:", repr(er)) diff --git a/tests/extmod/ussl_basic.py.exp b/tests/extmod/ussl_basic.py.exp index 5282338319..eb7df855aa 100644 --- a/tests/extmod/ussl_basic.py.exp +++ b/tests/extmod/ussl_basic.py.exp @@ -1,5 +1,4 @@ -ssl_handshake_status: -256 -wrap_socket: OSError(5,) +wrap_socket: OSError(-256, 'CONN_LOST') <_SSLSocket 4 b'' diff --git a/tests/net_inet/tls_num_errors.py b/tests/net_inet/tls_num_errors.py new file mode 100644 index 0000000000..dd7f714e6e --- /dev/null +++ b/tests/net_inet/tls_num_errors.py @@ -0,0 +1,44 @@ +# test that modtls produces a numerical error message when out of heap + +try: + import usocket as socket, ussl as ssl, sys +except: + import socket, ssl, sys +try: + from micropython import alloc_emergency_exception_buf, heap_lock, heap_unlock +except: + print("SKIP") + raise SystemExit + + +# test with heap locked to see it switch to number-only error message +def test(addr): + alloc_emergency_exception_buf(256) + s = socket.socket() + s.connect(addr) + try: + s.setblocking(False) + s = ssl.wrap_socket(s, do_handshake=False) + heap_lock() + print("heap is locked") + while True: + ret = s.write("foo") + if ret: + break + heap_unlock() + print("wrap: no exception") + except OSError as e: + heap_unlock() + # mbedtls produces "-29184" + # axtls produces "RECORD_OVERFLOW" + ok = "-29184" in str(e) or "RECORD_OVERFLOW" in str(e) + print("wrap:", ok) + if not ok: + print("got exception:", e) + s.close() + + +if __name__ == "__main__": + # connect to plain HTTP port, oops! + addr = socket.getaddrinfo("micropython.org", 80)[0][-1] + test(addr) diff --git a/tests/net_inet/tls_num_errors.py.exp b/tests/net_inet/tls_num_errors.py.exp new file mode 100644 index 0000000000..e6a15634d0 --- /dev/null +++ b/tests/net_inet/tls_num_errors.py.exp @@ -0,0 +1,2 @@ +heap is locked +wrap: True diff --git a/tests/net_inet/tls_text_errors.py b/tests/net_inet/tls_text_errors.py new file mode 100644 index 0000000000..2ba167b868 --- /dev/null +++ b/tests/net_inet/tls_text_errors.py @@ -0,0 +1,33 @@ +# test that modtls produces a text error message + +try: + import usocket as socket, ussl as ssl, sys +except: + import socket, ssl, sys + + +def test(addr): + s = socket.socket() + s.connect(addr) + try: + s = ssl.wrap_socket(s) + print("wrap: no exception") + except OSError as e: + # mbedtls produces "mbedtls -0x7200: SSL - An invalid SSL record was received" + # axtls produces "RECORD_OVERFLOW" + # CPython produces "[SSL: WRONG_VERSION_NUMBER] wrong version number (_ssl.c:1108)" + ok = ( + "invalid SSL record" in str(e) + or "RECORD_OVERFLOW" in str(e) + or "wrong version" in str(e) + ) + print("wrap:", ok) + if not ok: + print("got exception:", e) + s.close() + + +if __name__ == "__main__": + # connect to plain HTTP port, oops! + addr = socket.getaddrinfo("micropython.org", 80)[0][-1] + test(addr)