Skip to main content

crypto_bigint/uint/mul/
schoolbook.rs

1use crate::Limb;
2
3/// Schoolbook multiplication a.k.a. long multiplication, i.e. the traditional method taught in
4/// schools.
5///
6/// The most efficient method for small numbers.
7#[inline(always)]
8#[track_caller]
9pub const fn mul_wide(lhs: &[Limb], rhs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) {
10    assert!(
11        lhs.len() == lo.len() && rhs.len() == hi.len(),
12        "schoolbook multiplication length mismatch"
13    );
14
15    let mut i = 0;
16
17    while i < lhs.len() {
18        let mut carry = Limb::ZERO;
19        let xi = lhs[i];
20        let mut j = 0;
21
22        while j < rhs.len() {
23            let k = i + j;
24
25            if k >= lhs.len() {
26                (hi[k - lhs.len()], carry) = xi.carrying_mul_add(rhs[j], hi[k - lhs.len()], carry);
27            } else {
28                (lo[k], carry) = xi.carrying_mul_add(rhs[j], lo[k], carry);
29            }
30
31            j += 1;
32        }
33
34        if i + j >= lhs.len() {
35            hi[i + j - lhs.len()] = carry;
36        } else {
37            lo[i + j] = carry;
38        }
39        i += 1;
40    }
41}
42
43/// Schoolbook multiplication which only calculates the lower limbs of the product.
44#[inline(always)]
45#[track_caller]
46pub const fn wrapping_mul_add(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) -> Limb {
47    assert!(
48        lhs.len() + rhs.len() >= out.len(),
49        "wrapping schoolbook multiplication length mismatch"
50    );
51
52    let mut i = 0;
53    let mut meta_carry = Limb::ZERO;
54
55    while i < lhs.len() {
56        let mut carry = Limb::ZERO;
57        let xi = lhs[i];
58        let mut k = i;
59
60        loop {
61            let j = k - i;
62            if k >= out.len() {
63                meta_carry = meta_carry.wrapping_add(carry);
64                break;
65            } else if j == rhs.len() {
66                (out[k], meta_carry) = out[k].carrying_add(carry, meta_carry);
67                break;
68            }
69            (out[k], carry) = xi.carrying_mul_add(rhs[j], out[k], carry);
70            k += 1;
71        }
72
73        i += 1;
74    }
75
76    meta_carry
77}
78
79/// Schoolbook method of squaring.
80///
81/// Like schoolbook multiplication, but only considering half of the multiplication grid.
82#[inline(always)]
83#[track_caller]
84pub const fn square_wide(limbs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) {
85    // Translated from https://github.com/ucbrise/jedi-pairing/blob/c4bf151/include/core/bigint.hpp#L410
86    //
87    // Permission to relicense the resulting translation as Apache 2.0 + MIT was given
88    // by the original author Sam Kumar: https://github.com/RustCrypto/crypto-bigint/pull/133#discussion_r1056870411
89
90    assert!(
91        limbs.len() == lo.len() && lo.len() == hi.len(),
92        "schoolbook squaring length mismatch"
93    );
94
95    let mut i = 1;
96    while i < limbs.len() {
97        let mut j = 0;
98        let mut carry = Limb::ZERO;
99        let xi = limbs[i];
100
101        while j < i {
102            let k = i + j;
103
104            if k >= limbs.len() {
105                (hi[k - limbs.len()], carry) =
106                    xi.carrying_mul_add(limbs[j], hi[k - limbs.len()], carry);
107            } else {
108                (lo[k], carry) = xi.carrying_mul_add(limbs[j], lo[k], carry);
109            }
110
111            j += 1;
112        }
113
114        if (2 * i) < limbs.len() {
115            lo[2 * i] = carry;
116        } else {
117            hi[2 * i - limbs.len()] = carry;
118        }
119
120        i += 1;
121    }
122
123    // Double the current result, this accounts for the other half of the multiplication grid.
124    // The top word is empty, so we use a special purpose shl.
125    let mut carry = Limb::ZERO;
126    let mut i = 0;
127    while i < limbs.len() {
128        (lo[i].0, carry) = ((lo[i].0 << 1) | carry.0, lo[i].shr(Limb::BITS - 1));
129        i += 1;
130    }
131
132    let mut i = 0;
133    while i < limbs.len() - 1 {
134        (hi[i].0, carry) = ((hi[i].0 << 1) | carry.0, hi[i].shr(Limb::BITS - 1));
135        i += 1;
136    }
137    hi[limbs.len() - 1] = carry;
138
139    // Handle the diagonal of the multiplication grid, which finishes the multiplication grid.
140    let mut carry = Limb::ZERO;
141    let mut i = 0;
142    while i < limbs.len() {
143        let xi = limbs[i];
144        if (i * 2) < limbs.len() {
145            (lo[i * 2], carry) = xi.carrying_mul_add(xi, lo[i * 2], carry);
146        } else {
147            (hi[i * 2 - limbs.len()], carry) =
148                xi.carrying_mul_add(xi, hi[i * 2 - limbs.len()], carry);
149        }
150
151        if (i * 2 + 1) < limbs.len() {
152            (lo[i * 2 + 1], carry) = lo[i * 2 + 1].overflowing_add(carry);
153        } else {
154            (hi[i * 2 + 1 - limbs.len()], carry) =
155                hi[i * 2 + 1 - limbs.len()].overflowing_add(carry);
156        }
157
158        i += 1;
159    }
160}
161
162/// Schoolbook squaring which may calculate a limited number of limbs of the product.
163#[inline(always)]
164#[track_caller]
165pub const fn wrapping_square(limbs: &[Limb], out: &mut [Limb]) -> Limb {
166    assert!(
167        limbs.len() * 2 >= out.len(),
168        "schoolbook wrapping squaring length mismatch"
169    );
170
171    let mut i = 1;
172
173    while i < limbs.len() {
174        let mut carry = Limb::ZERO;
175        let xi = limbs[i];
176        let mut k = i;
177
178        while k < 2 * i && k < out.len() {
179            (out[k], carry) = xi.carrying_mul_add(limbs[k - i], out[k], carry);
180            k += 1;
181        }
182
183        if k < out.len() {
184            out[k] = carry;
185        }
186        i += 1;
187    }
188
189    // Double the current result and fill in the diagonal terms.
190    let mut carry = Limb::ZERO;
191    let mut limb;
192    let mut hi_bit = Limb::ZERO;
193    i = 0;
194    while i < out.len() {
195        (limb, hi_bit) = (out[i].shl(1).bitor(hi_bit), out[i].shr(Limb::HI_BIT));
196        (out[i], carry) = if i & 1 == 0 {
197            let i_div_2 = i >> 1;
198            limbs[i_div_2].carrying_mul_add(limbs[i_div_2], limb, carry)
199        } else {
200            limb.overflowing_add(carry)
201        };
202        i += 1;
203    }
204    carry.wrapping_add(hi_bit)
205}