From 1c9592efc45739f436312e38312372545978e866 Mon Sep 17 00:00:00 2001
From: Dong Heng <dongheng@espressif.com>
Date: Thu, 12 Nov 2020 15:17:21 +0800
Subject: [PATCH] fix(mbedtls): fix mbedtls dynamic resource memory leaks and
 RSA cert drop earlier

RX process caches the session information in "ssl->in_ctr" not in "ssl->in_buf".
So when freeing the SSL, can't free the "ssl->in_ctr", because the "ssl->in_buf"
is empty.

Make the RX process like TX process, and cache the session information in
"ssl->in_buf", so that the cache buffer can be freed when freeing the SSL.

Closes https://github.com/espressif/esp-idf/issues/6104
---
 .../port/dynamic/esp_mbedtls_dynamic_impl.c   | 41 ++++++++++++++-----
 .../port/dynamic/esp_mbedtls_dynamic_impl.h   |  5 +--
 components/mbedtls/port/dynamic/esp_ssl_cli.c | 18 +++++++-
 components/mbedtls/port/dynamic/esp_ssl_tls.c | 11 ++++-
 4 files changed, 59 insertions(+), 16 deletions(-)

diff --git a/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.c b/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.c
index 602dfe9496..36896d386e 100644
--- a/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.c
+++ b/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.c
@@ -317,9 +317,13 @@ int esp_mbedtls_add_rx_buffer(mbedtls_ssl_context *ssl)
     ESP_LOGV(TAG, "--> add rx");
 
     if (ssl->in_buf) {
-        ESP_LOGV(TAG, "in buffer is not empty");
-        ret = 0;
-        goto exit;
+        if (ssl->in_iv) {
+            ESP_LOGV(TAG, "in buffer is not empty");
+            ret = 0;
+            goto exit;
+        } else {
+            cached = 1;
+        }
     }
 
     ssl->in_hdr = msg_head;
@@ -346,6 +350,12 @@ int esp_mbedtls_add_rx_buffer(mbedtls_ssl_context *ssl)
     ESP_LOGV(TAG, "message length is %d RX buffer length should be %d left is %d",
                 (int)in_msglen, (int)buffer_len, (int)ssl->in_left);
 
+    if (cached) {
+        memcpy(cache_buf, ssl->in_buf, 16);
+        mbedtls_free(ssl->in_buf);
+        init_rx_buffer(ssl, NULL);
+    }
+
     buf = mbedtls_calloc(1, buffer_len);
     if (!buf) {
         ESP_LOGE(TAG, "alloc(%d bytes) failed", buffer_len);
@@ -355,12 +365,6 @@ int esp_mbedtls_add_rx_buffer(mbedtls_ssl_context *ssl)
 
     ESP_LOGV(TAG, "add in buffer %d bytes @ %p", buffer_len, buf);
 
-    if (ssl->in_ctr) {
-        memcpy(cache_buf, ssl->in_ctr, 16);
-        mbedtls_free(ssl->in_ctr);
-        cached = 1;
-    }
-
     init_rx_buffer(ssl, buf);
 
     if (cached) {
@@ -389,7 +393,8 @@ int esp_mbedtls_free_rx_buffer(mbedtls_ssl_context *ssl)
     /**
      * When have read multi messages once, can't free the input buffer directly.
      */
-    if (!ssl->in_buf || (ssl->in_hslen && (ssl->in_hslen < ssl->in_msglen))) {
+    if (!ssl->in_buf || (ssl->in_hslen && (ssl->in_hslen < ssl->in_msglen)) ||
+        (ssl->in_buf && !ssl->in_iv)) {
         ret = 0;
         goto exit;
     }
@@ -418,7 +423,8 @@ int esp_mbedtls_free_rx_buffer(mbedtls_ssl_context *ssl)
     }
 
     memcpy(pdata, buf, 16);
-    ssl->in_ctr = pdata;
+    init_rx_buffer(ssl, pdata);
+    ssl->in_iv = NULL;
 
 exit:
     ESP_LOGV(TAG, "<-- free rx");
@@ -515,4 +521,17 @@ void esp_mbedtls_free_peer_cert(mbedtls_ssl_context *ssl)
         ssl->session_negotiate->peer_cert = NULL;
     }
 }
