diff --git a/extmod/moducryptolib.c b/extmod/moducryptolib.c index 4f3a6b8e84..23178acf21 100644 --- a/extmod/moducryptolib.c +++ b/extmod/moducryptolib.c @@ -50,23 +50,67 @@ enum { #if MICROPY_SSL_AXTLS #include "lib/axtls/crypto/crypto.h" + +#define AES_CTX_IMPL AES_CTX #endif typedef struct _mp_obj_aes_t { mp_obj_base_t base; - AES_CTX ctx; - uint8_t block_mode: 7; - uint8_t is_decrypt_key: 1; + AES_CTX_IMPL ctx; + uint8_t block_mode: 6; +#define AES_KEYTYPE_NONE 0 +#define AES_KEYTYPE_ENC 1 +#define AES_KEYTYPE_DEC 2 + uint8_t key_type: 2; } mp_obj_aes_t; +#if MICROPY_SSL_AXTLS +STATIC void aes_initial_set_key_impl(AES_CTX_IMPL *ctx, const uint8_t *key, size_t keysize, const uint8_t iv[16]) { + assert(16 == keysize || 32 == keysize); + AES_set_key(ctx, key, iv, (16 == keysize) ? AES_MODE_128 : AES_MODE_256); +} + +STATIC void aes_final_set_key_impl(AES_CTX_IMPL *ctx, bool encrypt) { + if (!encrypt) { + AES_convert_key(ctx); + } +} + +STATIC void aes_process_ecb_impl(AES_CTX_IMPL *ctx, const uint8_t in[16], uint8_t out[16], bool encrypt) { + memcpy(out, in, 16); + // We assume that out (vstr.buf or given output buffer) is uint32_t aligned + uint32_t *p = (uint32_t*)out; + // axTLS likes it weird and complicated with byteswaps + for (int i = 0; i < 4; i++) { + p[i] = MP_HTOBE32(p[i]); + } + if (encrypt) { + AES_encrypt(ctx, p); + } else { + AES_decrypt(ctx, p); + } + for (int i = 0; i < 4; i++) { + p[i] = MP_BE32TOH(p[i]); + } +} + +STATIC void aes_process_cbc_impl(AES_CTX_IMPL *ctx, const uint8_t *in, uint8_t *out, size_t in_len, bool encrypt) { + if (encrypt) { + AES_cbc_encrypt(ctx, in, out, in_len); + } else { + AES_cbc_decrypt(ctx, in, out, in_len); + } +} +#endif + STATIC mp_obj_t aes_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_kw, const mp_obj_t *args) { mp_arg_check_num(n_args, n_kw, 2, 3, false); mp_obj_aes_t *o = m_new_obj(mp_obj_aes_t); o->base.type = type; o->block_mode = mp_obj_get_int(args[1]); - o->is_decrypt_key = 0; + o->key_type = AES_KEYTYPE_NONE; if (o->block_mode <= UCRYPTOLIB_MODE_MIN || o->block_mode >= UCRYPTOLIB_MODE_MAX) { mp_raise_ValueError("mode"); @@ -74,19 +118,23 @@ STATIC mp_obj_t aes_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_ mp_buffer_info_t keyinfo; mp_get_buffer_raise(args[0], &keyinfo, MP_BUFFER_READ); + if (32 != keyinfo.len && 16 != keyinfo.len) { + mp_raise_ValueError("bad key length"); + } mp_buffer_info_t ivinfo; ivinfo.buf = NULL; if (n_args > 2 && args[2] != mp_const_none) { mp_get_buffer_raise(args[2], &ivinfo, MP_BUFFER_READ); + + if (16 != ivinfo.len) { + mp_raise_ValueError("bad iv length"); + } + } else if (o->block_mode == UCRYPTOLIB_MODE_CBC) { + mp_raise_ValueError("iv required for MODE_CBC"); } - AES_MODE keysize = AES_MODE_128; - if (keyinfo.len == 32) { - keysize = AES_MODE_256; - } - - AES_set_key(&o->ctx, keyinfo.buf, ivinfo.buf, keysize); + aes_initial_set_key_impl(&o->ctx, keyinfo.buf, keyinfo.len, ivinfo.buf); return MP_OBJ_FROM_PTR(o); } @@ -94,10 +142,6 @@ STATIC mp_obj_t aes_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_ STATIC mp_obj_t aes_process(size_t n_args, const mp_obj_t *args, bool encrypt) { mp_obj_aes_t *self = MP_OBJ_TO_PTR(args[0]); - if (encrypt && self->is_decrypt_key) { - mp_raise_TypeError("can't enc after dec"); - } - mp_obj_t in_buf = args[1]; mp_obj_t out_buf = MP_OBJ_NULL; if (n_args > 2) { @@ -118,7 +162,7 @@ STATIC mp_obj_t aes_process(size_t n_args, const mp_obj_t *args, bool encrypt) { if (out_buf != MP_OBJ_NULL) { mp_get_buffer_raise(out_buf, &out_bufinfo, MP_BUFFER_WRITE); if (out_bufinfo.len < in_bufinfo.len) { - mp_raise_ValueError("out blksize"); + mp_raise_ValueError("output buffer too small"); } out_buf_ptr = out_bufinfo.buf; } else { @@ -126,37 +170,25 @@ STATIC mp_obj_t aes_process(size_t n_args, const mp_obj_t *args, bool encrypt) { out_buf_ptr = (uint8_t*)vstr.buf; } - if (!encrypt && !self->is_decrypt_key) { - AES_convert_key(&self->ctx); - self->is_decrypt_key = 1; + if (AES_KEYTYPE_NONE == self->key_type) { + aes_final_set_key_impl(&self->ctx, encrypt); + self->key_type = encrypt ? AES_KEYTYPE_ENC : AES_KEYTYPE_DEC; + } else { + if ((encrypt && self->key_type == AES_KEYTYPE_DEC) || + (!encrypt && self->key_type == AES_KEYTYPE_ENC)) { + + mp_raise_ValueError("can't use same aes object for encrypt & decrypt"); + } } if (self->block_mode == UCRYPTOLIB_MODE_ECB) { - uint8_t *in = in_bufinfo.buf, *out = out_buf_ptr; - uint8_t *top = in + in_bufinfo.len; - for (; in < top; in += 16, out += 16) { - memcpy(out, in, 16); - // We assume that vstr.buf is uint32_t aligned - uint32_t *p = (uint32_t*)out; - // axTLS likes it weird and complicated with byteswaps - for (int i = 0; i < 4; i++) { - p[i] = MP_HTOBE32(p[i]); - } - if (encrypt) { - AES_encrypt(&self->ctx, p); - } else { - AES_decrypt(&self->ctx, p); - } - for (int i = 0; i < 4; i++) { - p[i] = MP_BE32TOH(p[i]); - } - } - } else { - if (encrypt) { - AES_cbc_encrypt(&self->ctx, in_bufinfo.buf, out_buf_ptr, in_bufinfo.len); - } else { - AES_cbc_decrypt(&self->ctx, in_bufinfo.buf, out_buf_ptr, in_bufinfo.len); + uint8_t *in = in_bufinfo.buf, *out = out_buf_ptr; + uint8_t *top = in + in_bufinfo.len; + for (; in < top; in += 16, out += 16) { + aes_process_ecb_impl(&self->ctx, in, out, encrypt); } + } else { + aes_process_cbc_impl(&self->ctx, in_bufinfo.buf, out_buf_ptr, in_bufinfo.len, encrypt); } if (out_buf != MP_OBJ_NULL) {