diff --git a/components/tcp_transport/test/test_transport.c b/components/tcp_transport/test/test_transport.c index dc82672302..0395c0b6be 100644 --- a/components/tcp_transport/test/test_transport.c +++ b/components/tcp_transport/test/test_transport.c @@ -190,7 +190,7 @@ static void transport_connection_timeout_test(esp_transport_handle_t transport_u EventBits_t bits = xEventGroupWaitBits(params.tcp_connect_done, TCP_CONNECT_DONE, true, true, max_wait); TickType_t end = xTaskGetTickCount(); - TEST_ASSERT_EQUAL(TCP_CONNECT_DONE, TCP_CONNECT_DONE&bits); // Connection has finished + TEST_ASSERT_EQUAL(TCP_CONNECT_DONE, TCP_CONNECT_DONE & bits); // Connection has finished TEST_ASSERT_EQUAL(-1, params.ret); // Connection failed with -1 // Test connection attempt took expected timeout value @@ -316,24 +316,21 @@ static void socket_operation_test(esp_transport_handle_t transport_under_test, test_utils_task_delete(tcp_connect_task_handle); } -static void tcp_transport_keepalive_test(esp_transport_handle_t transport_under_test) +static void tcp_transport_keepalive_test(esp_transport_handle_t transport_under_test, esp_transport_keep_alive_t *config) { - static const int KEEP_ALIVE_INTERVAL = 1; - static const int KEEP_ALIVE_IDLE = 2; - static const int KEEP_ALIVE_COUNT = 3; - - static const struct expected_sock_option expected_opts[] = { + static struct expected_sock_option expected_opts[4] = { { .level = SOL_SOCKET, .optname = SO_KEEPALIVE, .optval = 1, .opttype = SOCK_OPT_TYPE_BOOL }, - { .level = IPPROTO_TCP, .optname = TCP_KEEPIDLE, .optval = KEEP_ALIVE_IDLE, .opttype = SOCK_OPT_TYPE_INT }, - { .level = IPPROTO_TCP, .optname = TCP_KEEPINTVL, .optval = KEEP_ALIVE_INTERVAL, .opttype = SOCK_OPT_TYPE_INT }, - { .level = IPPROTO_TCP, .optname = TCP_KEEPCNT, .optval = KEEP_ALIVE_COUNT, .opttype = SOCK_OPT_TYPE_INT }, + { .level = IPPROTO_TCP }, + { .level = IPPROTO_TCP }, + { .level = IPPROTO_TCP } }; - esp_transport_keep_alive_t keep_alive_cfg = { .keep_alive_interval = KEEP_ALIVE_INTERVAL, - .keep_alive_idle = KEEP_ALIVE_IDLE, - .keep_alive_enable = true, - .keep_alive_count = KEEP_ALIVE_COUNT }; - esp_transport_tcp_set_keep_alive(transport_under_test, &keep_alive_cfg); + expected_opts[1].optname = TCP_KEEPIDLE; + expected_opts[1].optval = config->keep_alive_idle; + expected_opts[2].optname = TCP_KEEPINTVL; + expected_opts[2].optval = config->keep_alive_interval; + expected_opts[3].optname = TCP_KEEPCNT; + expected_opts[3].optval = config->keep_alive_count; socket_operation_test(transport_under_test, expected_opts, sizeof(expected_opts)/sizeof(struct expected_sock_option)); } @@ -346,7 +343,14 @@ TEST_CASE("tcp_transport: Keep alive test", "[tcp_transport]") esp_transport_list_add(transport_list, tcp, "tcp"); // Perform the test - tcp_transport_keepalive_test(tcp); + esp_transport_keep_alive_t keep_alive_cfg = { + .keep_alive_interval = 5, + .keep_alive_idle = 4, + .keep_alive_enable = true, + .keep_alive_count = 3 }; + esp_transport_tcp_set_keep_alive(tcp, &keep_alive_cfg); + + tcp_transport_keepalive_test(tcp, &keep_alive_cfg); // Cleanup esp_transport_close(tcp); @@ -363,7 +367,14 @@ TEST_CASE("ssl_transport: Keep alive test", "[tcp_transport]") esp_transport_ssl_enable_global_ca_store(ssl); // Perform the test - tcp_transport_keepalive_test(ssl); + esp_transport_keep_alive_t keep_alive_cfg = { + .keep_alive_interval = 2, + .keep_alive_idle = 3, + .keep_alive_enable = true, + .keep_alive_count = 4 }; + esp_transport_ssl_set_keep_alive(ssl, &keep_alive_cfg); + + tcp_transport_keepalive_test(ssl, &keep_alive_cfg); // Cleanup esp_transport_close(ssl); @@ -382,9 +393,40 @@ TEST_CASE("ws_transport: Keep alive test", "[tcp_transport]") esp_transport_list_add(transport_list, ws, "wss"); // Perform the test - tcp_transport_keepalive_test(ws); + esp_transport_keep_alive_t keep_alive_cfg = { + .keep_alive_interval = 1, + .keep_alive_idle = 2, + .keep_alive_enable = true, + .keep_alive_count = 3 }; + esp_transport_tcp_set_keep_alive(ssl, &keep_alive_cfg); + + tcp_transport_keepalive_test(ws, &keep_alive_cfg); // Cleanup esp_transport_close(ssl); esp_transport_list_destroy(transport_list); } + +// Note: This functionality is tested and kept only for compatibility reasons with IDF <= 4.x +// It is strongly encouraged to use transport within lists only +TEST_CASE("ssl_transport: Check that parameters (keepalive) are set independently on the list", "[tcp_transport]") +{ + // Init the transport under test + esp_transport_handle_t ssl = esp_transport_ssl_init(); + esp_tls_init_global_ca_store(); + esp_transport_ssl_enable_global_ca_store(ssl); + + // Perform the test + esp_transport_keep_alive_t keep_alive_cfg = { + .keep_alive_interval = 2, + .keep_alive_idle = 4, + .keep_alive_enable = true, + .keep_alive_count = 3 }; + esp_transport_ssl_set_keep_alive(ssl, &keep_alive_cfg); + + tcp_transport_keepalive_test(ssl, &keep_alive_cfg); + + // Cleanup + esp_transport_close(ssl); + esp_transport_destroy(ssl); +} diff --git a/components/tcp_transport/transport_ssl.c b/components/tcp_transport/transport_ssl.c index f503aa74d4..fd2795a4bd 100644 --- a/components/tcp_transport/transport_ssl.c +++ b/components/tcp_transport/transport_ssl.c @@ -23,6 +23,10 @@ #include "esp_transport_utils.h" #include "esp_transport_internal.h" +#define GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t) \ + transport_esp_tls_t *ssl = ssl_get_context_data(t); \ + if (!ssl) { return; } + static const char *TAG = "TRANSPORT_BASE"; typedef enum { @@ -40,11 +44,30 @@ typedef struct transport_esp_tls { transport_ssl_conn_state_t conn_state; } transport_esp_tls_t; +static inline struct transport_esp_tls * ssl_get_context_data(esp_transport_handle_t t) +{ + if (!t) { + return NULL; + } + if (t->data) { // Prefer internal ssl context (independent from the list) + return (transport_esp_tls_t*)t->data; + } + if (t->base && t->base->transport_esp_tls) { // Next one is the lists inherent context + t->data = t->base->transport_esp_tls; // Optimize: if we have base context, use it as internal + return t->base->transport_esp_tls; + } + // If we don't have a valid context, let's to create one + transport_esp_tls_t *ssl = esp_transport_esp_tls_create(); + ESP_TRANSPORT_MEM_CHECK(TAG, ssl, return NULL) + t->data = ssl; + return ssl; +} + static int ssl_close(esp_transport_handle_t t); static int esp_tls_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp) { - transport_esp_tls_t *ssl = t->base->transport_esp_tls; + transport_esp_tls_t *ssl = ssl_get_context_data(t); if (ssl->conn_state == TRANS_SSL_INIT) { ssl->cfg.timeout_ms = timeout_ms; ssl->cfg.is_plain_tcp = is_plain_tcp; @@ -74,7 +97,7 @@ static inline int tcp_connect_async(esp_transport_handle_t t, const char *host, static int esp_tls_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp) { - transport_esp_tls_t *ssl = t->base->transport_esp_tls; + transport_esp_tls_t *ssl = ssl_get_context_data(t); ssl->cfg.timeout_ms = timeout_ms; ssl->cfg.is_plain_tcp = is_plain_tcp; @@ -103,7 +126,7 @@ static inline int tcp_connect(esp_transport_handle_t t, const char *host, int po static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms) { - transport_esp_tls_t *ssl = t->base->transport_esp_tls; + transport_esp_tls_t *ssl = ssl_get_context_data(t); int ret = -1; int remain = 0; struct timeval timeout; @@ -132,7 +155,7 @@ static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms) static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms) { - transport_esp_tls_t *ssl = t->base->transport_esp_tls; + transport_esp_tls_t *ssl = ssl_get_context_data(t); int ret = -1; struct timeval timeout; fd_set writeset; @@ -156,7 +179,7 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms) static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms) { int poll, ret; - transport_esp_tls_t *ssl = t->base->transport_esp_tls; + transport_esp_tls_t *ssl = ssl_get_context_data(t); if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) { ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms); @@ -173,7 +196,7 @@ static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) { int poll, ret; - transport_esp_tls_t *ssl = t->base->transport_esp_tls; + transport_esp_tls_t *ssl = ssl_get_context_data(t); if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) { return poll; @@ -196,8 +219,8 @@ static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout static int ssl_close(esp_transport_handle_t t) { int ret = -1; - if (t && t->base && t->base->transport_esp_tls && t->base->transport_esp_tls->ssl_initialized) { - transport_esp_tls_t *ssl = t->base->transport_esp_tls; + transport_esp_tls_t *ssl = ssl_get_context_data(t); + if (ssl && ssl->ssl_initialized) { ret = esp_tls_conn_destroy(ssl->tls); ssl->conn_state = TRANS_SSL_INIT; ssl->ssl_initialized = false; @@ -207,127 +230,124 @@ static int ssl_close(esp_transport_handle_t t) static int ssl_destroy(esp_transport_handle_t t) { - esp_transport_close(t); + transport_esp_tls_t *ssl = ssl_get_context_data(t); + if (ssl) { + esp_transport_close(t); + if (t->base && t->base->transport_esp_tls && + t->data == t->base->transport_esp_tls) { + // if internal ssl the same as the foundation transport, + // just zero out, it will be freed on list destroy + t->data = NULL; + } + esp_transport_esp_tls_destroy(t->data); // okay to pass NULL + } return 0; } + void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.use_global_ca_store = true; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.use_global_ca_store = true; } void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.psk_hint_key = psk_hint_key; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.psk_hint_key = psk_hint_key; } void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.cacert_pem_buf = (void *)data; - t->base->transport_esp_tls->cfg.cacert_pem_bytes = len + 1; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.cacert_pem_buf = (void *)data; + ssl->cfg.cacert_pem_bytes = len + 1; } void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.cacert_buf = (void *)data; - t->base->transport_esp_tls->cfg.cacert_bytes = len; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.cacert_buf = (void *)data; + ssl->cfg.cacert_bytes = len; } void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.clientcert_pem_buf = (void *)data; - t->base->transport_esp_tls->cfg.clientcert_pem_bytes = len + 1; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.clientcert_pem_buf = (void *)data; + ssl->cfg.clientcert_pem_bytes = len + 1; } void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.clientcert_buf = (void *)data; - t->base->transport_esp_tls->cfg.clientcert_bytes = len; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.clientcert_buf = (void *)data; + ssl->cfg.clientcert_bytes = len; } void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.clientkey_pem_buf = (void *)data; - t->base->transport_esp_tls->cfg.clientkey_pem_bytes = len + 1; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.clientkey_pem_buf = (void *)data; + ssl->cfg.clientkey_pem_bytes = len + 1; } void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const char *password, int password_len) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.clientkey_password = (void *)password; - t->base->transport_esp_tls->cfg.clientkey_password_len = password_len; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.clientkey_password = (void *)password; + ssl->cfg.clientkey_password_len = password_len; } void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.clientkey_buf = (void *)data; - t->base->transport_esp_tls->cfg.clientkey_bytes = len; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.clientkey_buf = (void *)data; + ssl->cfg.clientkey_bytes = len; } void esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t, const char **alpn_protos) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.alpn_protos = alpn_protos; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.alpn_protos = alpn_protos; } void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.skip_common_name = true; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.skip_common_name = true; } void esp_transport_ssl_use_secure_element(esp_transport_handle_t t) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.use_secure_element = true; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.use_secure_element = true; } static int ssl_get_socket(esp_transport_handle_t t) { - if (t && t->base && t->base->transport_esp_tls && t->base->transport_esp_tls->tls) { - return t->base->transport_esp_tls->tls->sockfd; + transport_esp_tls_t *ssl = ssl_get_context_data(t); + if (ssl && ssl->tls) { + return ssl->tls->sockfd; } return -1; } void esp_transport_ssl_set_ds_data(esp_transport_handle_t t, void *ds_data) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.ds_data = ds_data; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.ds_data = ds_data; } void esp_transport_ssl_set_keep_alive(esp_transport_handle_t t, esp_transport_keep_alive_t *keep_alive_cfg) { - if (t && t->base && t->base->transport_esp_tls) { - t->base->transport_esp_tls->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *) keep_alive_cfg; - } + GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); + ssl->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *) keep_alive_cfg; } esp_transport_handle_t esp_transport_ssl_init(void) { esp_transport_handle_t t = esp_transport_init(); - esp_transport_set_context_data(t, NULL); esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy); esp_transport_set_async_connect_func(t, ssl_connect_async); t->_get_socket = ssl_get_socket; @@ -348,7 +368,6 @@ void esp_transport_esp_tls_destroy(struct transport_esp_tls* transport_esp_tls) esp_transport_handle_t esp_transport_tcp_init(void) { esp_transport_handle_t t = esp_transport_init(); - esp_transport_set_context_data(t, NULL); esp_transport_set_func(t, tcp_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy); esp_transport_set_async_connect_func(t, tcp_connect_async); t->_get_socket = ssl_get_socket;