From 3a315e36add54f9db813c20ff3c4b37b54812931 Mon Sep 17 00:00:00 2001 From: moprg Date: Tue, 30 Apr 2024 10:34:18 +0700 Subject: [PATCH] extmod/modtls_mbedtls.c: Add set_psk_server_callback() method. Signed-off-by: moprg --- extmod/modtls_mbedtls.c | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/extmod/modtls_mbedtls.c b/extmod/modtls_mbedtls.c index 6db6ac1958..4019733347 100644 --- a/extmod/modtls_mbedtls.c +++ b/extmod/modtls_mbedtls.c @@ -199,6 +199,24 @@ static int ssl_sock_cert_verify(void *ptr, mbedtls_x509_crt *crt, int depth, uin return mp_obj_get_int(mp_call_function_2(o->handler, MP_OBJ_FROM_PTR(&cert), MP_OBJ_NEW_SMALL_INT(depth))); } +static int ssl_conf_psk_cb(void *parameter, mbedtls_ssl_context *ssl, + const unsigned char *psk_identity, size_t psk_identity_len) { + mp_obj_t callback = MP_OBJ_FROM_PTR(parameter); + mp_obj_t psk_identity_obj = mp_obj_new_bytes(psk_identity, psk_identity_len); + + // Call the callback + mp_obj_t psk_key_obj = mp_call_function_1(callback, psk_identity_obj); + + // Check if psk key object is supplied, set handshake psk + int ret = -1; + if (psk_key_obj != mp_const_none) { + size_t psk_key_len; + const unsigned char *psk_key = (const unsigned char *)mp_obj_str_get_data(psk_key_obj, &psk_key_len); + ret = mbedtls_ssl_set_hs_psk(ssl, psk_key, psk_key_len); + } + return ret; +} + /******************************************************************************/ // SSLContext type. @@ -401,6 +419,18 @@ static mp_obj_t ssl_context_load_verify_locations(mp_obj_t self_in, mp_obj_t cad } static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_load_verify_locations_obj, ssl_context_load_verify_locations); +// SSLContext.set_psk_server_callback(callback) +static mp_obj_t ssl_context_set_psk_server_callback(mp_obj_t self_in, mp_obj_t callback) { + mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(self_in); + + // Check if callback is supplied, config psk callback + if (callback != mp_const_none) { + mbedtls_ssl_conf_psk_cb(&self->conf, ssl_conf_psk_cb, MP_OBJ_TO_PTR(callback)); + } + return mp_const_none; +} +static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_set_psk_server_callback_obj, ssl_context_set_psk_server_callback); + static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname }; static const mp_arg_t allowed_args[] = { @@ -429,6 +459,7 @@ static const mp_rom_map_elem_t ssl_context_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR_set_ciphers), MP_ROM_PTR(&ssl_context_set_ciphers_obj)}, { MP_ROM_QSTR(MP_QSTR_load_cert_chain), MP_ROM_PTR(&ssl_context_load_cert_chain_obj)}, { MP_ROM_QSTR(MP_QSTR_load_verify_locations), MP_ROM_PTR(&ssl_context_load_verify_locations_obj)}, + { MP_ROM_QSTR(MP_QSTR_set_psk_server_callback), MP_ROM_PTR(&ssl_context_set_psk_server_callback_obj)}, { MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) }, }; static MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table);