tcp_transport: Allow for using transport independently from list

For compatibility reasons, support also transports separately if the transport is used
before attaching to parent list. In this case we create an internal context which is
independent on the foundation transport and used preferably
pull/6718/head
David Cermak 2021-02-04 10:18:56 +01:00
rodzic 41146d674f
commit 1fa0db8d44
2 zmienionych plików z 141 dodań i 80 usunięć

Wyświetl plik

@ -190,7 +190,7 @@ static void transport_connection_timeout_test(esp_transport_handle_t transport_u
EventBits_t bits = xEventGroupWaitBits(params.tcp_connect_done, TCP_CONNECT_DONE, true, true, max_wait);
TickType_t end = xTaskGetTickCount();
TEST_ASSERT_EQUAL(TCP_CONNECT_DONE, TCP_CONNECT_DONE&bits); // Connection has finished
TEST_ASSERT_EQUAL(TCP_CONNECT_DONE, TCP_CONNECT_DONE & bits); // Connection has finished
TEST_ASSERT_EQUAL(-1, params.ret); // Connection failed with -1
// Test connection attempt took expected timeout value
@ -316,24 +316,21 @@ static void socket_operation_test(esp_transport_handle_t transport_under_test,
test_utils_task_delete(tcp_connect_task_handle);
}
static void tcp_transport_keepalive_test(esp_transport_handle_t transport_under_test)
static void tcp_transport_keepalive_test(esp_transport_handle_t transport_under_test, esp_transport_keep_alive_t *config)
{
static const int KEEP_ALIVE_INTERVAL = 1;
static const int KEEP_ALIVE_IDLE = 2;
static const int KEEP_ALIVE_COUNT = 3;
static const struct expected_sock_option expected_opts[] = {
static struct expected_sock_option expected_opts[4] = {
{ .level = SOL_SOCKET, .optname = SO_KEEPALIVE, .optval = 1, .opttype = SOCK_OPT_TYPE_BOOL },
{ .level = IPPROTO_TCP, .optname = TCP_KEEPIDLE, .optval = KEEP_ALIVE_IDLE, .opttype = SOCK_OPT_TYPE_INT },
{ .level = IPPROTO_TCP, .optname = TCP_KEEPINTVL, .optval = KEEP_ALIVE_INTERVAL, .opttype = SOCK_OPT_TYPE_INT },
{ .level = IPPROTO_TCP, .optname = TCP_KEEPCNT, .optval = KEEP_ALIVE_COUNT, .opttype = SOCK_OPT_TYPE_INT },
{ .level = IPPROTO_TCP },
{ .level = IPPROTO_TCP },
{ .level = IPPROTO_TCP }
};
esp_transport_keep_alive_t keep_alive_cfg = { .keep_alive_interval = KEEP_ALIVE_INTERVAL,
.keep_alive_idle = KEEP_ALIVE_IDLE,
.keep_alive_enable = true,
.keep_alive_count = KEEP_ALIVE_COUNT };
esp_transport_tcp_set_keep_alive(transport_under_test, &keep_alive_cfg);
expected_opts[1].optname = TCP_KEEPIDLE;
expected_opts[1].optval = config->keep_alive_idle;
expected_opts[2].optname = TCP_KEEPINTVL;
expected_opts[2].optval = config->keep_alive_interval;
expected_opts[3].optname = TCP_KEEPCNT;
expected_opts[3].optval = config->keep_alive_count;
socket_operation_test(transport_under_test, expected_opts, sizeof(expected_opts)/sizeof(struct expected_sock_option));
}
@ -346,7 +343,14 @@ TEST_CASE("tcp_transport: Keep alive test", "[tcp_transport]")
esp_transport_list_add(transport_list, tcp, "tcp");
// Perform the test
tcp_transport_keepalive_test(tcp);
esp_transport_keep_alive_t keep_alive_cfg = {
.keep_alive_interval = 5,
.keep_alive_idle = 4,
.keep_alive_enable = true,
.keep_alive_count = 3 };
esp_transport_tcp_set_keep_alive(tcp, &keep_alive_cfg);
tcp_transport_keepalive_test(tcp, &keep_alive_cfg);
// Cleanup
esp_transport_close(tcp);
@ -363,7 +367,14 @@ TEST_CASE("ssl_transport: Keep alive test", "[tcp_transport]")
esp_transport_ssl_enable_global_ca_store(ssl);
// Perform the test
tcp_transport_keepalive_test(ssl);
esp_transport_keep_alive_t keep_alive_cfg = {
.keep_alive_interval = 2,
.keep_alive_idle = 3,
.keep_alive_enable = true,
.keep_alive_count = 4 };
esp_transport_ssl_set_keep_alive(ssl, &keep_alive_cfg);
tcp_transport_keepalive_test(ssl, &keep_alive_cfg);
// Cleanup
esp_transport_close(ssl);
@ -382,9 +393,40 @@ TEST_CASE("ws_transport: Keep alive test", "[tcp_transport]")
esp_transport_list_add(transport_list, ws, "wss");
// Perform the test
tcp_transport_keepalive_test(ws);
esp_transport_keep_alive_t keep_alive_cfg = {
.keep_alive_interval = 1,
.keep_alive_idle = 2,
.keep_alive_enable = true,
.keep_alive_count = 3 };
esp_transport_tcp_set_keep_alive(ssl, &keep_alive_cfg);
tcp_transport_keepalive_test(ws, &keep_alive_cfg);
// Cleanup
esp_transport_close(ssl);
esp_transport_list_destroy(transport_list);
}
// Note: This functionality is tested and kept only for compatibility reasons with IDF <= 4.x
// It is strongly encouraged to use transport within lists only
TEST_CASE("ssl_transport: Check that parameters (keepalive) are set independently on the list", "[tcp_transport]")
{
// Init the transport under test
esp_transport_handle_t ssl = esp_transport_ssl_init();
esp_tls_init_global_ca_store();
esp_transport_ssl_enable_global_ca_store(ssl);
// Perform the test
esp_transport_keep_alive_t keep_alive_cfg = {
.keep_alive_interval = 2,
.keep_alive_idle = 4,
.keep_alive_enable = true,
.keep_alive_count = 3 };
esp_transport_ssl_set_keep_alive(ssl, &keep_alive_cfg);
tcp_transport_keepalive_test(ssl, &keep_alive_cfg);
// Cleanup
esp_transport_close(ssl);
esp_transport_destroy(ssl);
}

Wyświetl plik

@ -23,6 +23,10 @@
#include "esp_transport_utils.h"
#include "esp_transport_internal.h"
#define GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t) \
transport_esp_tls_t *ssl = ssl_get_context_data(t); \
if (!ssl) { return; }
static const char *TAG = "TRANSPORT_BASE";
typedef enum {
@ -40,11 +44,30 @@ typedef struct transport_esp_tls {
transport_ssl_conn_state_t conn_state;
} transport_esp_tls_t;
static inline struct transport_esp_tls * ssl_get_context_data(esp_transport_handle_t t)
{
if (!t) {
return NULL;
}
if (t->data) { // Prefer internal ssl context (independent from the list)
return (transport_esp_tls_t*)t->data;
}
if (t->base && t->base->transport_esp_tls) { // Next one is the lists inherent context
t->data = t->base->transport_esp_tls; // Optimize: if we have base context, use it as internal
return t->base->transport_esp_tls;
}
// If we don't have a valid context, let's to create one
transport_esp_tls_t *ssl = esp_transport_esp_tls_create();
ESP_TRANSPORT_MEM_CHECK(TAG, ssl, return NULL)
t->data = ssl;
return ssl;
}
static int ssl_close(esp_transport_handle_t t);
static int esp_tls_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp)
{
transport_esp_tls_t *ssl = t->base->transport_esp_tls;
transport_esp_tls_t *ssl = ssl_get_context_data(t);
if (ssl->conn_state == TRANS_SSL_INIT) {
ssl->cfg.timeout_ms = timeout_ms;
ssl->cfg.is_plain_tcp = is_plain_tcp;
@ -74,7 +97,7 @@ static inline int tcp_connect_async(esp_transport_handle_t t, const char *host,
static int esp_tls_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp)
{
transport_esp_tls_t *ssl = t->base->transport_esp_tls;
transport_esp_tls_t *ssl = ssl_get_context_data(t);
ssl->cfg.timeout_ms = timeout_ms;
ssl->cfg.is_plain_tcp = is_plain_tcp;
@ -103,7 +126,7 @@ static inline int tcp_connect(esp_transport_handle_t t, const char *host, int po
static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)
{
transport_esp_tls_t *ssl = t->base->transport_esp_tls;
transport_esp_tls_t *ssl = ssl_get_context_data(t);
int ret = -1;
int remain = 0;
struct timeval timeout;
@ -132,7 +155,7 @@ static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)
static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
{
transport_esp_tls_t *ssl = t->base->transport_esp_tls;
transport_esp_tls_t *ssl = ssl_get_context_data(t);
int ret = -1;
struct timeval timeout;
fd_set writeset;
@ -156,7 +179,7 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
{
int poll, ret;
transport_esp_tls_t *ssl = t->base->transport_esp_tls;
transport_esp_tls_t *ssl = ssl_get_context_data(t);
if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms);
@ -173,7 +196,7 @@ static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int
static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
{
int poll, ret;
transport_esp_tls_t *ssl = t->base->transport_esp_tls;
transport_esp_tls_t *ssl = ssl_get_context_data(t);
if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) {
return poll;
@ -196,8 +219,8 @@ static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout
static int ssl_close(esp_transport_handle_t t)
{
int ret = -1;
if (t && t->base && t->base->transport_esp_tls && t->base->transport_esp_tls->ssl_initialized) {
transport_esp_tls_t *ssl = t->base->transport_esp_tls;
transport_esp_tls_t *ssl = ssl_get_context_data(t);
if (ssl && ssl->ssl_initialized) {
ret = esp_tls_conn_destroy(ssl->tls);
ssl->conn_state = TRANS_SSL_INIT;
ssl->ssl_initialized = false;
@ -207,127 +230,124 @@ static int ssl_close(esp_transport_handle_t t)
static int ssl_destroy(esp_transport_handle_t t)
{
esp_transport_close(t);
transport_esp_tls_t *ssl = ssl_get_context_data(t);
if (ssl) {
esp_transport_close(t);
if (t->base && t->base->transport_esp_tls &&
t->data == t->base->transport_esp_tls) {
// if internal ssl the same as the foundation transport,
// just zero out, it will be freed on list destroy
t->data = NULL;
}
esp_transport_esp_tls_destroy(t->data); // okay to pass NULL
}
return 0;
}
void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.use_global_ca_store = true;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.use_global_ca_store = true;
}
void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.psk_hint_key = psk_hint_key;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.psk_hint_key = psk_hint_key;
}
void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.cacert_pem_buf = (void *)data;
t->base->transport_esp_tls->cfg.cacert_pem_bytes = len + 1;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.cacert_pem_buf = (void *)data;
ssl->cfg.cacert_pem_bytes = len + 1;
}
void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.cacert_buf = (void *)data;
t->base->transport_esp_tls->cfg.cacert_bytes = len;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.cacert_buf = (void *)data;
ssl->cfg.cacert_bytes = len;
}
void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.clientcert_pem_buf = (void *)data;
t->base->transport_esp_tls->cfg.clientcert_pem_bytes = len + 1;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.clientcert_pem_buf = (void *)data;
ssl->cfg.clientcert_pem_bytes = len + 1;
}
void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.clientcert_buf = (void *)data;
t->base->transport_esp_tls->cfg.clientcert_bytes = len;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.clientcert_buf = (void *)data;
ssl->cfg.clientcert_bytes = len;
}
void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.clientkey_pem_buf = (void *)data;
t->base->transport_esp_tls->cfg.clientkey_pem_bytes = len + 1;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.clientkey_pem_buf = (void *)data;
ssl->cfg.clientkey_pem_bytes = len + 1;
}
void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const char *password, int password_len)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.clientkey_password = (void *)password;
t->base->transport_esp_tls->cfg.clientkey_password_len = password_len;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.clientkey_password = (void *)password;
ssl->cfg.clientkey_password_len = password_len;
}
void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.clientkey_buf = (void *)data;
t->base->transport_esp_tls->cfg.clientkey_bytes = len;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.clientkey_buf = (void *)data;
ssl->cfg.clientkey_bytes = len;
}
void esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t, const char **alpn_protos)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.alpn_protos = alpn_protos;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.alpn_protos = alpn_protos;
}
void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.skip_common_name = true;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.skip_common_name = true;
}
void esp_transport_ssl_use_secure_element(esp_transport_handle_t t)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.use_secure_element = true;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.use_secure_element = true;
}
static int ssl_get_socket(esp_transport_handle_t t)
{
if (t && t->base && t->base->transport_esp_tls && t->base->transport_esp_tls->tls) {
return t->base->transport_esp_tls->tls->sockfd;
transport_esp_tls_t *ssl = ssl_get_context_data(t);
if (ssl && ssl->tls) {
return ssl->tls->sockfd;
}
return -1;
}
void esp_transport_ssl_set_ds_data(esp_transport_handle_t t, void *ds_data)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.ds_data = ds_data;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.ds_data = ds_data;
}
void esp_transport_ssl_set_keep_alive(esp_transport_handle_t t, esp_transport_keep_alive_t *keep_alive_cfg)
{
if (t && t->base && t->base->transport_esp_tls) {
t->base->transport_esp_tls->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *) keep_alive_cfg;
}
GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
ssl->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *) keep_alive_cfg;
}
esp_transport_handle_t esp_transport_ssl_init(void)
{
esp_transport_handle_t t = esp_transport_init();
esp_transport_set_context_data(t, NULL);
esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
esp_transport_set_async_connect_func(t, ssl_connect_async);
t->_get_socket = ssl_get_socket;
@ -348,7 +368,6 @@ void esp_transport_esp_tls_destroy(struct transport_esp_tls* transport_esp_tls)
esp_transport_handle_t esp_transport_tcp_init(void)
{
esp_transport_handle_t t = esp_transport_init();
esp_transport_set_context_data(t, NULL);
esp_transport_set_func(t, tcp_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
esp_transport_set_async_connect_func(t, tcp_connect_async);
t->_get_socket = ssl_get_socket;