From 4368ae31424f93f3272209ea61847f2406dd23ad Mon Sep 17 00:00:00 2001 From: Paul Sokolovsky Date: Thu, 20 Jul 2017 00:20:53 +0300 Subject: [PATCH] extmod/modussl_axtls: Allow to close ssl stream multiple times. Make sure that 2nd close has no effect and operations on closed streams are handled properly. --- extmod/modussl_axtls.c | 22 +++++++++++++++++++--- tests/extmod/ussl_basic.py | 8 ++++++++ tests/extmod/ussl_basic.py.exp | 3 ++- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/extmod/modussl_axtls.c b/extmod/modussl_axtls.c index a27f0f1fe5..a5ab8896c0 100644 --- a/extmod/modussl_axtls.c +++ b/extmod/modussl_axtls.c @@ -102,6 +102,11 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) { mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); + if (o->ssl_sock == NULL) { + *errcode = EBADF; + return MP_STREAM_ERROR; + } + while (o->bytes_left == 0) { mp_int_t r = ssl_read(o->ssl_sock, &o->buf); if (r == SSL_OK) { @@ -131,6 +136,12 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) { mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); + + if (o->ssl_sock == NULL) { + *errcode = EBADF; + return MP_STREAM_ERROR; + } + mp_int_t r = ssl_write(o->ssl_sock, buf, size); if (r < 0) { *errcode = r; @@ -151,9 +162,14 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking); STATIC mp_obj_t socket_close(mp_obj_t self_in) { mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in); - ssl_free(self->ssl_sock); - ssl_ctx_free(self->ssl_ctx); - return mp_stream_close(self->sock); + if (self->ssl_sock != NULL) { + ssl_free(self->ssl_sock); + ssl_ctx_free(self->ssl_ctx); + self->ssl_sock = NULL; + return mp_stream_close(self->sock); + } + + return mp_const_none; } STATIC MP_DEFINE_CONST_FUN_OBJ_1(socket_close_obj, socket_close); diff --git a/tests/extmod/ussl_basic.py b/tests/extmod/ussl_basic.py index 9f8019a0bc..e8710ed51a 100644 --- a/tests/extmod/ussl_basic.py +++ b/tests/extmod/ussl_basic.py @@ -43,6 +43,14 @@ except OSError as er: # close ss.close() +# close 2nd time +ss.close() + +# read on closed socket +try: + ss.read(10) +except OSError as er: + print('read:', repr(er)) # write on closed socket try: diff --git a/tests/extmod/ussl_basic.py.exp b/tests/extmod/ussl_basic.py.exp index b4dd038606..cb9c51f7a1 100644 --- a/tests/extmod/ussl_basic.py.exp +++ b/tests/extmod/ussl_basic.py.exp @@ -5,4 +5,5 @@ setblocking: NotImplementedError 4 b'' read: OSError(-261,) -write: OSError(-256,) +read: OSError(9,) +write: OSError(9,)