Skip to content

Commit d5e00ba

Browse files
authored
Align with core/std on overflowing_sh* (#430)
Changes all shift functions which return an overflow flag (as a `Choice` or `ConstChoice`) to use the `overflowing_sh*` name prefix, which aligns with similar APIs in `core`/`std`. In their place, adds new `Uint::{shl, shr}` functions which provide the trait-like behavior (i.e. panic on overflow) but work in `const fn` contexts (and can now panic at compile time on overflow).
1 parent cc3f984 commit d5e00ba

File tree

14 files changed

+133
-93
lines changed

14 files changed

+133
-93
lines changed

benches/boxed_uint.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ fn bench_shifts(c: &mut Criterion) {
1919
group.bench_function("shl", |b| {
2020
b.iter_batched(
2121
|| BoxedUint::random(&mut OsRng, UINT_BITS),
22-
|x| x.shl(UINT_BITS / 2 + 10),
22+
|x| x.overflowing_shl(UINT_BITS / 2 + 10),
2323
BatchSize::SmallInput,
2424
)
2525
});
@@ -35,7 +35,7 @@ fn bench_shifts(c: &mut Criterion) {
3535
group.bench_function("shr", |b| {
3636
b.iter_batched(
3737
|| BoxedUint::random(&mut OsRng, UINT_BITS),
38-
|x| x.shr(UINT_BITS / 2 + 10),
38+
|x| x.overflowing_shr(UINT_BITS / 2 + 10),
3939
BatchSize::SmallInput,
4040
)
4141
});

benches/uint.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,27 +79,35 @@ fn bench_shl(c: &mut Criterion) {
7979
let mut group = c.benchmark_group("left shift");
8080

8181
group.bench_function("shl_vartime, small, U2048", |b| {
82-
b.iter_batched(|| U2048::ONE, |x| x.shl_vartime(10), BatchSize::SmallInput)
82+
b.iter_batched(
83+
|| U2048::ONE,
84+
|x| x.overflowing_shl_vartime(10),
85+
BatchSize::SmallInput,
86+
)
8387
});
8488

8589
group.bench_function("shl_vartime, large, U2048", |b| {
8690
b.iter_batched(
8791
|| U2048::ONE,
88-
|x| black_box(x.shl_vartime(1024 + 10)),
92+
|x| black_box(x.overflowing_shl_vartime(1024 + 10)),
8993
BatchSize::SmallInput,
9094
)
9195
});
9296

9397
group.bench_function("shl_vartime_wide, large, U2048", |b| {
9498
b.iter_batched(
9599
|| (U2048::ONE, U2048::ONE),
96-
|x| Uint::shl_vartime_wide(x, 1024 + 10),
100+
|x| Uint::overflowing_shl_vartime_wide(x, 1024 + 10),
97101
BatchSize::SmallInput,
98102
)
99103
});
100104

101105
group.bench_function("shl, U2048", |b| {
102-
b.iter_batched(|| U2048::ONE, |x| x.shl(1024 + 10), BatchSize::SmallInput)
106+
b.iter_batched(
107+
|| U2048::ONE,
108+
|x| x.overflowing_shl(1024 + 10),
109+
BatchSize::SmallInput,
110+
)
103111
});
104112

105113
group.finish();
@@ -109,27 +117,35 @@ fn bench_shr(c: &mut Criterion) {
109117
let mut group = c.benchmark_group("right shift");
110118

111119
group.bench_function("shr_vartime, small, U2048", |b| {
112-
b.iter_batched(|| U2048::ONE, |x| x.shr_vartime(10), BatchSize::SmallInput)
120+
b.iter_batched(
121+
|| U2048::ONE,
122+
|x| x.overflowing_shr_vartime(10),
123+
BatchSize::SmallInput,
124+
)
113125
});
114126

115127
group.bench_function("shr_vartime, large, U2048", |b| {
116128
b.iter_batched(
117129
|| U2048::ONE,
118-
|x| x.shr_vartime(1024 + 10),
130+
|x| x.overflowing_shr_vartime(1024 + 10),
119131
BatchSize::SmallInput,
120132
)
121133
});
122134

123135
group.bench_function("shr_vartime_wide, large, U2048", |b| {
124136
b.iter_batched(
125137
|| (U2048::ONE, U2048::ONE),
126-
|x| Uint::shr_vartime_wide(x, 1024 + 10),
138+
|x| Uint::overflowing_shr_vartime_wide(x, 1024 + 10),
127139
BatchSize::SmallInput,
128140
)
129141
});
130142

131143
group.bench_function("shr, U2048", |b| {
132-
b.iter_batched(|| U2048::ONE, |x| x.shr(1024 + 10), BatchSize::SmallInput)
144+
b.iter_batched(
145+
|| U2048::ONE,
146+
|x| x.overflowing_shr(1024 + 10),
147+
BatchSize::SmallInput,
148+
)
133149
});
134150

135151
group.finish();

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
//! U256::from_be_hex("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551");
4646
//!
4747
//! // Compute `MODULUS` shifted right by 1 at compile time
48-
//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1).0;
48+
//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1);
4949
//! ```
5050
//!
5151
//! Note that large constant computations may accidentally trigger a the `const_eval_limit` of the compiler.

src/uint/boxed/div.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ impl BoxedUint {
7878
let bits_precision = self.bits_precision();
7979
let mut rem = self.clone();
8080
let mut quo = Self::zero_with_precision(bits_precision);
81-
let (mut c, _overflow) = rhs.shl(bits_precision - mb);
81+
let (mut c, _overflow) = rhs.overflowing_shl(bits_precision - mb);
8282
let mut i = bits_precision;
8383
let mut done = Choice::from(0u8);
8484

src/uint/boxed/shl.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ impl BoxedUint {
99
///
1010
/// Returns a zero and a truthy `Choice` if `shift >= self.bits_precision()`,
1111
/// or the result and a falsy `Choice` otherwise.
12-
pub fn shl(&self, shift: u32) -> (Self, Choice) {
12+
pub fn overflowing_shl(&self, shift: u32) -> (Self, Choice) {
1313
let mut result = self.clone();
1414
let overflow = result.overflowing_shl_assign(shift);
1515
(result, overflow)
@@ -125,7 +125,7 @@ impl Shl<u32> for &BoxedUint {
125125
type Output = BoxedUint;
126126

127127
fn shl(self, shift: u32) -> BoxedUint {
128-
let (result, overflow) = self.shl(shift);
128+
let (result, overflow) = self.overflowing_shl(shift);
129129
assert!(!bool::from(overflow), "attempt to shift left with overflow");
130130
result
131131
}
@@ -154,8 +154,8 @@ mod tests {
154154
fn shl() {
155155
let one = BoxedUint::one_with_precision(128);
156156

157-
assert_eq!(BoxedUint::from(2u8), one.shl(1).0);
158-
assert_eq!(BoxedUint::from(4u8), one.shl(2).0);
157+
assert_eq!(BoxedUint::from(2u8), &one << 1);
158+
assert_eq!(BoxedUint::from(4u8), &one << 2);
159159
assert_eq!(
160160
BoxedUint::from(0x80000000000000000u128),
161161
one.shl_vartime(67).unwrap()

src/uint/boxed/shr.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ impl BoxedUint {
99
///
1010
/// Returns a zero and a truthy `Choice` if `shift >= self.bits_precision()`,
1111
/// or the result and a falsy `Choice` otherwise.
12-
pub fn shr(&self, shift: u32) -> (Self, Choice) {
12+
pub fn overflowing_shr(&self, shift: u32) -> (Self, Choice) {
1313
let mut result = self.clone();
1414
let overflow = result.overflowing_shr_assign(shift);
1515
(result, overflow)
@@ -129,7 +129,7 @@ impl Shr<u32> for &BoxedUint {
129129
type Output = BoxedUint;
130130

131131
fn shr(self, shift: u32) -> BoxedUint {
132-
let (result, overflow) = self.shr(shift);
132+
let (result, overflow) = self.overflowing_shr(shift);
133133
assert!(
134134
!bool::from(overflow),
135135
"attempt to shift right with overflow"
@@ -163,10 +163,10 @@ mod tests {
163163
#[test]
164164
fn shr() {
165165
let n = BoxedUint::from(0x80000000000000000u128);
166-
assert_eq!(BoxedUint::zero(), n.shr(68).0);
167-
assert_eq!(BoxedUint::one(), n.shr(67).0);
168-
assert_eq!(BoxedUint::from(2u8), n.shr(66).0);
169-
assert_eq!(BoxedUint::from(4u8), n.shr(65).0);
166+
assert_eq!(BoxedUint::zero(), &n >> 68);
167+
assert_eq!(BoxedUint::one(), &n >> 67);
168+
assert_eq!(BoxedUint::from(2u8), &n >> 66);
169+
assert_eq!(BoxedUint::from(4u8), &n >> 65);
170170
}
171171

172172
#[test]

src/uint/div.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
2828
let mut rem = *self;
2929
let mut quo = Self::ZERO;
3030
// If there is overflow, it means `mb == 0`, so `rhs == 0`.
31-
let (mut c, _overflow) = rhs.0.shl(Self::BITS - mb);
31+
let (mut c, _overflow) = rhs.0.overflowing_shl(Self::BITS - mb);
3232

3333
let mut i = Self::BITS;
3434
let mut done = ConstChoice::FALSE;
@@ -64,7 +64,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
6464
let mut rem = *self;
6565
let mut quo = Self::ZERO;
6666
// If there is overflow, it means `mb == 0`, so `rhs == 0`.
67-
let (mut c, _overflow) = rhs.0.shl_vartime(bd);
67+
let (mut c, _overflow) = rhs.0.overflowing_shl_vartime(bd);
6868

6969
loop {
7070
let (mut r, borrow) = rem.sbb(&c, Limb::ZERO);
@@ -92,7 +92,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
9292
let mb = rhs.0.bits_vartime();
9393
let mut bd = Self::BITS - mb;
9494
let mut rem = *self;
95-
let (mut c, _overflow) = rhs.0.shl_vartime(bd);
95+
let (mut c, _overflow) = rhs.0.overflowing_shl_vartime(bd);
9696

9797
loop {
9898
let (r, borrow) = rem.sbb(&c, Limb::ZERO);
@@ -123,7 +123,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
123123
let (mut lower, mut upper) = lower_upper;
124124

125125
// Factor of the modulus, split into two halves
126-
let (mut c, _overflow) = Self::shl_vartime_wide((rhs.0, Uint::ZERO), bd);
126+
let (mut c, _overflow) = Self::overflowing_shl_vartime_wide((rhs.0, Uint::ZERO), bd);
127127

128128
loop {
129129
let (lower_sub, borrow) = lower.sbb(&c.0, Limb::ZERO);
@@ -135,7 +135,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
135135
break;
136136
}
137137
bd -= 1;
138-
let (new_c, _overflow) = Self::shr_vartime_wide(c, 1);
138+
let (new_c, _overflow) = Self::overflowing_shr_vartime_wide(c, 1);
139139
c = new_c;
140140
}
141141

@@ -634,8 +634,8 @@ mod tests {
634634
fn div() {
635635
let mut rng = ChaChaRng::from_seed([7u8; 32]);
636636
for _ in 0..25 {
637-
let (num, _) = U256::random(&mut rng).shr_vartime(128);
638-
let den = NonZero::new(U256::random(&mut rng).shr_vartime(128).0).unwrap();
637+
let (num, _) = U256::random(&mut rng).overflowing_shr_vartime(128);
638+
let den = NonZero::new(U256::random(&mut rng).overflowing_shr_vartime(128).0).unwrap();
639639
let n = num.checked_mul(den.as_ref());
640640
if n.is_some().into() {
641641
let (q, _) = n.unwrap().div_rem(&den);
@@ -724,7 +724,7 @@ mod tests {
724724
for _ in 0..25 {
725725
let num = U256::random(&mut rng);
726726
let k = rng.next_u32() % 256;
727-
let (den, _) = U256::ONE.shl_vartime(k);
727+
let (den, _) = U256::ONE.overflowing_shl_vartime(k);
728728

729729
let a = num.rem2k(k);
730730
let e = num.wrapping_rem(&den);

src/uint/inv_mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
3030
// b_{i+1} = (b_i - a * X_i) / 2
3131
b = Self::select(&b, &b.wrapping_sub(self), x_i_choice).shr1();
3232
// Store the X_i bit in the result (x = x | (1 << X_i))
33-
let (shifted, _overflow) = Uint::from_word(x_i).shl_vartime(i);
33+
let (shifted, _overflow) = Uint::from_word(x_i).overflowing_shl_vartime(i);
3434
x = x.bitor(&shifted);
3535

3636
i += 1;
@@ -162,7 +162,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
162162
pub const fn inv_mod(&self, modulus: &Self) -> (Self, ConstChoice) {
163163
// Decompose `modulus = s * 2^k` where `s` is odd
164164
let k = modulus.trailing_zeros();
165-
let (s, _overflow) = modulus.shr(k);
165+
let (s, _overflow) = modulus.overflowing_shr(k);
166166

167167
// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
168168
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
@@ -178,7 +178,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
178178

179179
// This part is mod 2^k
180180
// Will not overflow since `modulus` is nonzero, and therefore `k < BITS`.
181-
let (shifted, _overflow) = Uint::ONE.shl(k);
181+
let (shifted, _overflow) = Uint::ONE.overflowing_shl(k);
182182
let mask = shifted.wrapping_sub(&Uint::ONE);
183183
let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask);
184184

src/uint/mul.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
135135

136136
// Double the current result, this accounts for the other half of the multiplication grid.
137137
// TODO: The top word is empty so we can also use a special purpose shl.
138-
(lo, hi) = Self::shl_vartime_wide((lo, hi), 1).0;
138+
(lo, hi) = Self::overflowing_shl_vartime_wide((lo, hi), 1).0;
139139

140140
// Handle the diagonal of the multiplication grid, which finishes the multiplication grid.
141141
let mut carry = Limb::ZERO;

0 commit comments

Comments
 (0)