From 6b3bc4d8c5bb3eacf1076417074bccdca73971d5 Mon Sep 17 00:00:00 2001 From: Angus Gratton Date: Tue, 20 Sep 2016 21:02:07 +1000 Subject: [PATCH] hwcrypto bignum: Implement multiplication modulo Fixes case where hardware bignum multiplication fails due to either operand >2048 bits. --- components/mbedtls/port/esp_bignum.c | 245 ++++++++++++++++-- .../mbedtls/port/include/mbedtls/esp_config.h | 4 +- 2 files changed, 223 insertions(+), 26 deletions(-) diff --git a/components/mbedtls/port/esp_bignum.c b/components/mbedtls/port/esp_bignum.c index caae0161f0..d6b79e32f5 100644 --- a/components/mbedtls/port/esp_bignum.c +++ b/components/mbedtls/port/esp_bignum.c @@ -52,6 +52,140 @@ static void esp_mpi_release_hardware( void ) _lock_release(&mpi_lock); } +/* Given a & b, determine u & v such that + + gcd(a,b) = d = au + bv + + Underlying algorithm comes from: + http://www.ucl.ac.uk/~ucahcjm/combopt/ext_gcd_python_programs.pdf + http://www.hackersdelight.org/hdcodetxt/mont64.c.txt + */ +static void extended_binary_gcd(const mbedtls_mpi *a, const mbedtls_mpi *b, + mbedtls_mpi *u, mbedtls_mpi *v) +{ + mbedtls_mpi ta, tb; + + mbedtls_mpi_init(&ta); + mbedtls_mpi_copy(&ta, a); + mbedtls_mpi_init(&tb); + mbedtls_mpi_copy(&tb, b); + + mbedtls_mpi_lset(u, 1); + mbedtls_mpi_lset(v, 0); + + /* Loop invariant: + ta = u*2*a - v*b. */ + while (mbedtls_mpi_cmp_int(&ta, 0) != 0) { + mbedtls_mpi_shift_r(&ta, 1); + if (mbedtls_mpi_get_bit(u, 0) == 0) { + // Remove common factor of 2 in u & v + mbedtls_mpi_shift_r(u, 1); + mbedtls_mpi_shift_r(v, 1); + } + else { + /* u = (u + b) >> 1 */ + mbedtls_mpi_add_mpi(u, u, b); + mbedtls_mpi_shift_r(u, 1); + /* v = (v >> 1) + a */ + mbedtls_mpi_shift_r(v, 1); + mbedtls_mpi_add_mpi(v, v, a); + } + } + mbedtls_mpi_free(&ta); + mbedtls_mpi_free(&tb); + + /* u = u * 2, so 1 = u*a - v*b */ + mbedtls_mpi_shift_l(u, 1); +} + +/* inner part of MPI modular multiply, after Rinv & Mprime are calculated */ +static int mpi_mul_mpi_mod_inner(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *M, mbedtls_mpi *Rinv, uint32_t Mprime, size_t num_words) +{ + int ret; + mbedtls_mpi TA, TB; + size_t num_bits = num_words * 32; + + mbedtls_mpi_grow(Rinv, num_words); + + /* TODO: fill memory blocks directly so this isn't needed */ + mbedtls_mpi_init(&TA); + mbedtls_mpi_copy(&TA, A); + mbedtls_mpi_grow(&TA, num_words); + A = &TA; + mbedtls_mpi_init(&TB); + mbedtls_mpi_copy(&TB, B); + mbedtls_mpi_grow(&TB, num_words); + B = &TB; + + esp_mpi_acquire_hardware(); + + if(ets_bigint_mod_mult_prepare(A->p, B->p, M->p, Mprime, + Rinv->p, num_bits, false)) { + mbedtls_mpi_grow(X, num_words); + ets_bigint_wait_finish(); + if(ets_bigint_mod_mult_getz(M->p, X->p, num_bits)) { + X->s = A->s * B->s; + ret = 0; + } else { + printf("ets_bigint_mod_mult_getz failed\n"); + ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA; + } + } else { + printf("ets_bigint_mod_mult_prepare failed\n"); + ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA; + } + esp_mpi_release_hardware(); + + /* unclear why this is necessary, but the result seems + to come back rotated 32 bits to the right... */ + uint32_t last_word = X->p[num_words-1]; + X->p[num_words-1] = 0; + mbedtls_mpi_shift_l(X, 32); + X->p[0] = last_word; + + mbedtls_mpi_free(&TA); + mbedtls_mpi_free(&TB); + + return ret; +} + +/* X = (A * B) mod M + + Not an mbedTLS function + + num_bits guaranteed to be a multiple of 512 already. + + TODO: ensure M is odd + */ +int esp_mpi_mul_mpi_mod(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *M, size_t num_bits) +{ + int ret = 0; + mbedtls_mpi RR, Rinv, Mprime; + uint32_t Mprime_int; + size_t num_words = num_bits / 32; + + /* Rinv & Mprime are calculated via extended binary gcd + algorithm, see references on extended_binary_gcd above. + */ + mbedtls_mpi_init(&Rinv); + mbedtls_mpi_init(&RR); + mbedtls_mpi_set_bit(&RR, num_bits+32, 1); + mbedtls_mpi_init(&Mprime); + extended_binary_gcd(&RR, M, &Rinv, &Mprime); + + /* M' is mod 2^32 */ + Mprime_int = Mprime.p[0]; + + ret = mpi_mul_mpi_mod_inner(X, A, B, M, &Rinv, Mprime_int, num_words); + + mbedtls_mpi_free(&RR); + mbedtls_mpi_free(&Mprime); + mbedtls_mpi_free(&Rinv); + + return ret; +} + + /* * Helper for mbedtls_mpi multiplication * copied/trimmed from mbedtls bignum.c @@ -223,6 +357,53 @@ static inline size_t hardware_words_needed(const mbedtls_mpi *mpi) return res; } + +/* Special-case multiply, where we use hardware montgomery mod + multiplication to solve the case where A or B are >2048 bits so + can't do standard multiplication. + + the modulus here is chosen with M=(2^num_bits-1) + to guarantee the output isn't actually modulo anything. This means + we don't need to calculate M' and Rinv, they are predictable + as follows: + M' = 1 + Rinv = (1 << (num_bits - 32) + + (See RSA Accelerator section in Technical Reference for derivation + of M', Rinv) +*/ +static int esp_mpi_mult_mpi_failover_mod_mult(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B, size_t num_words) + { + mbedtls_mpi M, Rinv; + int ret; + size_t mprime; + size_t num_bits = num_words * 32; + + mbedtls_mpi_init(&M); + mbedtls_mpi_init(&Rinv); + + /* TODO: it may be faster to just use 4096-bit arithmetic every time, + and make these constants rather than runtime derived + derived. */ + /* M = (2^num_words)-1 */ + mbedtls_mpi_grow(&M, num_words); + for(int i = 0; i < num_words*32; i++) { + mbedtls_mpi_set_bit(&M, i, 1); + } + + /* Rinv = (2^num_words-32) */ + mbedtls_mpi_grow(&Rinv, num_words); + mbedtls_mpi_set_bit(&Rinv, num_bits - 32, 1); + + mprime = 1; + + ret = mpi_mul_mpi_mod_inner(X, A, B, &M, &Rinv, mprime, num_words); + + mbedtls_mpi_free(&M); + mbedtls_mpi_free(&Rinv); + return ret; + } + int mbedtls_mpi_mul_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B ) { int ret = -1; @@ -236,6 +417,8 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi words_a = hardware_words_needed(A); words_b = hardware_words_needed(B); + words_mult = (words_a > words_b ? words_a : words_b); + /* Take a copy of A if either X == A OR if A isn't long enough to hold the number of words needed for hardware. @@ -248,47 +431,63 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi RAM. But we need to reimplement ets_bigint_mult_prepare() in software for this. */ - if( X == A || A->n < words_a) { + if( X == A || A->n < words_mult) { MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TA, A ) ); - MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TA, words_a) ); + MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TA, words_mult) ); A = &TA; } /* Same for B */ - if( X == B || B->n < words_b ) { + if( X == B || B->n < words_mult ) { MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TB, B ) ); - MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TB, words_b) ); + MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TB, words_mult) ); B = &TB; } /* Result X has to have room for double the larger operand */ - words_mult = (words_a > words_b ? words_a : words_b); words_x = words_mult * 2; MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, words_x ) ); /* TODO: check if lset here is necessary, hardware should zero */ MBEDTLS_MPI_CHK( mbedtls_mpi_lset( X, 0 ) ); - esp_mpi_acquire_hardware(); + /* If either operand is over 2048 bits, we can't use the standard hardware multiplier + (it assumes result is double longest operand, and result is max 4096 bits.) + However, we can fail over to mod_mult for up to 4096 bits. + */ if(words_mult * 32 > 2048) { - printf("WARNING: %d bit operands (%d bits * %d bits) too large for hardware unit\n", words_mult * 32, mbedtls_mpi_bitlen(A), mbedtls_mpi_bitlen(B)); - } - - if (ets_bigint_mult_prepare(A->p, B->p, words_mult * 32)) { - ets_bigint_wait_finish(); - /* NB: argument to bigint_mult_getz is length of inputs, double this number (words_x) is - copied to output X->p. + /* TODO: check if there's an overflow condition if words_a & words_b are both + the bit lengths of the operands, result could be 1 bit longer */ - if (ets_bigint_mult_getz(X->p, words_mult * 32) == true) { - ret = 0; - } else { - printf("ets_bigint_mult_getz failed\n"); - } - } else{ - printf("Baseline multiplication failed\n"); - } - esp_mpi_release_hardware(); + if((words_a + words_b) * 32 > 4096) { + printf("ERROR: %d bit operands (%d bits * %d bits) too large for hardware unit\n", words_mult * 32, mbedtls_mpi_bitlen(A), mbedtls_mpi_bitlen(B)); + ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE; + } + else { + ret = esp_mpi_mult_mpi_failover_mod_mult(X, A, B, words_a + words_b); + } + } + else { - X->s = A->s * B->s; + /* normal mpi multiplication */ + esp_mpi_acquire_hardware(); + if (ets_bigint_mult_prepare(A->p, B->p, words_mult * 32)) { + ets_bigint_wait_finish(); + /* NB: argument to bigint_mult_getz is length of inputs, double this number (words_x) is + copied to output X->p. + */ + if (ets_bigint_mult_getz(X->p, words_mult * 32) == true) { + X->s = A->s * B->s; + ret = 0; + } else { + printf("ets_bigint_mult_getz failed\n"); + ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE; + } + } else{ + printf("Baseline multiplication failed\n"); + ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE; + } + esp_mpi_release_hardware(); + } cleanup: mbedtls_mpi_free( &TB ); mbedtls_mpi_free( &TA ); diff --git a/components/mbedtls/port/include/mbedtls/esp_config.h b/components/mbedtls/port/include/mbedtls/esp_config.h index 4ddd9821c4..e4f4af271a 100644 --- a/components/mbedtls/port/include/mbedtls/esp_config.h +++ b/components/mbedtls/port/include/mbedtls/esp_config.h @@ -250,10 +250,8 @@ /* The following MPI (bignum) functions have ESP32 hardware support, Uncommenting these macros will use the hardware-accelerated implementations. - - Disabled as number of limbs limited by bug. Internal TW#7112. */ -#define MBEDTLS_MPI_EXP_MOD_ALT +//#define MBEDTLS_MPI_EXP_MOD_ALT #define MBEDTLS_MPI_MUL_MPI_ALT /**