diff --git a/py/mpz.c b/py/mpz.c index 9e60fc50d0..9c42878ff8 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -218,6 +218,38 @@ STATIC uint mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz return idig + 1 - oidig; } +/* computes i = j & -k = j & (~k + 1) + returns number of digits in i + assumes enough memory in i; assumes normalised j, k + can have i, j, k pointing to same memory +*/ +STATIC uint mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { + mpz_dig_t *oidig = idig; + mpz_dbl_dig_t carry = 1; + + for (; jlen > 0 && klen > 0; --jlen, --klen, ++idig, ++jdig, ++kdig) { + carry += *kdig ^ DIG_MASK; + *idig = (*jdig & carry) & DIG_MASK; + carry >>= DIG_SIZE; + } + + for (; jlen > 0; --jlen, ++idig, ++jdig) { + carry += DIG_MASK; + *idig = (*jdig & carry) & DIG_MASK; + carry >>= DIG_SIZE; + } + + if (carry != 0) { + *idig = carry; + } else { + // remove trailing zeros + for (--idig; idig >= oidig && *idig == 0; --idig) { + } + } + + return idig + 1 - oidig; +} + /* computes i = j | k returns number of digits in i assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen @@ -896,24 +928,35 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { can have dest, lhs, rhs the same */ void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { - // make sure lhs has the most digits - if (lhs->len < rhs->len) { - const mpz_t *temp = lhs; - lhs = rhs; - rhs = temp; - } - if (lhs->neg == rhs->neg) { - mpz_need_dig(dest, rhs->len); - dest->len = mpn_and(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + if (lhs->neg == 0) { + // make sure lhs has the most digits + if (lhs->len < rhs->len) { + const mpz_t *temp = lhs; + lhs = rhs; + rhs = temp; + } + // do the and'ing + mpz_need_dig(dest, rhs->len); + dest->len = mpn_and(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + dest->neg = 0; + } else { + // TODO both args are negative + assert(0); + } } else { - mpz_need_dig(dest, lhs->len); - // TODO - assert(0); -// dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + // args have different sign + // make sure lhs is the positive arg + if (rhs->neg == 0) { + const mpz_t *temp = lhs; + lhs = rhs; + rhs = temp; + } + mpz_need_dig(dest, lhs->len + 1); + dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + assert(dest->len <= dest->alloc); + dest->neg = 0; } - - dest->neg = lhs->neg; } /* computes dest = lhs | rhs diff --git a/tests/basics/int-big-and.py b/tests/basics/int-big-and.py index 75fbd52884..a48848dbf0 100644 --- a/tests/basics/int-big-and.py +++ b/tests/basics/int-big-and.py @@ -2,7 +2,23 @@ print(0 & (1 << 80)) print(0 & (1 << 80) == 0) print(bool(0 & (1 << 80))) -#a = 0xfffffffffffffffffffffffffffff -#print(a & (1 << 80)) -#print((a & (1 << 80)) >> 80) -#print((a & (1 << 80)) >> 80 == 1) +a = 0xfffffffffffffffffffffffffffff +print(a & (1 << 80)) +print((a & (1 << 80)) >> 80) +print((a & (1 << 80)) >> 80 == 1) + +# test negative on rhs +a = 123456789012345678901234567890 +print(a & -1) +print(a & -2) +print(a & -2345678901234567890123456789) +print(a & (-a)) + +# test negative on lhs +a = 123456789012345678901234567890 +print(-1 & a) +print(-2 & a) +print(-2345678901234567890123456789 & a) +print((-a) & a) +print((-a) & 0xffffffff) +print((-a) & 0xffffffffffffffffffffffffffffffff)