From 9766fddcdc29cb2b5a8657389b870703580a6e57 Mon Sep 17 00:00:00 2001 From: Damien George Date: Thu, 28 Dec 2017 14:02:36 +1100 Subject: [PATCH] py/mpz: Simplify handling of borrow and quo adjustment in mpn_div. The motivation behind this patch is to remove unreachable code in mpn_div. This unreachable code was added some time ago in 9a21d2e070c9ee0ef2c003f3a668e635c6ae4401, when a loop in mpn_div was copied and adjusted to work when mpz_dig_t was exactly half of the size of mpz_dbl_dig_t (a common case). The loop was copied correctly but it wasn't noticed at the time that the final part of the calculation of num-quo*den could be optimised, and hence unreachable code was left for a case that never occurred. The observation for the optimisation is that the initial value of quo in mpn_div is either exact or too large (never too small), and therefore the subtraction of quo*den from num may subtract exactly enough or too much (but never too little). Using this observation the part of the algorithm that handles the borrow value can be simplified, and most importantly this eliminates the unreachable code. The new code has been tested with DIG_SIZE=3 and DIG_SIZE=4 by dividing all possible combinations of non-negative integers with between 0 and 3 (inclusive) mpz digits. --- py/mpz.c | 114 +++++++++++++++++++++---------------------------------- py/mpz.h | 4 ++ 2 files changed, 48 insertions(+), 70 deletions(-) diff --git a/py/mpz.c b/py/mpz.c index 018a5454f8..78dee31325 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -537,83 +537,57 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di // not to overflow the borrow variable. And the shifting of // borrow needs some special logic (it's a shift right with // round up). - - if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) { - const mpz_dig_t *d = den_dig; - mpz_dbl_dig_t d_norm = 0; - mpz_dbl_dig_signed_t borrow = 0; - - for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { - d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); - borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2 - *n = borrow & DIG_MASK; - borrow >>= DIG_SIZE; - } - borrow += *num_dig; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2 - *num_dig = borrow & DIG_MASK; - borrow >>= DIG_SIZE; - - // adjust quotient if it is too big - for (; borrow != 0; --quo) { - d = den_dig; - d_norm = 0; - mpz_dbl_dig_t carry = 0; - for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { - d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); - carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK); - *n = carry & DIG_MASK; - carry >>= DIG_SIZE; - } - carry += *num_dig; - *num_dig = carry & DIG_MASK; - carry >>= DIG_SIZE; - - borrow += carry; - } - } else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2 - const mpz_dig_t *d = den_dig; - mpz_dbl_dig_t d_norm = 0; - mpz_dbl_dig_t borrow = 0; - - for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { - d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); - mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); - if (x >= *n || *n - x <= borrow) { - borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n; - *n = (-borrow) & DIG_MASK; - borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up - } else { - *n = ((mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)borrow) & DIG_MASK; - borrow = 0; - } - } - if (borrow >= *num_dig) { - borrow -= (mpz_dbl_dig_t)*num_dig; - *num_dig = (-borrow) & DIG_MASK; + // + const mpz_dig_t *d = den_dig; + mpz_dbl_dig_t d_norm = 0; + mpz_dbl_dig_t borrow = 0; + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); + #if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2 + borrow += (mpz_dbl_dig_t)*n - x; // will overflow if DIG_SIZE >= MPZ_DBL_DIG_SIZE/2 + *n = borrow & DIG_MASK; + borrow = (mpz_dbl_dig_signed_t)borrow >> DIG_SIZE; + #else // DIG_SIZE == MPZ_DBL_DIG_SIZE / 2 + if (x >= *n || *n - x <= borrow) { + borrow += x - (mpz_dbl_dig_t)*n; + *n = (-borrow) & DIG_MASK; borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up } else { - *num_dig = (*num_dig - borrow) & DIG_MASK; + *n = ((mpz_dbl_dig_t)*n - x - borrow) & DIG_MASK; borrow = 0; } + #endif + } - // adjust quotient if it is too big - for (; borrow != 0; --quo) { - d = den_dig; - d_norm = 0; - mpz_dbl_dig_t carry = 0; - for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { - d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); - carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK); - *n = carry & DIG_MASK; - carry >>= DIG_SIZE; - } - carry += (mpz_dbl_dig_t)*num_dig; - *num_dig = carry & DIG_MASK; + #if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2 + // Borrow was negative in the above for-loop, make it positive for next if-block. + borrow = -borrow; + #endif + + // At this point we have either: + // + // 1. quo was the correct value and the most-sig-digit of num is exactly + // cancelled by borrow (borrow == *num_dig). In this case there is + // nothing more to do. + // + // 2. quo was too large, we subtracted too many den from num, and the + // most-sig-digit of num is 1 less than borrow (borrow == *num_dig + 1). + // In this case we must reduce quo and add back den to num until the + // carry from this operation cancels out the borrow. + // + borrow -= *num_dig; + for (; borrow != 0; --quo) { + d = den_dig; + d_norm = 0; + mpz_dbl_dig_t carry = 0; + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK); + *n = carry & DIG_MASK; carry >>= DIG_SIZE; - - //assert(borrow >= carry); // enable this to check the logic - borrow -= carry; } + borrow -= carry; } // store this digit of the quotient diff --git a/py/mpz.h b/py/mpz.h index e2d0c30aac..3c36cac66b 100644 --- a/py/mpz.h +++ b/py/mpz.h @@ -55,18 +55,22 @@ #endif #if MPZ_DIG_SIZE > 16 +#define MPZ_DBL_DIG_SIZE (64) typedef uint32_t mpz_dig_t; typedef uint64_t mpz_dbl_dig_t; typedef int64_t mpz_dbl_dig_signed_t; #elif MPZ_DIG_SIZE > 8 +#define MPZ_DBL_DIG_SIZE (32) typedef uint16_t mpz_dig_t; typedef uint32_t mpz_dbl_dig_t; typedef int32_t mpz_dbl_dig_signed_t; #elif MPZ_DIG_SIZE > 4 +#define MPZ_DBL_DIG_SIZE (16) typedef uint8_t mpz_dig_t; typedef uint16_t mpz_dbl_dig_t; typedef int16_t mpz_dbl_dig_signed_t; #else +#define MPZ_DBL_DIG_SIZE (8) typedef uint8_t mpz_dig_t; typedef uint8_t mpz_dbl_dig_t; typedef int8_t mpz_dbl_dig_signed_t;