Skip to content

Commit db83658

Browse files
authored
Extract FFT code into its own crate (#252)
* Remove fft from `math` module. * Make a separate fft crate to avoid circular dependencies. * Bring back benches. * Move tests to gpu. * Fix error handling. * Remove more repeated code. * Rename abstractions to ops. * Move metal tests back to gpu. * Fix linting error. * Remove duplicate tests. * nit.
1 parent be6b800 commit db83658

36 files changed

+469
-457
lines changed

.rusty-hook.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[hooks]
2-
pre-commit = "cargo test && cargo clippy --all-targets --all-features -- -D warnings && cargo fmt --all -- --check"
2+
pre-commit = "cargo test && cargo clippy --all-targets -- -D warnings && cargo fmt --all -- --check"
33

44
[logging]
55
verbose = true

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[workspace]
22
members = [
33
"math",
4+
"fft",
45
"crypto",
56
"proving_system/stark",
67
"proving_system/plonk",

fft/Cargo.toml

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[package]
2+
name = "lambdaworks-fft"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
rand = "0.8.5"
8+
thiserror = "1.0.38"
9+
lambdaworks-math = { path = "../math" }
10+
lambdaworks-gpu = { path = "../gpu", optional = true }
11+
12+
[dev-dependencies]
13+
proptest = "1.1.0"
14+
criterion = "0.4.0"
15+
objc = "0.2.7"
16+
17+
[features]
18+
metal = ["dep:lambdaworks-gpu"]
19+
cuda = ["dep:lambdaworks-gpu"]
20+
21+
[[bench]]
22+
name = "fft_benchmarks"
23+
harness = false

math/benches/benchmarks/fft.rs renamed to fft/benches/benchmarks/fft.rs

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
use criterion::Criterion;
2+
use lambdaworks_fft::{
3+
bit_reversing::in_place_bit_reverse_permute,
4+
fft_iterative::{in_place_nr_2radix_fft, in_place_rn_2radix_fft},
5+
polynomial::FFTPoly,
6+
roots_of_unity::get_twiddles,
7+
};
28
use lambdaworks_math::{
3-
fft::{bit_reversing::*, fft_iterative::*},
4-
field::{element::FieldElement, traits::IsTwoAdicField},
5-
field::{fields::fft_friendly::stark_252_prime_field::Stark252PrimeField, traits::RootsConfig},
9+
field::{
10+
element::FieldElement, fields::fft_friendly::stark_252_prime_field::Stark252PrimeField,
11+
traits::RootsConfig,
12+
},
613
polynomial::Polynomial,
714
unsigned_integer::element::UnsignedInteger,
815
};
@@ -29,8 +36,8 @@ pub fn fft_benchmarks(c: &mut Criterion) {
2936
group.throughput(criterion::Throughput::Elements(1 << order));
3037

3138
let input = rand_field_elements(order);
32-
let twiddles_bitrev = F::get_twiddles(order, RootsConfig::BitReverse).unwrap();
33-
let twiddles_nat = F::get_twiddles(order, RootsConfig::Natural).unwrap();
39+
let twiddles_bitrev = get_twiddles(order, RootsConfig::BitReverse).unwrap();
40+
let twiddles_nat = get_twiddles(order, RootsConfig::Natural).unwrap();
3441

3542
// the objective is to bench ordered FFT, that's why a bitrev permutation is added.
3643
group.bench_with_input("Sequential from NR radix2", &input, |bench, input| {
@@ -61,7 +68,7 @@ pub fn twiddles_benchmarks(c: &mut Criterion) {
6168

6269
group.bench_with_input("Sequential", &order, |bench, order| {
6370
bench.iter(|| {
64-
F::get_twiddles(*order, RootsConfig::Natural).unwrap();
71+
get_twiddles::<F>(*order, RootsConfig::Natural).unwrap();
6572
});
6673
});
6774
}

fft/benches/benchmarks/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod fft;
File renamed without changes.
File renamed without changes.

math/src/fft/errors.rs renamed to fft/src/errors.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use lambdaworks_math::field::errors::FieldError;
12
use thiserror::Error;
23

34
#[derive(Debug, Error)]
@@ -6,7 +7,6 @@ pub enum FFTError {
67
InvalidOrder(String),
78
#[error("Could not calculate {1} root of unity")]
89
RootOfUnityError(String, u64),
9-
1010
#[error("Couldn't find a system default device for Metal")]
1111
MetalDeviceNotFound(),
1212
#[error("Couldn't create a new Metal library: {0}")]
@@ -16,3 +16,14 @@ pub enum FFTError {
1616
#[error("Couldn't create a new Metal compute pipeline: {0}")]
1717
MetalPipelineError(String),
1818
}
19+
20+
impl From<FieldError> for FFTError {
21+
fn from(error: FieldError) -> Self {
22+
match error {
23+
FieldError::DivisionByZero => {
24+
panic!("Can't divide by zero during FFT");
25+
}
26+
FieldError::RootOfUnityError(error, order) => FFTError::RootOfUnityError(error, order),
27+
}
28+
}
29+
}

math/src/fft/fft_iterative.rs renamed to fft/src/fft_iterative.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::field::{element::FieldElement, traits::IsTwoAdicField};
1+
use lambdaworks_math::field::{element::FieldElement, traits::IsTwoAdicField};
22

33
/// In-Place Radix-2 NR DIT FFT algorithm over a slice of two-adic field elements.
44
/// It's required that the twiddle factors are in bit-reverse order. Else this function will not
@@ -93,10 +93,10 @@ where
9393

9494
#[cfg(test)]
9595
mod tests {
96-
use crate::fft::helpers::log2;
97-
use crate::fft::test_helpers::naive_matrix_dft_test;
98-
use crate::field::test_fields::u64_test_field::U64TestField;
99-
use crate::{fft::bit_reversing::in_place_bit_reverse_permute, field::traits::RootsConfig};
96+
use crate::helpers::log2;
97+
use crate::test_helpers::naive_matrix_dft_test;
98+
use crate::{bit_reversing::in_place_bit_reverse_permute, roots_of_unity::get_twiddles};
99+
use lambdaworks_math::field::{test_fields::u64_test_field::U64TestField, traits::RootsConfig};
100100
use proptest::{collection, prelude::*};
101101

102102
use super::*;
@@ -127,7 +127,7 @@ mod tests {
127127
let expected = naive_matrix_dft_test(&coeffs);
128128

129129
let order = log2(coeffs.len()).unwrap();
130-
let twiddles = F::get_twiddles(order, RootsConfig::BitReverse).unwrap();
130+
let twiddles = get_twiddles(order, RootsConfig::BitReverse).unwrap();
131131

132132
let mut result = coeffs;
133133
in_place_nr_2radix_fft(&mut result, &twiddles);
@@ -142,7 +142,7 @@ mod tests {
142142
let expected = naive_matrix_dft_test(&coeffs);
143143

144144
let order = log2(coeffs.len()).unwrap();
145-
let twiddles = F::get_twiddles(order, RootsConfig::Natural).unwrap();
145+
let twiddles = get_twiddles(order, RootsConfig::Natural).unwrap();
146146

147147
let mut result = coeffs;
148148
in_place_bit_reverse_permute(&mut result);
File renamed without changes.
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
pub mod abstractions;
21
pub mod bit_reversing;
32
pub mod errors;
43
pub mod fft_iterative;
54
pub(crate) mod helpers;
5+
pub mod ops;
6+
pub mod polynomial;
7+
pub mod roots_of_unity;
68

79
#[cfg(test)]
810
pub(crate) mod test_helpers;

math/src/fft/abstractions.rs renamed to fft/src/ops.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
use crate::field::{
1+
use lambdaworks_math::field::{
22
element::FieldElement,
33
traits::{IsTwoAdicField, RootsConfig},
44
};
55

6+
use crate::roots_of_unity::get_twiddles;
7+
68
use super::{
79
bit_reversing::in_place_bit_reverse_permute, errors::FFTError,
810
fft_iterative::in_place_nr_2radix_fft, helpers::log2,
@@ -16,7 +18,7 @@ pub fn fft_with_blowup<F: IsTwoAdicField>(
1618
) -> Result<Vec<FieldElement<F>>, FFTError> {
1719
let domain_size = coeffs.len() * blowup_factor;
1820
let order = log2(domain_size)?;
19-
let twiddles = F::get_twiddles(order, RootsConfig::BitReverse)?;
21+
let twiddles = get_twiddles(order, RootsConfig::BitReverse)?;
2022

2123
let mut results = coeffs.to_vec();
2224
results.resize(domain_size, FieldElement::zero());
@@ -33,7 +35,7 @@ pub fn fft<F: IsTwoAdicField>(
3335
coeffs: &[FieldElement<F>],
3436
) -> Result<Vec<FieldElement<F>>, FFTError> {
3537
let order = log2(coeffs.len())?;
36-
let twiddles = F::get_twiddles(order, RootsConfig::BitReverse)?;
38+
let twiddles = get_twiddles(order, RootsConfig::BitReverse)?;
3739

3840
let mut results = coeffs.to_vec();
3941
in_place_nr_2radix_fft(&mut results, &twiddles);
@@ -48,7 +50,7 @@ pub fn inverse_fft<F: IsTwoAdicField>(
4850
coeffs: &[FieldElement<F>],
4951
) -> Result<Vec<FieldElement<F>>, FFTError> {
5052
let order = log2(coeffs.len())?;
51-
let twiddles = F::get_twiddles(order, RootsConfig::BitReverseInversed)?;
53+
let twiddles = get_twiddles(order, RootsConfig::BitReverseInversed)?;
5254

5355
let mut results = coeffs.to_vec();
5456
in_place_nr_2radix_fft(&mut results, &twiddles);

0 commit comments

Comments
 (0)