use crate::std_alloc::{Cow, Vec};
use core::cmp;
use core::cmp::Ordering::{self, Equal, Greater, Less};
use core::iter::repeat;
use core::mem;
use num_traits::{One, PrimInt, Zero};
#[cfg(all(use_addcarry, target_arch = "x86_64"))]
use core::arch::x86_64 as arch;
#[cfg(all(use_addcarry, target_arch = "x86"))]
use core::arch::x86 as arch;
use crate::biguint::biguint_from_vec;
use crate::biguint::BigUint;
use crate::bigint::BigInt;
use crate::bigint::Sign;
use crate::bigint::Sign::{Minus, NoSign, Plus};
use crate::big_digit::{self, BigDigit, DoubleBigDigit};
#[cfg(not(use_addcarry))]
use crate::big_digit::SignedDoubleBigDigit;
#[cfg(all(use_addcarry, u64_digit))]
#[inline]
fn adc(carry: u8, a: u64, b: u64, out: &mut u64) -> u8 {
unsafe { arch::_addcarry_u64(carry, a, b, out) }
}
#[cfg(all(use_addcarry, not(u64_digit)))]
#[inline]
fn adc(carry: u8, a: u32, b: u32, out: &mut u32) -> u8 {
unsafe { arch::_addcarry_u32(carry, a, b, out) }
}
#[cfg(not(use_addcarry))]
#[inline]
fn adc(carry: u8, a: BigDigit, b: BigDigit, out: &mut BigDigit) -> u8 {
let sum = DoubleBigDigit::from(a) + DoubleBigDigit::from(b) + DoubleBigDigit::from(carry);
*out = sum as BigDigit;
(sum >> big_digit::BITS) as u8
}
#[cfg(all(use_addcarry, u64_digit))]
#[inline]
fn sbb(borrow: u8, a: u64, b: u64, out: &mut u64) -> u8 {
unsafe { arch::_subborrow_u64(borrow, a, b, out) }
}
#[cfg(all(use_addcarry, not(u64_digit)))]
#[inline]
fn sbb(borrow: u8, a: u32, b: u32, out: &mut u32) -> u8 {
unsafe { arch::_subborrow_u32(borrow, a, b, out) }
}
#[cfg(not(use_addcarry))]
#[inline]
fn sbb(borrow: u8, a: BigDigit, b: BigDigit, out: &mut BigDigit) -> u8 {
let difference = SignedDoubleBigDigit::from(a)
- SignedDoubleBigDigit::from(b)
- SignedDoubleBigDigit::from(borrow);
*out = difference as BigDigit;
u8::from(difference < 0)
}
#[inline]
pub(crate) fn mac_with_carry(
a: BigDigit,
b: BigDigit,
c: BigDigit,
acc: &mut DoubleBigDigit,
) -> BigDigit {
*acc += DoubleBigDigit::from(a);
*acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
let lo = *acc as BigDigit;
*acc >>= big_digit::BITS;
lo
}
#[inline]
pub(crate) fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
*acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
let lo = *acc as BigDigit;
*acc >>= big_digit::BITS;
lo
}
#[inline]
fn div_wide(hi: BigDigit, lo: BigDigit, divisor: BigDigit) -> (BigDigit, BigDigit) {
debug_assert!(hi < divisor);
let lhs = big_digit::to_doublebigdigit(hi, lo);
let rhs = DoubleBigDigit::from(divisor);
((lhs / rhs) as BigDigit, (lhs % rhs) as BigDigit)
}
#[inline]
fn div_half(rem: BigDigit, digit: BigDigit, divisor: BigDigit) -> (BigDigit, BigDigit) {
use crate::big_digit::{HALF, HALF_BITS};
use num_integer::Integer;
debug_assert!(rem < divisor && divisor <= HALF);
let (hi, rem) = ((rem << HALF_BITS) | (digit >> HALF_BITS)).div_rem(&divisor);
let (lo, rem) = ((rem << HALF_BITS) | (digit & HALF)).div_rem(&divisor);
((hi << HALF_BITS) | lo, rem)
}
#[inline]
pub(crate) fn div_rem_digit(mut a: BigUint, b: BigDigit) -> (BigUint, BigDigit) {
let mut rem = 0;
if b <= big_digit::HALF {
for d in a.data.iter_mut().rev() {
let (q, r) = div_half(rem, *d, b);
*d = q;
rem = r;
}
} else {
for d in a.data.iter_mut().rev() {
let (q, r) = div_wide(rem, *d, b);
*d = q;
rem = r;
}
}
(a.normalized(), rem)
}
#[inline]
pub(crate) fn rem_digit(a: &BigUint, b: BigDigit) -> BigDigit {
let mut rem = 0;
if b <= big_digit::HALF {
for &digit in a.data.iter().rev() {
let (_, r) = div_half(rem, digit, b);
rem = r;
}
} else {
for &digit in a.data.iter().rev() {
let (_, r) = div_wide(rem, digit, b);
rem = r;
}
}
rem
}
#[inline]
pub(crate) fn __add2(a: &mut [BigDigit], b: &[BigDigit]) -> BigDigit {
debug_assert!(a.len() >= b.len());
let mut carry = 0;
let (a_lo, a_hi) = a.split_at_mut(b.len());
for (a, b) in a_lo.iter_mut().zip(b) {
carry = adc(carry, *a, *b, a);
}
if carry != 0 {
for a in a_hi {
carry = adc(carry, *a, 0, a);
if carry == 0 {
break;
}
}
}
carry as BigDigit
}
pub(crate) fn add2(a: &mut [BigDigit], b: &[BigDigit]) {
let carry = __add2(a, b);
debug_assert!(carry == 0);
}
pub(crate) fn sub2(a: &mut [BigDigit], b: &[BigDigit]) {
let mut borrow = 0;
let len = cmp::min(a.len(), b.len());
let (a_lo, a_hi) = a.split_at_mut(len);
let (b_lo, b_hi) = b.split_at(len);
for (a, b) in a_lo.iter_mut().zip(b_lo) {
borrow = sbb(borrow, *a, *b, a);
}
if borrow != 0 {
for a in a_hi {
borrow = sbb(borrow, *a, 0, a);
if borrow == 0 {
break;
}
}
}
assert!(
borrow == 0 && b_hi.iter().all(|x| *x == 0),
"Cannot subtract b from a because b is larger than a."
);
}
#[inline]
pub(crate) fn __sub2rev(a: &[BigDigit], b: &mut [BigDigit]) -> u8 {
debug_assert!(b.len() == a.len());
let mut borrow = 0;
for (ai, bi) in a.iter().zip(b) {
borrow = sbb(borrow, *ai, *bi, bi);
}
borrow
}
pub(crate) fn sub2rev(a: &[BigDigit], b: &mut [BigDigit]) {
debug_assert!(b.len() >= a.len());
let len = cmp::min(a.len(), b.len());
let (a_lo, a_hi) = a.split_at(len);
let (b_lo, b_hi) = b.split_at_mut(len);
let borrow = __sub2rev(a_lo, b_lo);
assert!(a_hi.is_empty());
assert!(
borrow == 0 && b_hi.iter().all(|x| *x == 0),
"Cannot subtract b from a because b is larger than a."
);
}
pub(crate) fn sub_sign(a: &[BigDigit], b: &[BigDigit]) -> (Sign, BigUint) {
let a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
let b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
match cmp_slice(a, b) {
Greater => {
let mut a = a.to_vec();
sub2(&mut a, b);
(Plus, biguint_from_vec(a))
}
Less => {
let mut b = b.to_vec();
sub2(&mut b, a);
(Minus, biguint_from_vec(b))
}
_ => (NoSign, Zero::zero()),
}
}
pub(crate) fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
if c == 0 {
return;
}
let mut carry = 0;
let (a_lo, a_hi) = acc.split_at_mut(b.len());
for (a, &b) in a_lo.iter_mut().zip(b) {
*a = mac_with_carry(*a, b, c, &mut carry);
}
let (carry_hi, carry_lo) = big_digit::from_doublebigdigit(carry);
let final_carry = if carry_hi == 0 {
__add2(a_hi, &[carry_lo])
} else {
__add2(a_hi, &[carry_hi, carry_lo])
};
assert_eq!(final_carry, 0, "carry overflow during multiplication!");
}
fn sub_mul_digit_same_len(a: &mut [BigDigit], b: &[BigDigit], c: BigDigit) -> BigDigit {
debug_assert!(a.len() == b.len());
let mut offset_carry = big_digit::MAX;
for (x, y) in a.iter_mut().zip(b) {
let offset_sum = big_digit::to_doublebigdigit(big_digit::MAX, *x)
- big_digit::MAX as DoubleBigDigit
+ offset_carry as DoubleBigDigit
- *y as DoubleBigDigit * c as DoubleBigDigit;
let (new_offset_carry, new_x) = big_digit::from_doublebigdigit(offset_sum);
offset_carry = new_offset_carry;
*x = new_x;
}
big_digit::MAX - offset_carry
}
fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
BigInt::from(biguint_from_vec(slice.to_vec()))
}
fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) {
let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
if x.len() <= 32 {
for (i, xi) in x.iter().enumerate() {
mac_digit(&mut acc[i..], y, *xi);
}
} else if x.len() <= 256 {
let b = x.len() / 2;
let (x0, x1) = x.split_at(b);
let (y0, y1) = y.split_at(b);
let len = x1.len() + y1.len() + 1;
let mut p = BigUint { data: vec![0; len] };
mac3(&mut p.data[..], x1, y1);
p.normalize();
add2(&mut acc[b..], &p.data[..]);
add2(&mut acc[b * 2..], &p.data[..]);
p.data.truncate(0);
p.data.extend(repeat(0).take(len));
mac3(&mut p.data[..], x0, y0);
p.normalize();
add2(&mut acc[..], &p.data[..]);
add2(&mut acc[b..], &p.data[..]);
let (j0_sign, j0) = sub_sign(x1, x0);
let (j1_sign, j1) = sub_sign(y1, y0);
match j0_sign * j1_sign {
Plus => {
p.data.truncate(0);
p.data.extend(repeat(0).take(len));
mac3(&mut p.data[..], &j0.data[..], &j1.data[..]);
p.normalize();
sub2(&mut acc[b..], &p.data[..]);
}
Minus => {
mac3(&mut acc[b..], &j0.data[..], &j1.data[..]);
}
NoSign => (),
}
} else {
let i = y.len() / 3 + 1;
let x0_len = cmp::min(x.len(), i);
let x1_len = cmp::min(x.len() - x0_len, i);
let y0_len = i;
let y1_len = cmp::min(y.len() - y0_len, i);
let x0 = bigint_from_slice(&x[..x0_len]);
let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
let y0 = bigint_from_slice(&y[..y0_len]);
let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
let p = &x0 + &x2;
let q = &y0 + &y2;
let p2 = &p - &x1;
let q2 = &q - &y1;
let r0 = &x0 * &y0;
let r4 = &x2 * &y2;
let r1 = (p + x1) * (q + y1);
let r2 = &p2 * &q2;
let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
let mut comp3: BigInt = (r3 - &r1) / 3;
let mut comp1: BigInt = (r1 - &r2) / 2;
let mut comp2: BigInt = r2 - &r0;
comp3 = (&comp2 - comp3) / 2 + &r4 * 2;
comp2 += &comp1 - &r4;
comp1 -= &comp3;
let bits = u64::from(big_digit::BITS) * i as u64;
let result = r0
+ (comp1 << bits)
+ (comp2 << (2 * bits))
+ (comp3 << (3 * bits))
+ (r4 << (4 * bits));
let result_pos = result.to_biguint().unwrap();
add2(&mut acc[..], &result_pos.data);
}
}
pub(crate) fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
let len = x.len() + y.len() + 1;
let mut prod = BigUint { data: vec![0; len] };
mac3(&mut prod.data[..], x, y);
prod.normalized()
}
pub(crate) fn scalar_mul(a: &mut [BigDigit], b: BigDigit) -> BigDigit {
let mut carry = 0;
for a in a.iter_mut() {
*a = mul_with_carry(*a, b, &mut carry);
}
carry as BigDigit
}
pub(crate) fn div_rem(mut u: BigUint, mut d: BigUint) -> (BigUint, BigUint) {
if d.is_zero() {
panic!("attempt to divide by zero")
}
if u.is_zero() {
return (Zero::zero(), Zero::zero());
}
if d.data.len() == 1 {
if d.data == [1] {
return (u, Zero::zero());
}
let (div, rem) = div_rem_digit(u, d.data[0]);
d.data.clear();
d += rem;
return (div, d);
}
match u.cmp(&d) {
Less => return (Zero::zero(), u),
Equal => {
u.set_one();
return (u, Zero::zero());
}
Greater => {}
}
let shift = d.data.last().unwrap().leading_zeros() as usize;
let (q, r) = if shift == 0 {
div_rem_core(u, &d)
} else {
div_rem_core(u << shift, &(d << shift))
};
(q, r >> shift)
}
pub(crate) fn div_rem_ref(u: &BigUint, d: &BigUint) -> (BigUint, BigUint) {
if d.is_zero() {
panic!("attempt to divide by zero")
}
if u.is_zero() {
return (Zero::zero(), Zero::zero());
}
if d.data.len() == 1 {
if d.data == [1] {
return (u.clone(), Zero::zero());
}
let (div, rem) = div_rem_digit(u.clone(), d.data[0]);
return (div, rem.into());
}
match u.cmp(d) {
Less => return (Zero::zero(), u.clone()),
Equal => return (One::one(), Zero::zero()),
Greater => {}
}
let shift = d.data.last().unwrap().leading_zeros() as usize;
let (q, r) = if shift == 0 {
div_rem_core(u.clone(), d)
} else {
div_rem_core(u << shift, &(d << shift))
};
(q, r >> shift)
}
fn div_rem_core(mut a: BigUint, b: &BigUint) -> (BigUint, BigUint) {
debug_assert!(
a.data.len() >= b.data.len()
&& b.data.len() > 1
&& b.data.last().unwrap().leading_zeros() == 0
);
let mut a0 = 0;
let b0 = *b.data.last().unwrap();
let b1 = b.data[b.data.len() - 2];
let q_len = a.data.len() - b.data.len() + 1;
let mut q = BigUint {
data: vec![0; q_len],
};
for j in (0..q_len).rev() {
debug_assert!(a.data.len() == b.data.len() + j);
let a1 = *a.data.last().unwrap();
let a2 = a.data[a.data.len() - 2];
let (mut q0, mut r) = if a0 < b0 {
let (q0, r) = div_wide(a0, a1, b0);
(q0, r as DoubleBigDigit)
} else {
debug_assert!(a0 == b0);
(big_digit::MAX, a0 as DoubleBigDigit + a1 as DoubleBigDigit)
};
while r <= big_digit::MAX as DoubleBigDigit
&& big_digit::to_doublebigdigit(r as BigDigit, a2)
< q0 as DoubleBigDigit * b1 as DoubleBigDigit
{
q0 -= 1;
r += b0 as DoubleBigDigit;
}
let mut borrow = sub_mul_digit_same_len(&mut a.data[j..], &b.data, q0);
if borrow > a0 {
q0 -= 1;
borrow -= __add2(&mut a.data[j..], &b.data);
}
debug_assert!(borrow == a0);
q.data[j] = q0;
a0 = a.data.pop().unwrap();
}
a.data.push(a0);
a.normalize();
debug_assert!(a < *b);
(q.normalized(), a)
}
pub(crate) fn fls<T: PrimInt>(v: T) -> u8 {
mem::size_of::<T>() as u8 * 8 - v.leading_zeros() as u8
}
pub(crate) fn ilog2<T: PrimInt>(v: T) -> u8 {
fls(v) - 1
}
#[inline]
pub(crate) fn biguint_shl<T: PrimInt>(n: Cow<'_, BigUint>, shift: T) -> BigUint {
if shift < T::zero() {
panic!("attempt to shift left with negative");
}
if n.is_zero() {
return n.into_owned();
}
let bits = T::from(big_digit::BITS).unwrap();
let digits = (shift / bits).to_usize().expect("capacity overflow");
let shift = (shift % bits).to_u8().unwrap();
biguint_shl2(n, digits, shift)
}
fn biguint_shl2(n: Cow<'_, BigUint>, digits: usize, shift: u8) -> BigUint {
let mut data = match digits {
0 => n.into_owned().data,
_ => {
let len = digits.saturating_add(n.data.len() + 1);
let mut data = Vec::with_capacity(len);
data.extend(repeat(0).take(digits));
data.extend(n.data.iter());
data
}
};
if shift > 0 {
let mut carry = 0;
let carry_shift = big_digit::BITS as u8 - shift;
for elem in data[digits..].iter_mut() {
let new_carry = *elem >> carry_shift;
*elem = (*elem << shift) | carry;
carry = new_carry;
}
if carry != 0 {
data.push(carry);
}
}
biguint_from_vec(data)
}
#[inline]
pub(crate) fn biguint_shr<T: PrimInt>(n: Cow<'_, BigUint>, shift: T) -> BigUint {
if shift < T::zero() {
panic!("attempt to shift right with negative");
}
if n.is_zero() {
return n.into_owned();
}
let bits = T::from(big_digit::BITS).unwrap();
let digits = (shift / bits).to_usize().unwrap_or(core::usize::MAX);
let shift = (shift % bits).to_u8().unwrap();
biguint_shr2(n, digits, shift)
}
fn biguint_shr2(n: Cow<'_, BigUint>, digits: usize, shift: u8) -> BigUint {
if digits >= n.data.len() {
let mut n = n.into_owned();
n.set_zero();
return n;
}
let mut data = match n {
Cow::Borrowed(n) => n.data[digits..].to_vec(),
Cow::Owned(mut n) => {
n.data.drain(..digits);
n.data
}
};
if shift > 0 {
let mut borrow = 0;
let borrow_shift = big_digit::BITS as u8 - shift;
for elem in data.iter_mut().rev() {
let new_borrow = *elem << borrow_shift;
*elem = (*elem >> shift) | borrow;
borrow = new_borrow;
}
}
biguint_from_vec(data)
}
pub(crate) fn cmp_slice(a: &[BigDigit], b: &[BigDigit]) -> Ordering {
debug_assert!(a.last() != Some(&0));
debug_assert!(b.last() != Some(&0));
match Ord::cmp(&a.len(), &b.len()) {
Equal => Iterator::cmp(a.iter().rev(), b.iter().rev()),
other => other,
}
}
#[cfg(test)]
mod algorithm_tests {
use crate::big_digit::BigDigit;
use crate::{BigInt, BigUint};
use num_traits::Num;
#[test]
fn test_sub_sign() {
use super::sub_sign;
fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
let (sign, val) = sub_sign(a, b);
BigInt::from_biguint(sign, val)
}
let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
let a_i = BigInt::from(a.clone());
let b_i = BigInt::from(b.clone());
assert_eq!(sub_sign_i(&a.data[..], &b.data[..]), &a_i - &b_i);
assert_eq!(sub_sign_i(&b.data[..], &a.data[..]), &b_i - &a_i);
}
}