1use alloc::vec::Vec;
2use core::mem;
3use core::ops::Shl;
4use num_traits::One;
5
6use crate::big_digit::{self, BigDigit, DoubleBigDigit};
7use crate::biguint::BigUint;
8
9struct MontyReducer {
10    n0inv: BigDigit,
11}
12
13fn inv_mod_alt(b: BigDigit) -> BigDigit {
16    assert_ne!(b & 1, 0);
17
18    let mut k0 = BigDigit::wrapping_sub(2, b);
19    let mut t = b - 1;
20    let mut i = 1;
21    while i < big_digit::BITS {
22        t = t.wrapping_mul(t);
23        k0 = k0.wrapping_mul(t + 1);
24
25        i <<= 1;
26    }
27    debug_assert_eq!(k0.wrapping_mul(b), 1);
28    k0.wrapping_neg()
29}
30
31impl MontyReducer {
32    fn new(n: &BigUint) -> Self {
33        let n0inv = inv_mod_alt(n.data[0]);
34        MontyReducer { n0inv }
35    }
36}
37
38#[allow(clippy::many_single_char_names)]
46fn montgomery(x: &BigUint, y: &BigUint, m: &BigUint, k: BigDigit, n: usize) -> BigUint {
47    assert!(
52        x.data.len() == n && y.data.len() == n && m.data.len() == n,
53        "{:?} {:?} {:?} {}",
54        x,
55        y,
56        m,
57        n
58    );
59
60    let mut z = BigUint::ZERO;
61    z.data.resize(n * 2, 0);
62
63    let mut c: BigDigit = 0;
64    for i in 0..n {
65        let c2 = add_mul_vvw(&mut z.data[i..n + i], &x.data, y.data[i]);
66        let t = z.data[i].wrapping_mul(k);
67        let c3 = add_mul_vvw(&mut z.data[i..n + i], &m.data, t);
68        let cx = c.wrapping_add(c2);
69        let cy = cx.wrapping_add(c3);
70        z.data[n + i] = cy;
71        if cx < c2 || cy < c3 {
72            c = 1;
73        } else {
74            c = 0;
75        }
76    }
77
78    if c == 0 {
79        z.data = z.data[n..].to_vec();
80    } else {
81        {
82            let (first, second) = z.data.split_at_mut(n);
83            sub_vv(first, second, &m.data);
84        }
85        z.data = z.data[..n].to_vec();
86    }
87
88    z
89}
90
91#[inline(always)]
92fn add_mul_vvw(z: &mut [BigDigit], x: &[BigDigit], y: BigDigit) -> BigDigit {
93    let mut c = 0;
94    for (zi, xi) in z.iter_mut().zip(x.iter()) {
95        let (z1, z0) = mul_add_www(*xi, y, *zi);
96        let (c_, zi_) = add_ww(z0, c, 0);
97        *zi = zi_;
98        c = c_ + z1;
99    }
100
101    c
102}
103
104#[inline(always)]
106fn sub_vv(z: &mut [BigDigit], x: &[BigDigit], y: &[BigDigit]) -> BigDigit {
107    let mut c = 0;
108    for (i, (xi, yi)) in x.iter().zip(y.iter()).enumerate().take(z.len()) {
109        let zi = xi.wrapping_sub(*yi).wrapping_sub(c);
110        z[i] = zi;
111        c = ((yi & !xi) | ((yi | !xi) & zi)) >> (big_digit::BITS - 1)
113    }
114
115    c
116}
117
118#[inline(always)]
120fn add_ww(x: BigDigit, y: BigDigit, c: BigDigit) -> (BigDigit, BigDigit) {
121    let yc = y.wrapping_add(c);
122    let z0 = x.wrapping_add(yc);
123    let z1 = if z0 < x || yc < y { 1 } else { 0 };
124
125    (z1, z0)
126}
127
128#[inline(always)]
130fn mul_add_www(x: BigDigit, y: BigDigit, c: BigDigit) -> (BigDigit, BigDigit) {
131    let z = x as DoubleBigDigit * y as DoubleBigDigit + c as DoubleBigDigit;
132    ((z >> big_digit::BITS) as BigDigit, z as BigDigit)
133}
134
135#[allow(clippy::many_single_char_names)]
137pub(super) fn monty_modpow(x: &BigUint, y: &BigUint, m: &BigUint) -> BigUint {
138    assert!(m.data[0] & 1 == 1);
139    let mr = MontyReducer::new(m);
140    let num_words = m.data.len();
141
142    let mut x = x.clone();
143
144    if x.data.len() > num_words {
147        x %= m;
148        }
150    if x.data.len() < num_words {
151        x.data.resize(num_words, 0);
152    }
153
154    let mut rr = BigUint::one();
156    rr = (rr.shl(2 * num_words as u64 * u64::from(big_digit::BITS))) % m;
157    if rr.data.len() < num_words {
158        rr.data.resize(num_words, 0);
159    }
160    let mut one = BigUint::one();
162    one.data.resize(num_words, 0);
163
164    let n = 4;
165    let mut powers = Vec::with_capacity(1 << n);
167    powers.push(montgomery(&one, &rr, m, mr.n0inv, num_words));
168    powers.push(montgomery(&x, &rr, m, mr.n0inv, num_words));
169    for i in 2..1 << n {
170        let r = montgomery(&powers[i - 1], &powers[1], m, mr.n0inv, num_words);
171        powers.push(r);
172    }
173
174    let mut z = powers[0].clone();
176    z.data.resize(num_words, 0);
177    let mut zz = BigUint::ZERO;
178    zz.data.resize(num_words, 0);
179
180    for i in (0..y.data.len()).rev() {
182        let mut yi = y.data[i];
183        let mut j = 0;
184        while j < big_digit::BITS {
185            if i != y.data.len() - 1 || j != 0 {
186                zz = montgomery(&z, &z, m, mr.n0inv, num_words);
187                z = montgomery(&zz, &zz, m, mr.n0inv, num_words);
188                zz = montgomery(&z, &z, m, mr.n0inv, num_words);
189                z = montgomery(&zz, &zz, m, mr.n0inv, num_words);
190            }
191            zz = montgomery(
192                &z,
193                &powers[(yi >> (big_digit::BITS - n)) as usize],
194                m,
195                mr.n0inv,
196                num_words,
197            );
198            mem::swap(&mut z, &mut zz);
199            yi <<= n;
200            j += n;
201        }
202    }
203
204    zz = montgomery(&z, &one, m, mr.n0inv, num_words);
206
207    zz.normalize();
208    if zz >= *m {
211        zz -= m;
219        if zz >= *m {
220            zz %= m;
221        }
222    }
223
224    zz.normalize();
225    zz
226}