+
+bool esp_mbedtls_ssl_is_rsa(mbedtls_ssl_context *ssl)
+{
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info =
+        ssl->transform_negotiate->ciphersuite_info;
+
+    if (ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA ||
+        ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA_PSK) {
+        return true;
+    } else {
+        return false;
+    }
+}
 #endif
diff --git a/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.h b/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.h
index d831d07364..8f4bb144cc 100644
--- a/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.h
+++ b/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.h
@@ -33,9 +33,6 @@
  \
     if ((_ret = _fn) != 0) { \
         ESP_LOGV(TAG, "\"%s\" result is -0x%x", # _fn, -_ret); \
-        if (_ret == MBEDTLS_ERR_SSL_CONN_EOF) {\
-            return 0; \
-        } \
         TRACE_CHECK(_fn, "fail"); \
         return _ret; \
     } \
@@ -80,6 +77,8 @@ void esp_mbedtls_free_cacert(mbedtls_ssl_context *ssl);
 
 #ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_PEER_CERT
 void esp_mbedtls_free_peer_cert(mbedtls_ssl_context *ssl);
+
+bool esp_mbedtls_ssl_is_rsa(mbedtls_ssl_context *ssl);
 #endif
 
 #endif /* _DYNAMIC_IMPL_H_ */
diff --git a/components/mbedtls/port/dynamic/esp_ssl_cli.c b/components/mbedtls/port/dynamic/esp_ssl_cli.c
index 12b33f3ddc..0a0997adcc 100644
--- a/components/mbedtls/port/dynamic/esp_ssl_cli.c
+++ b/components/mbedtls/port/dynamic/esp_ssl_cli.c
@@ -73,7 +73,17 @@ static int manage_resource(mbedtls_ssl_context *ssl, bool add)
                     CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
                 }
 #ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_PEER_CERT
-                esp_mbedtls_free_peer_cert(ssl);
+                /**
+                 * If current ciphersuite is RSA, we should free peer'
+                 * certificate at step  MBEDTLS_SSL_CLIENT_KEY_EXCHANGE.
+                 *
+                 * And if it is other kinds of ciphersuite, we can free
+                 * peer certificate here.
+                 */
+
+                if (esp_mbedtls_ssl_is_rsa(ssl) == false) {
+                    esp_mbedtls_free_peer_cert(ssl);
+                }
 #endif
             }
             break;
@@ -123,6 +133,12 @@ static int manage_resource(mbedtls_ssl_context *ssl, bool add)
                 size_t buffer_len = MBEDTLS_SSL_OUT_BUFFER_LEN;
 
                 CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, buffer_len));
+            } else {
+#ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_PEER_CERT
+                if (esp_mbedtls_ssl_is_rsa(ssl) == true) {
+                    esp_mbedtls_free_peer_cert(ssl);
+                }
+#endif
             }
             break;
         case MBEDTLS_SSL_CERTIFICATE_VERIFY:
diff --git a/components/mbedtls/port/dynamic/esp_ssl_tls.c b/components/mbedtls/port/dynamic/esp_ssl_tls.c
index 384586cfd8..25b2f3d1f0 100644
--- a/components/mbedtls/port/dynamic/esp_ssl_tls.c
+++ b/components/mbedtls/port/dynamic/esp_ssl_tls.c
@@ -85,7 +85,16 @@ int __wrap_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t
 {
     int ret;
 
-    CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
+    ESP_LOGD(TAG, "add mbedtls RX buffer");
+    ret = esp_mbedtls_add_rx_buffer(ssl);
+    if (ret == MBEDTLS_ERR_SSL_CONN_EOF) {
+        ESP_LOGD(TAG, "fail, the connection indicated an EOF");
+        return 0;
+    } else if (ret < 0) {
+        ESP_LOGD(TAG, "fail, error=-0x%x", -ret);
+        return ret;
+    }
+    ESP_LOGD(TAG, "end");
 
     ret = __real_mbedtls_ssl_read(ssl, buf, len);