py/mpz: Simplify handling of borrow and quo adjustment in mpn_div.

The motivation behind this patch is to remove unreachable code in mpn_div.
This unreachable code was added some time ago in
9a21d2e070, when a loop in mpn_div was copied
and adjusted to work when mpz_dig_t was exactly half of the size of
mpz_dbl_dig_t (a common case).  The loop was copied correctly but it wasn't
noticed at the time that the final part of the calculation of num-quo*den
could be optimised, and hence unreachable code was left for a case that
never occurred.

The observation for the optimisation is that the initial value of quo in
mpn_div is either exact or too large (never too small), and therefore the
subtraction of quo*den from num may subtract exactly enough or too much
(but never too little).  Using this observation the part of the algorithm
that handles the borrow value can be simplified, and most importantly this
eliminates the unreachable code.

The new code has been tested with DIG_SIZE=3 and DIG_SIZE=4 by dividing all
possible combinations of non-negative integers with between 0 and 3
(inclusive) mpz digits.
pull/3514/merge
Damien George 2017-12-28 14:02:36 +11:00
rodzic c7cb1dfcb9
commit 9766fddcdc
2 zmienionych plików z 48 dodań i 70 usunięć

114
py/mpz.c
Wyświetl plik

@ -537,83 +537,57 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
// not to overflow the borrow variable. And the shifting of
// borrow needs some special logic (it's a shift right with
// round up).
if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
const mpz_dig_t *d = den_dig;
mpz_dbl_dig_t d_norm = 0;
mpz_dbl_dig_signed_t borrow = 0;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
*n = borrow & DIG_MASK;
borrow >>= DIG_SIZE;
}
borrow += *num_dig; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
*num_dig = borrow & DIG_MASK;
borrow >>= DIG_SIZE;
// adjust quotient if it is too big
for (; borrow != 0; --quo) {
d = den_dig;
d_norm = 0;
mpz_dbl_dig_t carry = 0;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
*n = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
carry += *num_dig;
*num_dig = carry & DIG_MASK;
carry >>= DIG_SIZE;
borrow += carry;
}
} else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
const mpz_dig_t *d = den_dig;
mpz_dbl_dig_t d_norm = 0;
mpz_dbl_dig_t borrow = 0;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
if (x >= *n || *n - x <= borrow) {
borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
*n = (-borrow) & DIG_MASK;
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
} else {
*n = ((mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)borrow) & DIG_MASK;
borrow = 0;
}
}
if (borrow >= *num_dig) {
borrow -= (mpz_dbl_dig_t)*num_dig;
*num_dig = (-borrow) & DIG_MASK;
//
const mpz_dig_t *d = den_dig;
mpz_dbl_dig_t d_norm = 0;
mpz_dbl_dig_t borrow = 0;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
borrow += (mpz_dbl_dig_t)*n - x; // will overflow if DIG_SIZE >= MPZ_DBL_DIG_SIZE/2
*n = borrow & DIG_MASK;
borrow = (mpz_dbl_dig_signed_t)borrow >> DIG_SIZE;
#else // DIG_SIZE == MPZ_DBL_DIG_SIZE / 2
if (x >= *n || *n - x <= borrow) {
borrow += x - (mpz_dbl_dig_t)*n;
*n = (-borrow) & DIG_MASK;
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
} else {
*num_dig = (*num_dig - borrow) & DIG_MASK;
*n = ((mpz_dbl_dig_t)*n - x - borrow) & DIG_MASK;
borrow = 0;
}
#endif
}
// adjust quotient if it is too big
for (; borrow != 0; --quo) {
d = den_dig;
d_norm = 0;
mpz_dbl_dig_t carry = 0;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
*n = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
carry += (mpz_dbl_dig_t)*num_dig;
*num_dig = carry & DIG_MASK;
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
// Borrow was negative in the above for-loop, make it positive for next if-block.
borrow = -borrow;
#endif
// At this point we have either:
//
// 1. quo was the correct value and the most-sig-digit of num is exactly
// cancelled by borrow (borrow == *num_dig). In this case there is
// nothing more to do.
//
// 2. quo was too large, we subtracted too many den from num, and the
// most-sig-digit of num is 1 less than borrow (borrow == *num_dig + 1).
// In this case we must reduce quo and add back den to num until the
// carry from this operation cancels out the borrow.
//
borrow -= *num_dig;
for (; borrow != 0; --quo) {
d = den_dig;
d_norm = 0;
mpz_dbl_dig_t carry = 0;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
*n = carry & DIG_MASK;
carry >>= DIG_SIZE;
//assert(borrow >= carry); // enable this to check the logic
borrow -= carry;
}
borrow -= carry;
}
// store this digit of the quotient

Wyświetl plik

@ -55,18 +55,22 @@
#endif
#if MPZ_DIG_SIZE > 16
#define MPZ_DBL_DIG_SIZE (64)
typedef uint32_t mpz_dig_t;
typedef uint64_t mpz_dbl_dig_t;
typedef int64_t mpz_dbl_dig_signed_t;
#elif MPZ_DIG_SIZE > 8
#define MPZ_DBL_DIG_SIZE (32)
typedef uint16_t mpz_dig_t;
typedef uint32_t mpz_dbl_dig_t;
typedef int32_t mpz_dbl_dig_signed_t;
#elif MPZ_DIG_SIZE > 4
#define MPZ_DBL_DIG_SIZE (16)
typedef uint8_t mpz_dig_t;
typedef uint16_t mpz_dbl_dig_t;
typedef int16_t mpz_dbl_dig_signed_t;
#else
#define MPZ_DBL_DIG_SIZE (8)
typedef uint8_t mpz_dig_t;
typedef uint8_t mpz_dbl_dig_t;
typedef int8_t mpz_dbl_dig_signed_t;