From 6b318fe583ddd1a396608d687219a39f30484371 Mon Sep 17 00:00:00 2001 From: David Cermak Date: Mon, 25 Jan 2021 09:45:38 +0100 Subject: [PATCH] esp_tls: Updated connection method to use non-blocking connect For better control over connection timeouts To be in line with former tcp_transport, as esp-tls is not used for plain tcp transports --- components/esp-tls/esp_tls.c | 118 +++++++++--- components/esp-tls/esp_tls.h | 3 +- components/esp-tls/esp_tls_errors.h | 2 +- .../private_include/esp_transport_internal.h | 34 +++- .../esp_transport_ssl_internal.h | 30 --- components/tcp_transport/transport.c | 58 ++++-- components/tcp_transport/transport_ssl.c | 175 ++++++++---------- 7 files changed, 244 insertions(+), 176 deletions(-) delete mode 100644 components/tcp_transport/private_include/esp_transport_ssl_internal.h diff --git a/components/esp-tls/esp_tls.c b/components/esp-tls/esp_tls.c index c2a6ad64f9..cc3af34c1a 100644 --- a/components/esp-tls/esp_tls.c +++ b/components/esp-tls/esp_tls.c @@ -146,8 +146,10 @@ static esp_err_t resolve_host_name(const char *host, size_t hostlen, struct addr } ESP_LOGD(TAG, "host:%s: strlen %lu", use_host, (unsigned long)hostlen); - if (getaddrinfo(use_host, NULL, &hints, address_info)) { - ESP_LOGE(TAG, "couldn't get hostname for :%s:", use_host); + int res = getaddrinfo(use_host, NULL, &hints, address_info); + if (res != 0 || *address_info == NULL) { + ESP_LOGE(TAG, "couldn't get hostname for :%s: " + "getaddrinfo() returns %d, addrinfo=%p", use_host, res, *address_info); free(use_host); return ESP_ERR_ESP_TLS_CANNOT_RESOLVE_HOSTNAME; } @@ -209,9 +211,11 @@ static esp_err_t esp_tcp_connect(const char *host, int hostlen, int port, int *s if (addrinfo->ai_family == AF_INET) { struct sockaddr_in *p = (struct sockaddr_in *)addrinfo->ai_addr; p->sin_port = htons(port); + ESP_LOGD(TAG, "[sock=%d] Resolved IPv4 address: %s", fd, ipaddr_ntoa((const ip_addr_t*)&p->sin_addr.s_addr)); addr_ptr = p; } else if (addrinfo->ai_family == AF_INET6) { struct sockaddr_in6 *p = (struct sockaddr_in6 *)addrinfo->ai_addr; + ESP_LOGD(TAG, "[sock=%d] Resolved IPv6 address: %s", fd, ip6addr_ntoa((const ip6_addr_t*)&p->sin6_addr)); p->sin6_port = htons(port); p->sin6_family = AF_INET6; addr_ptr = p; @@ -221,37 +225,95 @@ static esp_err_t esp_tcp_connect(const char *host, int hostlen, int port, int *s goto err_freesocket; } - if (cfg) { - if (cfg->timeout_ms >= 0) { - struct timeval tv; - ms_to_timeval(cfg->timeout_ms, &tv); - setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); - setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); - if (cfg->keep_alive_cfg && cfg->keep_alive_cfg->keep_alive_enable) { - if (esp_tls_tcp_enable_keep_alive(fd, cfg->keep_alive_cfg) < 0) { - ESP_LOGE(TAG, "Error setting keep-alive"); - goto err_freesocket; - } - } - } - if (cfg->non_block) { - int flags = fcntl(fd, F_GETFL, 0); - ret = fcntl(fd, F_SETFL, flags | O_NONBLOCK); - if (ret < 0) { - ESP_LOGE(TAG, "Failed to configure the socket as non-blocking (errno %d)", errno); + if (cfg && cfg->timeout_ms >= 0) { + struct timeval tv; + ms_to_timeval(cfg->timeout_ms, &tv); + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); + if (cfg->keep_alive_cfg && cfg->keep_alive_cfg->keep_alive_enable) { + if (esp_tls_tcp_enable_keep_alive(fd, cfg->keep_alive_cfg) < 0) { + ESP_LOGE(TAG, "Error setting keep-alive"); goto err_freesocket; } } } - ret = connect(fd, addr_ptr, addrinfo->ai_addrlen); - if (ret < 0 && !(errno == EINPROGRESS && cfg && cfg->non_block)) { - - ESP_LOGE(TAG, "Failed to connnect to host (errno %d)", errno); - ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ESP_TLS_ERR_TYPE_SYSTEM, errno); - ret = ESP_ERR_ESP_TLS_FAILED_CONNECT_TO_HOST; + // Set socket to non-blocking + int flags; + if ((flags = fcntl(fd, F_GETFL, NULL)) < 0) { + ESP_LOGE(TAG, "[sock=%d] get file flags error: %s", fd, strerror(errno)); goto err_freesocket; } + if (fcntl(fd, F_SETFL, flags |= O_NONBLOCK) < 0) { + ESP_LOGE(TAG, "[sock=%d] set nonblocking error: %s", fd, strerror(errno)); + goto err_freesocket; + } + + ESP_LOGD(TAG, "[sock=%d] Connecting to server. HOST: %s, Port: %d", fd, host, port); + + if (connect(fd, (struct sockaddr *)(addr_ptr), sizeof(struct sockaddr)) < 0) { + if (errno == EINPROGRESS) { + fd_set fdset; + struct timeval tv = { .tv_usec = 0, .tv_sec = 10 }; // Default connection timeout is 10 s + + if (cfg && cfg->non_block) { + // Non-blocking mode -> just return successfully at this stage + *sockfd = fd; + freeaddrinfo(addrinfo); + return ESP_OK; + } + + if ( cfg && cfg->timeout_ms > 0 ) { + ms_to_timeval(cfg->timeout_ms, &tv); + } + FD_ZERO(&fdset); + FD_SET(fd, &fdset); + + int res = select(fd+1, NULL, &fdset, NULL, &tv); + if (res < 0) { + ESP_LOGE(TAG, "[sock=%d] select() error: %s", fd, strerror(errno)); + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ESP_TLS_ERR_TYPE_SYSTEM, errno); + goto err_freesocket; + } + else if (res == 0) { + ESP_LOGE(TAG, "[sock=%d] select() timeout", fd); + ret = ESP_ERR_ESP_TLS_FAILED_CONNECT_TO_HOST; + goto err_freesocket; + } else { + int sockerr; + socklen_t len = (socklen_t)sizeof(int); + + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (void*)(&sockerr), &len) < 0) { + ESP_LOGE(TAG, "[sock=%d] getsockopt() error: %s", fd, strerror(errno)); + ret = ESP_ERR_ESP_TLS_SOCKET_SETOPT_FAILED; + goto err_freesocket; + } + else if (sockerr) { + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ESP_TLS_ERR_TYPE_SYSTEM, sockerr); + ESP_LOGE(TAG, "[sock=%d] delayed connect error: %s", fd, strerror(sockerr)); + ret = ESP_ERR_ESP_TLS_FAILED_CONNECT_TO_HOST; + goto err_freesocket; + } + } + } else { + ESP_LOGE(TAG, "[sock=%d] connect() error: %s", fd, strerror(errno)); + goto err_freesocket; + } + } + + if (cfg && cfg->non_block == false) { + // Reset socket to blocking (unless non-blocking option set) + if ((flags = fcntl(fd, F_GETFL, NULL)) < 0) { + ESP_LOGE(TAG, "[sock=%d] get file flags error: %s", fd, strerror(errno)); + ret = ESP_ERR_ESP_TLS_SOCKET_SETOPT_FAILED; + goto err_freesocket; + } + if (fcntl(fd, F_SETFL, flags & ~O_NONBLOCK) < 0) { + ESP_LOGE(TAG, "[sock=%d] reset blocking error: %s", fd, strerror(errno)); + ret = ESP_ERR_ESP_TLS_SOCKET_SETOPT_FAILED; + goto err_freesocket; + } + } *sockfd = fd; freeaddrinfo(addrinfo); @@ -292,7 +354,7 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c ESP_LOGD(TAG, "non-tls connection established"); return 1; } - if (cfg->non_block) { + if (cfg && cfg->non_block) { FD_ZERO(&tls->rset); FD_SET(tls->sockfd, &tls->rset); tls->wset = tls->rset; @@ -300,7 +362,7 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c tls->conn_state = ESP_TLS_CONNECTING; /* falls through */ case ESP_TLS_CONNECTING: - if (cfg->non_block) { + if (cfg && cfg->non_block) { ESP_LOGD(TAG, "connecting..."); struct timeval tv; ms_to_timeval(cfg->timeout_ms, &tv); diff --git a/components/esp-tls/esp_tls.h b/components/esp-tls/esp_tls.h index f3e36a22b5..981379435c 100644 --- a/components/esp-tls/esp_tls.h +++ b/components/esp-tls/esp_tls.h @@ -170,7 +170,8 @@ typedef struct esp_tls_cfg { bundle for server verification, must be enabled in menuconfig */ void *ds_data; /*!< Pointer for digital signature peripheral context */ - bool is_plain_tcp; + bool is_plain_tcp; /*!< Use non-TLS connection: When set to true, the esp-tls uses + plain TCP transport rather then TLS/SSL connection */ } esp_tls_cfg_t; #ifdef CONFIG_ESP_TLS_SERVER diff --git a/components/esp-tls/esp_tls_errors.h b/components/esp-tls/esp_tls_errors.h index 3b39b56209..a24fb7f7f1 100644 --- a/components/esp-tls/esp_tls_errors.h +++ b/components/esp-tls/esp_tls_errors.h @@ -26,7 +26,7 @@ extern "C" { #define ESP_ERR_ESP_TLS_CANNOT_CREATE_SOCKET (ESP_ERR_ESP_TLS_BASE + 0x02) /*!< Failed to create socket */ #define ESP_ERR_ESP_TLS_UNSUPPORTED_PROTOCOL_FAMILY (ESP_ERR_ESP_TLS_BASE + 0x03) /*!< Unsupported protocol family */ #define ESP_ERR_ESP_TLS_FAILED_CONNECT_TO_HOST (ESP_ERR_ESP_TLS_BASE + 0x04) /*!< Failed to connect to host */ -#define ESP_ERR_ESP_TLS_SOCKET_SETOPT_FAILED (ESP_ERR_ESP_TLS_BASE + 0x05) /*!< failed to set socket option */ +#define ESP_ERR_ESP_TLS_SOCKET_SETOPT_FAILED (ESP_ERR_ESP_TLS_BASE + 0x05) /*!< failed to set/get socket option */ #define ESP_ERR_MBEDTLS_CERT_PARTLY_OK (ESP_ERR_ESP_TLS_BASE + 0x06) /*!< mbedtls parse certificates was partly successful */ #define ESP_ERR_MBEDTLS_CTR_DRBG_SEED_FAILED (ESP_ERR_ESP_TLS_BASE + 0x07) /*!< mbedtls api returned error */ #define ESP_ERR_MBEDTLS_SSL_SET_HOSTNAME_FAILED (ESP_ERR_ESP_TLS_BASE + 0x08) /*!< mbedtls api returned error */ diff --git a/components/tcp_transport/private_include/esp_transport_internal.h b/components/tcp_transport/private_include/esp_transport_internal.h index 761efeef29..ea9c64b4a8 100644 --- a/components/tcp_transport/private_include/esp_transport_internal.h +++ b/components/tcp_transport/private_include/esp_transport_internal.h @@ -20,7 +20,10 @@ typedef int (*get_socket_func)(esp_transport_handle_t t); -struct transport_esp_tls; +typedef struct esp_foundation_transport { + struct esp_transport_error_storage *error_handle; /*!< Pointer to the transport error container */ + struct transport_esp_tls *transport_esp_tls; /*!< Pointer to the base transport which uses esp-tls */ +} esp_foundation_transport_t; /** * Transport layer structure, which will provide functions, basic properties for transport types @@ -39,10 +42,8 @@ struct esp_transport_item_t { connect_async_func _connect_async; /*!< non-blocking connect function of this transport */ payload_transfer_func _parent_transfer; /*!< Function returning underlying transport layer */ get_socket_func _get_socket; /*!< Function returning the transport's socket */ - struct esp_transport_error_s* error_handle; /*!< Error handle (based on esp-tls error handle) - * extended with transport's specific errors */ esp_transport_keep_alive_t *keep_alive_cfg; /*!< TCP keep-alive config */ - struct transport_esp_tls *foundation_transport; + struct esp_foundation_transport *base; /*!< Foundation transport pointer available from each transport */ STAILQ_ENTRY(esp_transport_item_t) next; }; @@ -89,6 +90,29 @@ int esp_transport_get_socket(esp_transport_handle_t t); */ void esp_transport_capture_errno(esp_transport_handle_t t, int sock_errno); -struct transport_esp_tls* esp_transport_init_foundation(void); +/** + * @brief Creates esp-tls transport used in the foundation transport + * + * @return transport esp-tls handle + */ +struct transport_esp_tls* esp_transport_esp_tls_create(void); + +/** + * @brief Destroys esp-tls transport used in the foundation transport + * + * @param[in] transport esp-tls handle + */ +void esp_transport_esp_tls_destroy(struct transport_esp_tls* transport_esp_tls); + +/** + * @brief Sets error to common transport handle + * + * Note: This function copies the supplied error handle object to tcp_transport's internal + * error handle object + * + * @param[in] A transport handle + * + */ +void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_handle_t error_handle); #endif //_ESP_TRANSPORT_INTERNAL_H_ diff --git a/components/tcp_transport/private_include/esp_transport_ssl_internal.h b/components/tcp_transport/private_include/esp_transport_ssl_internal.h deleted file mode 100644 index 07b8c39435..0000000000 --- a/components/tcp_transport/private_include/esp_transport_ssl_internal.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2015-2019 Espressif Systems (Shanghai) PTE LTD -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef _ESP_TRANSPORT_SSL_INTERNAL_H_ -#define _ESP_TRANSPORT_SSL_INTERNAL_H_ - -/** - * @brief Sets error to common transport handle - * - * Note: This function copies the supplied error handle object to tcp_transport's internal - * error handle object - * - * @param[in] A transport handle - * - */ -void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_handle_t error_handle); - - -#endif /* _ESP_TRANSPORT_SSL_INTERNAL_H_ */ diff --git a/components/tcp_transport/transport.c b/components/tcp_transport/transport.c index f6ae753279..8de4f9d355 100644 --- a/components/tcp_transport/transport.c +++ b/components/tcp_transport/transport.c @@ -31,7 +31,7 @@ static const char *TAG = "TRANSPORT"; * * esp-tls last error storage * * sock-errno */ -struct esp_transport_error_s { +struct esp_transport_error_storage { struct esp_tls_last_error esp_tls_err_h_base; /*!< esp-tls last error container */ // additional fields int sock_errno; /*!< last socket error captured for this transport */ @@ -49,10 +49,32 @@ struct transport_esp_tls; */ typedef struct esp_transport_internal { struct esp_transport_list_t list; /*!< List of transports */ - struct esp_transport_error_s* error_handle; /*!< Pointer to the transport error container */ - struct transport_esp_tls *foundation_transport; + struct esp_foundation_transport *base; /*!< Base transport pointer shared for each list item */ } esp_transport_internal_t; +static esp_foundation_transport_t * esp_transport_init_foundation_transport(void) +{ + esp_foundation_transport_t *foundation = calloc(1, sizeof(esp_foundation_transport_t)); + ESP_TRANSPORT_MEM_CHECK(TAG, foundation, return NULL); + foundation->error_handle = calloc(1, sizeof(struct esp_transport_error_storage)); + ESP_TRANSPORT_MEM_CHECK(TAG, foundation->error_handle, + free(foundation); + return NULL); + foundation->transport_esp_tls = esp_transport_esp_tls_create(); + ESP_TRANSPORT_MEM_CHECK(TAG, foundation->transport_esp_tls, + free(foundation->error_handle); + free(foundation); + return NULL); + return foundation; +} + +static void esp_transport_destroy_foundation_transport(esp_foundation_transport_t *foundation) +{ + esp_transport_esp_tls_destroy(foundation->transport_esp_tls); + free(foundation->error_handle); + free(foundation); +} + static esp_transport_handle_t esp_transport_get_default_parent(esp_transport_handle_t t) { /* @@ -66,8 +88,10 @@ esp_transport_list_handle_t esp_transport_list_init(void) esp_transport_list_handle_t transport = calloc(1, sizeof(esp_transport_internal_t)); ESP_TRANSPORT_MEM_CHECK(TAG, transport, return NULL); STAILQ_INIT(&transport->list); - transport->error_handle = calloc(1, sizeof(struct esp_transport_error_s)); - transport->foundation_transport = esp_transport_init_foundation(); + transport->base = esp_transport_init_foundation_transport(); + ESP_TRANSPORT_MEM_CHECK(TAG, transport->base, + free(transport); + return NULL); return transport; } @@ -81,8 +105,7 @@ esp_err_t esp_transport_list_add(esp_transport_list_handle_t h, esp_transport_ha strcpy(t->scheme, scheme); STAILQ_INSERT_TAIL(&h->list, t, next); // Each transport in a list to share the same error tracker - t->error_handle = h->error_handle; - t->foundation_transport = h->foundation_transport; + t->base = h->base; return ESP_OK; } @@ -106,8 +129,7 @@ esp_transport_handle_t esp_transport_list_get_transport(esp_transport_list_handl esp_err_t esp_transport_list_destroy(esp_transport_list_handle_t h) { esp_transport_list_clean(h); - free(h->error_handle); - free(h->foundation_transport); // TODO: make it destroy foundation + esp_transport_destroy_foundation_transport(h->base); free(h); return ESP_OK; } @@ -289,16 +311,16 @@ esp_err_t esp_transport_set_parent_transport_func(esp_transport_handle_t t, payl esp_tls_error_handle_t esp_transport_get_error_handle(esp_transport_handle_t t) { if (t) { - return &t->error_handle->esp_tls_err_h_base; + return &t->base->error_handle->esp_tls_err_h_base; } return NULL; } int esp_transport_get_errno(esp_transport_handle_t t) { - if (t && t->error_handle) { - int actual_errno = t->error_handle->sock_errno; - t->error_handle->sock_errno = 0; + if (t && t->base && t->base->error_handle) { + int actual_errno = t->base->error_handle->sock_errno; + t->base->error_handle->sock_errno = 0; return actual_errno; } return -1; @@ -328,19 +350,19 @@ void capture_tcp_transport_error(esp_transport_handle_t t, enum tcp_transport_er void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_handle_t error_handle) { - if (t && t->error_handle) { - memcpy(&t->error_handle->esp_tls_err_h_base, error_handle, sizeof(esp_tls_last_error_t)); + if (t && t->base && t->base->error_handle) { + memcpy(&t->base->error_handle->esp_tls_err_h_base, error_handle, sizeof(esp_tls_last_error_t)); int sock_error; if (esp_tls_get_and_clear_error_type(error_handle, ESP_TLS_ERR_TYPE_SYSTEM, &sock_error) == ESP_OK) { - t->error_handle->sock_errno = sock_error; + t->base->error_handle->sock_errno = sock_error; } } } void esp_transport_capture_errno(esp_transport_handle_t t, int sock_errno) { - if (t && t->error_handle) { - t->error_handle->sock_errno = sock_errno; + if (t && t->base && t->base->error_handle) { + t->base->error_handle->sock_errno = sock_errno; } } diff --git a/components/tcp_transport/transport_ssl.c b/components/tcp_transport/transport_ssl.c index 96034c76e3..f503aa74d4 100644 --- a/components/tcp_transport/transport_ssl.c +++ b/components/tcp_transport/transport_ssl.c @@ -15,19 +15,15 @@ #include #include -#include "freertos/FreeRTOS.h" -#include "freertos/task.h" #include "esp_tls.h" #include "esp_log.h" -#include "esp_system.h" #include "esp_transport.h" #include "esp_transport_ssl.h" #include "esp_transport_utils.h" -#include "esp_transport_ssl_internal.h" #include "esp_transport_internal.h" -static const char *TAG = "TRANS_SSL"; +static const char *TAG = "TRANSPORT_BASE"; typedef enum { TRANS_SSL_INIT = 0, @@ -42,15 +38,16 @@ typedef struct transport_esp_tls { esp_tls_cfg_t cfg; bool ssl_initialized; transport_ssl_conn_state_t conn_state; -} transport_ssl_t; +} transport_esp_tls_t; static int ssl_close(esp_transport_handle_t t); -static int ssl_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms) +static int esp_tls_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp) { - transport_ssl_t *ssl = t->foundation_transport; + transport_esp_tls_t *ssl = t->base->transport_esp_tls; if (ssl->conn_state == TRANS_SSL_INIT) { ssl->cfg.timeout_ms = timeout_ms; + ssl->cfg.is_plain_tcp = is_plain_tcp; ssl->cfg.non_block = true; ssl->ssl_initialized = true; ssl->tls = esp_tls_init(); @@ -65,11 +62,23 @@ static int ssl_connect_async(esp_transport_handle_t t, const char *host, int por return 0; } -static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) +static inline int ssl_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms) { - transport_ssl_t *ssl = t->foundation_transport; + return esp_tls_connect_async(t, host, port, timeout_ms, false); +} + +static inline int tcp_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms) +{ + return esp_tls_connect_async(t, host, port, timeout_ms, true); +} + +static int esp_tls_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp) +{ + transport_esp_tls_t *ssl = t->base->transport_esp_tls; ssl->cfg.timeout_ms = timeout_ms; + ssl->cfg.is_plain_tcp = is_plain_tcp; + ssl->ssl_initialized = true; ssl->tls = esp_tls_init(); if (esp_tls_conn_new_sync(host, strlen(host), port, &ssl->cfg, ssl->tls) <= 0) { @@ -79,33 +88,22 @@ static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int ssl->tls = NULL; return -1; } - return 0; } -static int tcp_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) +static inline int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) { - transport_ssl_t *ssl = t->foundation_transport; - - ssl->cfg.timeout_ms = timeout_ms; - ssl->cfg.is_plain_tcp = true; - ssl->ssl_initialized = true; - ssl->tls = esp_tls_init(); - if (esp_tls_conn_new_sync(host, strlen(host), port, &ssl->cfg, ssl->tls) <= 0) { - ESP_LOGE(TAG, "Failed to open a new connection"); - esp_transport_set_errors(t, ssl->tls->error_handle); - esp_tls_conn_destroy(ssl->tls); - ssl->tls = NULL; - return -1; - } - - return 0; + return esp_tls_connect(t, host, port, timeout_ms, false); } +static inline int tcp_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) +{ + return esp_tls_connect(t, host, port, timeout_ms, true); +} static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms) { - transport_ssl_t *ssl = t->foundation_transport; + transport_esp_tls_t *ssl = t->base->transport_esp_tls; int ret = -1; int remain = 0; struct timeval timeout; @@ -134,7 +132,7 @@ static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms) static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms) { - transport_ssl_t *ssl = t->foundation_transport; + transport_esp_tls_t *ssl = t->base->transport_esp_tls; int ret = -1; struct timeval timeout; fd_set writeset; @@ -158,7 +156,7 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms) static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms) { int poll, ret; - transport_ssl_t *ssl = t->foundation_transport; + transport_esp_tls_t *ssl = t->base->transport_esp_tls; if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) { ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms); @@ -175,7 +173,7 @@ static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) { int poll, ret; - transport_ssl_t *ssl = t->foundation_transport; + transport_esp_tls_t *ssl = t->base->transport_esp_tls; if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) { return poll; @@ -198,8 +196,8 @@ static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout static int ssl_close(esp_transport_handle_t t) { int ret = -1; - transport_ssl_t *ssl = t->foundation_transport; - if (ssl->ssl_initialized) { + if (t && t->base && t->base->transport_esp_tls && t->base->transport_esp_tls->ssl_initialized) { + transport_esp_tls_t *ssl = t->base->transport_esp_tls; ret = esp_tls_conn_destroy(ssl->tls); ssl->conn_state = TRANS_SSL_INIT; ssl->ssl_initialized = false; @@ -209,139 +207,120 @@ static int ssl_close(esp_transport_handle_t t) static int ssl_destroy(esp_transport_handle_t t) { - transport_ssl_t *ssl = t->foundation_transport; esp_transport_close(t); - free(ssl); return 0; } void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.use_global_ca_store = true; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.use_global_ca_store = true; } } void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.psk_hint_key = psk_hint_key; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.psk_hint_key = psk_hint_key; } } void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.cacert_pem_buf = (void *)data; - ssl->cfg.cacert_pem_bytes = len + 1; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.cacert_pem_buf = (void *)data; + t->base->transport_esp_tls->cfg.cacert_pem_bytes = len + 1; } } void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.cacert_buf = (void *)data; - ssl->cfg.cacert_bytes = len; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.cacert_buf = (void *)data; + t->base->transport_esp_tls->cfg.cacert_bytes = len; } } void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.clientcert_pem_buf = (void *)data; - ssl->cfg.clientcert_pem_bytes = len + 1; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.clientcert_pem_buf = (void *)data; + t->base->transport_esp_tls->cfg.clientcert_pem_bytes = len + 1; } } void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.clientcert_buf = (void *)data; - ssl->cfg.clientcert_bytes = len; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.clientcert_buf = (void *)data; + t->base->transport_esp_tls->cfg.clientcert_bytes = len; } } void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.clientkey_pem_buf = (void *)data; - ssl->cfg.clientkey_pem_bytes = len + 1; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.clientkey_pem_buf = (void *)data; + t->base->transport_esp_tls->cfg.clientkey_pem_bytes = len + 1; } } void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const char *password, int password_len) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.clientkey_password = (void *)password; - ssl->cfg.clientkey_password_len = password_len; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.clientkey_password = (void *)password; + t->base->transport_esp_tls->cfg.clientkey_password_len = password_len; } } void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.clientkey_buf = (void *)data; - ssl->cfg.clientkey_bytes = len; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.clientkey_buf = (void *)data; + t->base->transport_esp_tls->cfg.clientkey_bytes = len; } } void esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t, const char **alpn_protos) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.alpn_protos = alpn_protos; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.alpn_protos = alpn_protos; } } void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.skip_common_name = true; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.skip_common_name = true; } } void esp_transport_ssl_use_secure_element(esp_transport_handle_t t) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { - ssl->cfg.use_secure_element = true; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.use_secure_element = true; } } static int ssl_get_socket(esp_transport_handle_t t) { - if (t) { - transport_ssl_t *ssl = t->data; - if (ssl && ssl->tls) { - return ssl->tls->sockfd; - } + if (t && t->base && t->base->transport_esp_tls && t->base->transport_esp_tls->tls) { + return t->base->transport_esp_tls->tls->sockfd; } return -1; } void esp_transport_ssl_set_ds_data(esp_transport_handle_t t, void *ds_data) { - transport_ssl_t *ssl = t->foundation_transport; - if (t && ssl) { // TODO: check t NULL first! - ssl->cfg.ds_data = ds_data; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.ds_data = ds_data; } } void esp_transport_ssl_set_keep_alive(esp_transport_handle_t t, esp_transport_keep_alive_t *keep_alive_cfg) { - transport_ssl_t *ssl = esp_transport_get_context_data(t); - if (t && ssl) { - ssl->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *)keep_alive_cfg; + if (t && t->base && t->base->transport_esp_tls) { + t->base->transport_esp_tls->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *) keep_alive_cfg; } } @@ -355,10 +334,15 @@ esp_transport_handle_t esp_transport_ssl_init(void) return t; } -struct transport_esp_tls* esp_transport_init_foundation(void) +struct transport_esp_tls* esp_transport_esp_tls_create(void) { - transport_ssl_t *ssl = calloc(1, sizeof(transport_ssl_t)); - return ssl; + transport_esp_tls_t *transport_esp_tls = calloc(1, sizeof(transport_esp_tls_t)); + return transport_esp_tls; +} + +void esp_transport_esp_tls_destroy(struct transport_esp_tls* transport_esp_tls) +{ + free(transport_esp_tls); } esp_transport_handle_t esp_transport_tcp_init(void) @@ -366,7 +350,12 @@ esp_transport_handle_t esp_transport_tcp_init(void) esp_transport_handle_t t = esp_transport_init(); esp_transport_set_context_data(t, NULL); esp_transport_set_func(t, tcp_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy); - esp_transport_set_async_connect_func(t, ssl_connect_async); // TODO: tcp_connect_async() + esp_transport_set_async_connect_func(t, tcp_connect_async); t->_get_socket = ssl_get_socket; return t; } + +void esp_transport_tcp_set_keep_alive(esp_transport_handle_t t, esp_transport_keep_alive_t *keep_alive_cfg) +{ + return esp_transport_ssl_set_keep_alive(t, keep_alive_cfg); +}