diff --git a/src/cuzk/msm.rs b/src/cuzk/msm.rs index 21705f2..dfbd196 100644 --- a/src/cuzk/msm.rs +++ b/src/cuzk/msm.rs @@ -12,39 +12,17 @@ use crate::cuzk::gpu::{ get_adapter, get_device, read_from_gpu, }; use crate::cuzk::shader_manager::ShaderManager; +use crate::cuzk::utils::compute_p; use crate::cuzk::utils::to_biguint_le; use crate::{points_to_bytes, scalars_to_bytes}; use super::utils::bytes_to_field; -use super::utils::calc_bitwidth; use super::utils::{MiscParams, compute_misc_params}; -use ff::Field; - -/// Calculate the number of words in the field characteristic -pub fn calc_num_words(word_size: usize) -> usize { - let p_bit_length = calc_bitwidth(&P); - let mut num_words = p_bit_length / word_size; - while num_words * word_size < p_bit_length { - num_words += 1; - } - num_words -} +use ff::{Field, PrimeField}; /// 13-bit limbs. pub const WORD_SIZE: usize = 13; -/// Field characteristic -pub static P: Lazy = Lazy::new(|| { - BigUint::from_str_radix( - "21888242871839275222246405745257275088696311157297823662689037894645226208583", - 10, - ) - .expect("Invalid modulus") -}); - -/// Miscellaneous parameters -pub static PARAMS: Lazy = Lazy::new(|| compute_misc_params(&P, WORD_SIZE)); - fn pad_scalars(scalars: &[C::Scalar]) -> Vec { let n = scalars.len(); let l = n.next_power_of_two(); @@ -73,6 +51,8 @@ fn pad_points(points: &[C]) -> Vec { * 2022: https://eprint.iacr.org/2022/1321.pdf */ pub async fn compute_msm(points: &[C], scalars: &[C::Scalar]) -> C::Curve { + let p = compute_p::(); + let params = compute_misc_params(&p, WORD_SIZE); let padded_scalars = pad_scalars::(scalars); let padded_points = pad_points::(points); let input_size = padded_scalars.len(); @@ -80,12 +60,12 @@ pub async fn compute_msm(points: &[C], scalars: &[C::Scalar]) -> let num_columns = 1 << chunk_size; let num_rows = input_size.div_ceil(num_columns); let num_subtasks = 256_usize.div_ceil(chunk_size); - let num_words = PARAMS.num_words; + let num_words = params.num_words; let point_bytes = points_to_bytes(&padded_points); let scalar_bytes = scalars_to_bytes(&padded_scalars); - let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size); + let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size, ¶ms); let adapter = get_adapter().await; let (device, queue) = get_device(&adapter).await; @@ -350,12 +330,13 @@ pub async fn compute_msm(points: &[C], scalars: &[C::Scalar]) -> device.destroy(); let mut points = vec![]; + let r_inv = params.clone().rinv; let g_points_x = bytemuck::cast_slice::(&data[0]) .chunks(num_words) .map(|x| { let x_biguint_montgomery = to_biguint_le(x, num_words, WORD_SIZE as u32); - let x_biguint = x_biguint_montgomery * &PARAMS.rinv % P.clone(); + let x_biguint = x_biguint_montgomery * &r_inv % p.clone(); bytes_to_field(&x_biguint.to_bytes_le()) }) @@ -364,7 +345,7 @@ pub async fn compute_msm(points: &[C], scalars: &[C::Scalar]) -> .chunks(num_words) .map(|y| { let y_biguint_montgomery = to_biguint_le(y, num_words, WORD_SIZE as u32); - let y_biguint = y_biguint_montgomery * &PARAMS.rinv % P.clone(); + let y_biguint = y_biguint_montgomery * &r_inv % p.clone(); bytes_to_field(&y_biguint.to_bytes_le()) }) @@ -373,7 +354,7 @@ pub async fn compute_msm(points: &[C], scalars: &[C::Scalar]) -> .chunks(num_words) .map(|z| { let z_biguint_montgomery = to_biguint_le(z, num_words, WORD_SIZE as u32); - let z_biguint = z_biguint_montgomery * &PARAMS.rinv % P.clone(); + let z_biguint = z_biguint_montgomery * &r_inv % p.clone(); bytes_to_field(&z_biguint.to_bytes_le()) }) diff --git a/src/cuzk/shader_manager.rs b/src/cuzk/shader_manager.rs index 5201928..32d6905 100644 --- a/src/cuzk/shader_manager.rs +++ b/src/cuzk/shader_manager.rs @@ -43,10 +43,9 @@ pub static TEST_FIELD_SHADER: Lazy = pub static TEST_POINT_SHADER: Lazy = Lazy::new(|| include_str!("wgsl/test/test_point.wgsl").to_string()); -use crate::cuzk::utils::{calc_bitwidth, gen_mu_limbs, gen_one_limbs, gen_p_limbs, gen_rinv_limbs}; +use crate::cuzk::utils::{calc_bitwidth, gen_mu_limbs, gen_one_limbs, gen_p_limbs, gen_rinv_limbs, MiscParams}; use super::{ - msm::{P, PARAMS}, utils::{gen_p_limbs_plus_one, gen_r_limbs, gen_zero_limbs}, }; @@ -71,13 +70,13 @@ pub struct ShaderManager { impl ShaderManager { /// Create a new shader manager - pub fn new(word_size: usize, chunk_size: usize, input_size: usize) -> Self { - let p_bit_length = calc_bitwidth(&P); - let num_words = PARAMS.num_words; - let r = PARAMS.r.clone(); - let rinv = PARAMS.rinv.clone(); - println!("P: {P:?}"); - println!("P limbs: {}", gen_p_limbs(&P, num_words, word_size)); + pub fn new(word_size: usize, chunk_size: usize, input_size: usize, params: &MiscParams) -> Self { + let p_bit_length = calc_bitwidth(¶ms.p); + let num_words = params.num_words; + let r = params.r.clone(); + let rinv = params.rinv.clone(); + println!("P: {:?}", params.p); + println!("P limbs: {}", gen_p_limbs(¶ms.p, num_words, word_size)); println!("W_MASK: {:?}", (1 << word_size) - 1); println!("R limbs: {}", gen_r_limbs(&r, num_words, word_size)); Self { @@ -86,15 +85,15 @@ impl ShaderManager { input_size, num_words, index_shift: 1 << (chunk_size - 1), - p_limbs: gen_p_limbs(&P, num_words, word_size), - p_limbs_plus_one: gen_p_limbs_plus_one(&P, num_words, word_size), + p_limbs: gen_p_limbs(¶ms.p, num_words, word_size), + p_limbs_plus_one: gen_p_limbs_plus_one(¶ms.p, num_words, word_size), zero_limbs: gen_zero_limbs(num_words), one_limbs: gen_one_limbs(num_words), slack: num_words * word_size - p_bit_length, w_mask: (1 << word_size) - 1, - n0: PARAMS.n0, + n0: params.n0, r_limbs: gen_r_limbs(&r, num_words, word_size), - mu_limbs: gen_mu_limbs(&P, num_words, word_size), + mu_limbs: gen_mu_limbs(¶ms.p, num_words, word_size), rinv_limbs: gen_rinv_limbs(&rinv, num_words, word_size), } } diff --git a/src/cuzk/test/utils.rs b/src/cuzk/test/utils.rs index cf899c4..e6afa4b 100644 --- a/src/cuzk/test/utils.rs +++ b/src/cuzk/test/utils.rs @@ -1,6 +1,6 @@ use ff::PrimeField; use group::{prime::PrimeCurveAffine, Group}; -use halo2curves::bn256::{Fr, G1, G1Affine}; +use halo2curves::CurveAffine; use crate::cuzk::utils::to_words_le_from_field; @@ -30,15 +30,15 @@ pub fn get_element(arr: &[i32], id: i32) -> i32 { } } -pub fn get_point_element(arr: &[G1Affine], id: i32) -> G1Affine { +pub fn get_point_element(arr: &[C], id: i32) -> C { if id < 0 { if (arr.len() as i32 + id) < 0 { - return G1Affine::identity(); + return C::identity(); } arr[arr.len() + id as usize] } else { if id >= arr.len() as i32 { - return G1Affine::identity(); + return C::identity(); } arr[id as usize] } @@ -137,18 +137,22 @@ pub fn decompose_scalars_signed( signed_slices[i] = limbs[i] as i32 + carry; if signed_slices[i] >= l / 2 { signed_slices[i] = -(l - signed_slices[i]); - if signed_slices[i] == -0 { - signed_slices[i] = 0; - } + // if signed_slices[i] == 0 { + // signed_slices[i] = 0; + // } carry = 1; } else { carry = 0; } } - // We do not need to handle the case where the final carry equals 1, as the highest word of the field modulus (0x12ab) is smaller than 2^{16-1} if carry == 1 { - panic!("final carry is 1"); + // TODO: Review this + // panic!("final carry is 1"); + println!("Carrying 1"); + println!("Scalar: {:?}", scalar); + println!("Limbs: {:?}", limbs); + signed_slices.push(carry); } as_limbs.push(signed_slices.iter().map(|x| x + shift).collect()); } @@ -163,19 +167,19 @@ pub fn decompose_scalars_signed( /** * Perform SMVP with signed bucket indices */ -pub fn cpu_smvp_signed( +pub fn cpu_smvp_signed( subtask_idx: usize, input_size: usize, num_columns: usize, chunk_size: usize, all_csc_col_ptr: &[i32], all_csc_val_idxs: &[i32], - points: &[G1Affine], -) -> Vec { + points: &[C], +) -> Vec { let l = 1 << chunk_size; let h = l / 2; - let zero = G1::identity(); - let mut buckets: Vec = vec![zero; num_columns / 2]; + let zero = C::Curve::identity(); + let mut buckets: Vec = vec![zero; num_columns / 2]; let rp_offset = subtask_idx * (num_columns + 1); @@ -197,7 +201,7 @@ pub fn cpu_smvp_signed( let idx = subtask_idx as i32 * input_size as i32 + k; let val = get_element(all_csc_val_idxs, idx); let point = get_point_element(points, val); - sum += G1::from(point); + sum += C::Curve::from(point); } let bucket_idx; @@ -219,23 +223,23 @@ pub fn cpu_smvp_signed( } /// Serial bucket reduction -pub fn serial_bucket_reduction(buckets: &[G1]) -> G1 { +pub fn serial_bucket_reduction(buckets: &[C::Curve]) -> C::Curve { let mut indices = vec![]; for i in 1..buckets.len() { indices.push(i); } indices.push(0); - let mut bucket_sum = G1::identity(); + let mut bucket_sum = C::Curve::identity(); for i in 1..buckets.len() + 1 { - let b = buckets[indices[i - 1]] * Fr::from(i as u64); + let b = buckets[indices[i - 1]] * C::Scalar::from(i as u64); bucket_sum += b; } bucket_sum } /// Perform running sum in the classic fashion - one siumulated thread only -pub fn running_sum_bucket_reduction(buckets: &[G1]) -> G1 { +pub fn running_sum_bucket_reduction(buckets: &[C::Curve]) -> C::Curve { let n = buckets.len(); let mut m = buckets[0]; let mut g = m; @@ -252,9 +256,9 @@ pub fn running_sum_bucket_reduction(buckets: &[G1]) -> G1 { /// Perform running sum with simulated parallelism. It is up to the caller /// to add the resulting points. -pub fn parallel_bucket_reduction(buckets: &[G1], num_threads: usize) -> Vec { +pub fn parallel_bucket_reduction(buckets: &[C::Curve], num_threads: usize) -> Vec { let buckets_per_thread = buckets.len() / num_threads; - let mut bucket_sums: Vec = vec![]; + let mut bucket_sums: Vec = vec![]; for thread_id in 0..num_threads { let idx = if thread_id == 0 { @@ -275,7 +279,7 @@ pub fn parallel_bucket_reduction(buckets: &[G1], num_threads: usize) -> Vec let s = buckets_per_thread * (num_threads - thread_id - 1); if s > 0 { - g += m * Fr::from(s as u64); + g += m * C::Scalar::from(s as u64); } bucket_sums.push(g); @@ -284,13 +288,13 @@ pub fn parallel_bucket_reduction(buckets: &[G1], num_threads: usize) -> Vec } /// The first part of the parallel bucket reduction algo -pub fn parallel_bucket_reduction_1( - buckets: &[G1], +pub fn parallel_bucket_reduction_1( + buckets: &[C::Curve], num_threads: usize, -) -> (Vec, Vec) { +) -> (Vec, Vec) { let buckets_per_thread = buckets.len() / num_threads; - let mut g_points: Vec = vec![]; - let mut m_points: Vec = vec![]; + let mut g_points: Vec = vec![]; + let mut m_points: Vec = vec![]; for thread_id in 0..num_threads { let idx = if thread_id == 0 { @@ -316,21 +320,21 @@ pub fn parallel_bucket_reduction_1( } /// The second part of the parallel bucket reduction algo -pub fn parallel_bucket_reduction_2( - g_points: Vec, - m_points: Vec, +pub fn parallel_bucket_reduction_2( + g_points: Vec, + m_points: Vec, num_buckets: usize, num_threads: usize, -) -> Vec { +) -> Vec { let buckets_per_thread = num_buckets / num_threads; - let mut result: Vec = vec![]; + let mut result: Vec = vec![]; for thread_id in 0..num_threads { let mut g = g_points[thread_id]; let m = m_points[thread_id]; let s = buckets_per_thread * (num_threads - thread_id - 1); if s > 0 { - g += m * Fr::from(s as u64); + g += m * C::Scalar::from(s as u64); } result.push(g); } diff --git a/src/cuzk/utils.rs b/src/cuzk/utils.rs index f5e582b..9375796 100644 --- a/src/cuzk/utils.rs +++ b/src/cuzk/utils.rs @@ -1,11 +1,22 @@ -use crate::cuzk::msm::{P, calc_num_words}; use ff::{Field, PrimeField}; use halo2curves::CurveAffine; use num_bigint::{BigInt, BigUint, Sign}; -use num_traits::One; +use num_traits::{Num, One}; #[cfg(target_arch = "wasm32")] use web_sys::console; +pub fn compute_p() -> BigUint { + // Trim 0x prefix + let modulus = C::Base::MODULUS; + let modulus_str = if modulus.starts_with("0x") { + &modulus[2..] + } else { + modulus + }; + + BigUint::from_str_radix(modulus_str, 16).unwrap() +} + /// Convert a field element to bytes pub fn field_to_bytes(value: &F) -> Vec { let s_bytes = value.to_repr(); @@ -127,6 +138,7 @@ pub fn field_to_u8_vec_for_gpu( /// Convert a vector of bytes into a vector of field elements pub fn u8s_to_fields_without_assertion( + p: &BigUint, u8s: &[u8], num_words: usize, word_size: usize, @@ -135,15 +147,16 @@ pub fn u8s_to_fields_without_assertion( let mut result = vec![]; for i in 0..(u8s.len() / num_u8s_per_scalar) { - let p = i * num_u8s_per_scalar; - let s = u8s[p..p + num_u8s_per_scalar].to_vec(); - result.push(u8s_to_field_without_assertion(&s, num_words, word_size)); + let t = i * num_u8s_per_scalar; + let s = u8s[t..t + num_u8s_per_scalar].to_vec(); + result.push(u8s_to_field_without_assertion(p, &s, num_words, word_size)); } result } /// Convert a vector of bytes into a field element pub fn u8s_to_field_without_assertion( + p: &BigUint, u8s: &[u8], num_words: usize, word_size: usize, @@ -153,11 +166,12 @@ pub fn u8s_to_field_without_assertion( for i in (0..a.len()).step_by(2) { limbs.push(a[i]); } - from_words_le_without_assertion(&limbs, num_words, word_size) + from_words_le_without_assertion(p,&limbs, num_words, word_size) } /// Convert u16 limbs into a field element pub fn from_words_le_without_assertion( + p: &BigUint, limbs: &[u16], num_words: usize, word_size: usize, @@ -169,7 +183,7 @@ pub fn from_words_le_without_assertion( let exponent = (num_words - i - 1) * word_size; let limb = limbs[num_words - i - 1]; val += BigUint::from(2u32).pow(exponent as u32) * BigUint::from(limb); - if val == *P { + if val == *p { val = BigUint::ZERO; } } @@ -186,11 +200,18 @@ pub fn points_to_bytes_for_gpu( ) -> Vec { g.iter() .flat_map(|affine| { - let coords = affine.coordinates().unwrap(); - let x = field_to_u8_vec_for_gpu(coords.x(), num_words, word_size); - let y = field_to_u8_vec_for_gpu(coords.y(), num_words, word_size); - let z = field_to_u8_vec_for_gpu(&C::Base::ONE, num_words, word_size); - [x, y, z].concat() + if affine.is_identity().into() { + let x = field_to_u8_vec_for_gpu(&C::Base::ZERO, num_words, word_size); + let y = field_to_u8_vec_for_gpu(&C::Base::ONE, num_words, word_size); + let z = field_to_u8_vec_for_gpu(&C::Base::ZERO, num_words, word_size); + return [x, y, z].concat(); + } else { + let coords = affine.coordinates().unwrap(); + let x = field_to_u8_vec_for_gpu(coords.x(), num_words, word_size); + let y = field_to_u8_vec_for_gpu(coords.y(), num_words, word_size); + let z = field_to_u8_vec_for_gpu(&C::Base::ONE, num_words, word_size); + [x, y, z].concat() + } }) .collect::>() } @@ -348,18 +369,29 @@ pub fn calc_rinv_and_n0(p: &BigUint, r: &BigUint, log_limb_size: u32) -> (BigUin } /// Miscellaneous parameters for the WebGPU shader -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MiscParams { pub num_words: usize, pub n0: u32, pub r: BigUint, pub rinv: BigUint, + pub p: BigUint, +} + +/// Calculate the number of words in the field characteristic +pub fn calc_num_words(p: &BigUint, word_size: usize) -> usize { + let p_bit_length = calc_bitwidth(p); + let mut num_words = p_bit_length / word_size; + while num_words * word_size < p_bit_length { + num_words += 1; + } + num_words } /// Compute miscellaneous parameters for the WebGPU shader pub fn compute_misc_params(p: &BigUint, word_size: usize) -> MiscParams { assert!(word_size > 0); - let num_words = calc_num_words(word_size); + let num_words = calc_num_words(p, word_size); let r = BigUint::one() << (num_words * word_size); let res = calc_rinv_and_n0(p, &r, word_size as u32); let rinv = res.0; @@ -369,6 +401,7 @@ pub fn compute_misc_params(p: &BigUint, word_size: usize) -> MiscParams { n0, r: r % p, rinv, + p: p.clone(), } } @@ -384,20 +417,22 @@ pub fn debug(s: &str) { #[cfg(test)] mod tests { - use halo2curves::bn256::{Fq, Fr}; + use halo2curves::bn256::{Fq, Fr, Bn256, G1Affine}; + use ff::{Field, PrimeField}; use num_traits::Num; use rand::thread_rng; use super::*; - use crate::cuzk::msm::{PARAMS, WORD_SIZE}; + use crate::cuzk::msm::WORD_SIZE; use crate::sample_scalars; #[test] fn test_to_words_le_from_le_bytes() { + let p = compute_p::(); let val = sample_scalars::(1)[0]; let bytes = field_to_bytes(&val); for word_size in 13..17 { - let num_words = calc_num_words(word_size); + let num_words = calc_num_words(&p, word_size); let v = BigUint::from_bytes_le(&bytes); let limbs = to_words_le(&v, num_words, word_size); @@ -408,16 +443,18 @@ mod tests { #[test] fn test_gen_p_limbs() { - let p = P.clone(); - let num_words = calc_num_words(13); - let p_limbs = gen_p_limbs(&p, num_words, 13); + let p = compute_p::(); + let num_words = calc_num_words(&p, WORD_SIZE); + let p_limbs = gen_p_limbs(&p, num_words, WORD_SIZE); println!("{}", p_limbs); } #[test] fn test_gen_r_limbs() { - let r = PARAMS.r.clone(); - let num_words = calc_num_words(WORD_SIZE); + let p = compute_p::(); + let params = compute_misc_params(&p, WORD_SIZE); + let r = params.r.clone(); + let num_words = calc_num_words(&p, WORD_SIZE); let r_limbs = gen_r_limbs(&r, num_words, WORD_SIZE); println!("{}", r_limbs); } @@ -425,12 +462,13 @@ mod tests { #[test] fn test_field_to_u8_vec_for_gpu() { // random + let p = compute_p::(); let mut rng = thread_rng(); let a = Fq::random(&mut rng); for word_size in 13..17 { - let num_words = calc_num_words(word_size); + let num_words = calc_num_words(&p, word_size); let bytes = field_to_u8_vec_for_gpu(&a, num_words, word_size); - let a_from_bytes = u8s_to_field_without_assertion(&bytes, num_words, word_size); + let a_from_bytes = u8s_to_field_without_assertion(&p, &bytes, num_words, word_size); assert_eq!(a, a_from_bytes); } } diff --git a/src/cuzk/wgsl/curve/ec.template.wgsl b/src/cuzk/wgsl/curve/ec.template.wgsl index 004eff5..bc56742 100644 --- a/src/cuzk/wgsl/curve/ec.template.wgsl +++ b/src/cuzk/wgsl/curve/ec.template.wgsl @@ -60,7 +60,7 @@ fn point_add(p: Point, q: Point) -> Point { if (field_eq(S1, S2)) { return point_double(p); } else { - return POINT_IDENTITY; + return get_paf(); } } @@ -86,7 +86,7 @@ fn point_add(p: Point, q: Point) -> Point { } fn scalar_mul(p: Point, k: BigInt) -> Point { - var r: Point = POINT_IDENTITY; + var r: Point = get_paf(); var t: Point = p; for (var i = 0u; i < NUM_WORDS; i = i + 1u) { var k_s = k.limbs[i]; @@ -101,12 +101,15 @@ fn scalar_mul(p: Point, k: BigInt) -> Point { return r; } -/// Point negation only involves multiplying the X and T coordinates by -1 in +/// Point negation only involves multiplying the Y coordinate by -1 in /// the field. fn negate_point(point: Point) -> Point { var p = get_p(); var y = point.y; - var neg_y: BigInt; + if (field_eq(y, ZERO)) { + return point; + } + var neg_y = ZERO; bigint_sub(&p, &y, &neg_y); return Point(point.x, neg_y, point.z); } @@ -114,16 +117,18 @@ fn negate_point(point: Point) -> Point { fn get_paf() -> Point { var result: Point; + let r = get_r(); result.x = ZERO; - result.y = ONE; + result.y = r; result.z = ZERO; return result; } + /// This double-and-add code is adapted from the ZPrize test harness: /// https://github.com/demox-labs/webgpu-msm/blob/main/src/reference/webgpu/wgsl/Curve.ts#L78. fn double_and_add(point: Point, scalar: u32) -> Point { /// Set result to the point at infinity. - var result: Point = POINT_IDENTITY; // get_paf(); + var result: Point = get_paf(); var s = scalar; var temp = point; diff --git a/src/cuzk/wgsl/cuzk/smvp.template.wgsl b/src/cuzk/wgsl/cuzk/smvp.template.wgsl index 59b9891..177bf8c 100644 --- a/src/cuzk/wgsl/cuzk/smvp.template.wgsl +++ b/src/cuzk/wgsl/cuzk/smvp.template.wgsl @@ -48,7 +48,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { /// Define custom subtask_idx. let subtask_idx = (id / h); - var inf = POINT_IDENTITY; + var inf = get_paf(); + var z = get_r(); let rp_offset = (subtask_idx + subtask_offset) * (num_columns + 1u); @@ -72,7 +73,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { var x = new_point_x[idx]; var y = new_point_y[idx]; - var z = get_r(); + let pt = Point(x, y, z); sum = point_add(sum, pt); @@ -86,7 +87,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { var bucket_idx = 0u; if (h > row_idx) { bucket_idx = h - row_idx; - sum = negate_point(sum); + let neg_sum = negate_point(sum); + sum = neg_sum; } else { bucket_idx = row_idx - h; } @@ -112,6 +114,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { bucket_z[bi] = sum.z; } } - - {{{ recompile }}} } diff --git a/src/cuzk/wgsl/test/test_point.wgsl b/src/cuzk/wgsl/test/test_point.wgsl index 41fab5e..bb8e5ec 100644 --- a/src/cuzk/wgsl/test/test_point.wgsl +++ b/src/cuzk/wgsl/test/test_point.wgsl @@ -83,6 +83,6 @@ fn test_point_add_identity( var ar_y = field_mul(&ay, &r); var ar_z = field_mul(&az, &r); var p_a = Point(ar_x, ar_y, ar_z); - var p_b = POINT_IDENTITY; + var p_b = get_paf(); result = point_add(p_a, p_b); } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 0837e67..b4f25c7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ use js_sys::Array; use crate::cuzk::utils::field_to_bytes; use wasm_bindgen::prelude::*; use wasm_bindgen_futures; + /// Sample random scalars pub fn sample_scalars(n: usize) -> Vec { let mut rng = thread_rng(); @@ -71,7 +72,7 @@ pub async fn run_webgpu_msm( } #[wasm_bindgen] -pub async fn run_webgpu_msm_web( +pub async fn run_webgpu_msm_web_bn256( sample_size: usize, _callback: js_sys::Function, ) -> Array { @@ -98,7 +99,35 @@ pub async fn run_webgpu_msm_web( } #[wasm_bindgen] -pub async fn run_cpu_msm_web( +pub async fn run_webgpu_msm_web_pallas( + sample_size: usize, + _callback: js_sys::Function, +) -> Array { + use halo2curves::pasta::pallas::{Affine as PallasAffine, Scalar as PallasScalar}; + let start = now(); + debug(&format!("Testing with sample size: {sample_size}")); + let points = sample_points::(sample_size); + let scalars = sample_scalars::(sample_size); + debug(&format!("Sampling points and scalars took {} ms", now() - start)); + + let start = now(); + let result = compute_msm(&points, &scalars).await; + let msm_elapsed = now() - start; + debug(&format!("GPU MSM Elapsed: {} ms", msm_elapsed)); + let coords = result.to_affine().coordinates().unwrap(); + + let x_str = format!("{:?}", coords.x()); + let y_str = format!("{:?}", coords.y()); + + let arr = Array::new(); + arr.push(&JsValue::from(x_str)); + arr.push(&JsValue::from(y_str)); + arr.push(&JsValue::from(msm_elapsed)); + arr +} + +#[wasm_bindgen] +pub async fn run_cpu_msm_web_bn256( sample_size: usize, _callback: js_sys::Function, ) -> Array { @@ -130,37 +159,46 @@ pub mod tests_wasm_pack { use super::*; use halo2curves::bn256::{Fr, G1Affine}; - - + use halo2curves::pasta::pallas::{Affine as PallasAffine, Scalar as PallasScalar}; + use halo2curves::secp256k1::{Secp256k1Affine, Fq as Secp256k1Fq}; + use halo2curves::secq256k1::{Secq256k1Affine, Fq as Secq256k1Fq}; #[wasm_bindgen] extern "C" { #[wasm_bindgen(js_namespace = performance)] fn now() -> f64; } - pub async fn test_webgpu_msm_cuzk(sample_size: usize) { + pub async fn test_webgpu_msm_cuzk(sample_size: usize) { debug(&format!("Testing with sample size: {sample_size}")); - let points = sample_points::(sample_size); - let scalars = sample_scalars::(sample_size); + let points = sample_points::(sample_size); + let scalars = sample_scalars::(sample_size); let cpu_start = now(); let fast = cpu_msm(&points, &scalars); debug(&format!("CPU Elapsed: {} ms", now() - cpu_start)); let result_start = now(); - let result = run_webgpu_msm::(&points, &scalars).await; + let result = run_webgpu_msm::(&points, &scalars).await; debug(&format!("GPU Elapsed: {} ms", now() - result_start)); debug(&format!("Result: {result:?}")); assert_eq!(fast, result); } - #[test] - fn test_webgpu_msm_cuzk_cpu() { - let input_size = 65537; - let scalars = sample_scalars::(input_size); - let points = sample_points::(input_size); + pub async fn test_webgpu_msm_cuzk_bn256(sample_size: usize) { + test_webgpu_msm_cuzk::(sample_size).await; + } + + pub async fn test_webgpu_msm_cuzk_pallas(sample_size: usize) { + test_webgpu_msm_cuzk::(sample_size).await; + } + + pub async fn test_webgpu_msm_cuzk_secp256k1(sample_size: usize) { + test_webgpu_msm_cuzk::(sample_size).await; + } - let result = pollster::block_on(run_webgpu_msm::(&points, &scalars)); + pub async fn test_webgpu_msm_cuzk_secq256k1(sample_size: usize) { + test_webgpu_msm_cuzk::(sample_size).await; } + } diff --git a/tests/cuzk.rs b/tests/cuzk.rs index e01a7aa..78c42a6 100644 --- a/tests/cuzk.rs +++ b/tests/cuzk.rs @@ -2,19 +2,20 @@ #[cfg(test)] mod tests { use halo2curves::bn256::{Fr, G1Affine, G1}; + use halo2curves::pasta::pallas::{Affine as PallasAffine, Scalar as PallasScalar, Point as PallasPoint}; + use halo2curves::secp256k1::{Secp256k1Affine, Fq as Secp256k1Fq, Secp256k1}; + use halo2curves::secq256k1::Secq256k1Affine; + use halo2curves::CurveAffine; use msm_webgpu::cuzk::test::utils::*; use msm_webgpu::{cpu_msm, sample_points, sample_scalars}; use group::{Curve, Group}; - use rand::Rng; - #[test] - fn test_cuzk() { + fn test_cuzk() { - // let input_size = rand::thread_rng().gen_range(1 << 16..1 << 20); - let input_size: usize = (1 << 16) + 4; + let input_size: usize = 1 << 8; let next_power_of_two = input_size.next_power_of_two(); - let scalars = sample_scalars::(input_size); - let points = sample_points::(input_size); + let scalars = sample_scalars::(input_size); + let points = sample_points::(input_size); let input_size = next_power_of_two; @@ -49,11 +50,11 @@ mod tests { &points, ); - let buckets_sum_serial = serial_bucket_reduction(&buckets); - let buckets_sum_rs = running_sum_bucket_reduction(&buckets); + let buckets_sum_serial = serial_bucket_reduction::(&buckets); + let buckets_sum_rs = running_sum_bucket_reduction::(&buckets); - let mut bucket_sum = G1::identity(); - for b in parallel_bucket_reduction(&buckets, 4) { + let mut bucket_sum = C::Curve::identity(); + for b in parallel_bucket_reduction::(&buckets, 4) { bucket_sum = bucket_sum + b; } @@ -63,11 +64,11 @@ mod tests { bucket_sums.push(bucket_sum); let num_buckets = buckets.len(); - let (g_points, m_points) = parallel_bucket_reduction_1(&buckets, 4); + let (g_points, m_points) = parallel_bucket_reduction_1::(&buckets, 4); - let p_result = parallel_bucket_reduction_2(g_points, m_points, num_buckets, 4); + let p_result = parallel_bucket_reduction_2::(g_points, m_points, num_buckets, 4); - let mut bucket_sum_2 = G1::identity(); + let mut bucket_sum_2 = C::Curve::identity(); for b in p_result { bucket_sum_2 = bucket_sum_2 + b; } @@ -81,7 +82,7 @@ mod tests { let m = 1 << chunk_size; let mut result = bucket_sums[bucket_sums.len() - 1]; for i in (0..bucket_sums.len() - 1).rev() { - result = result * Fr::from(m as u64); + result = result * C::Scalar::from(m as u64); result = result + bucket_sums[i]; } @@ -90,7 +91,28 @@ mod tests { let expected = cpu_msm(&points, &scalars); let expected_affine = expected.to_affine(); - assert_eq!(result_affine.x, expected_affine.x); - assert_eq!(result_affine.y, expected_affine.y); + assert_eq!(result_affine, expected_affine); + } + + + #[test] + fn test_cuzk_bn256() { + test_cuzk::(); + } + + #[test] + fn test_cuzk_pallas() { + test_cuzk::(); } + + #[test] + fn test_cuzk_secp256k1() { + test_cuzk::(); + } + + #[test] + fn test_cuzk_secq256k1() { + test_cuzk::(); + } + } diff --git a/tests/decompose_shader.rs b/tests/decompose_shader.rs index 2f17215..64ce433 100644 --- a/tests/decompose_shader.rs +++ b/tests/decompose_shader.rs @@ -5,9 +5,9 @@ use wgpu::CommandEncoderDescriptor; use msm_webgpu::cuzk::{ gpu::{get_adapter, get_device, read_from_gpu_test}, - msm::{P, PARAMS, WORD_SIZE, convert_point_coords_and_decompose_shaders}, + msm::{convert_point_coords_and_decompose_shaders, WORD_SIZE}, shader_manager::ShaderManager, - utils::{bytes_to_field, debug, to_biguint_le}, + utils::{bytes_to_field, compute_misc_params, compute_p, debug, to_biguint_le}, }; use msm_webgpu::{points_to_bytes, scalars_to_bytes}; @@ -15,12 +15,14 @@ async fn decompose_shader( points: &[C], scalars: &[C::Scalar], ) -> (Vec, Vec) { + let p = compute_p::(); + let params = compute_misc_params(&p, WORD_SIZE); let input_size = scalars.len(); let chunk_size = if input_size >= 65536 { 16 } else { 4 }; let num_columns = 1 << chunk_size; let num_rows = input_size.div_ceil(num_columns); let num_subtasks = 256_usize.div_ceil(chunk_size); - let num_words = PARAMS.num_words; + let num_words = params.num_words; debug(&format!("Input size: {input_size}")); debug(&format!("Chunk size: {chunk_size}")); debug(&format!("Num columns: {num_columns}")); @@ -28,12 +30,12 @@ async fn decompose_shader( debug(&format!("Num subtasks: {num_subtasks}")); debug(&format!("Num words: {num_words}")); debug(&format!("Word size: {WORD_SIZE}")); - debug(&format!("Params: {PARAMS:?}")); + debug(&format!("Params: {params:?}")); let point_bytes = points_to_bytes(points); let scalar_bytes = scalars_to_bytes(scalars); - let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size); + let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size, ¶ms); let adapter = get_adapter().await; let (device, queue) = get_device(&adapter).await; @@ -109,8 +111,8 @@ async fn decompose_shader( let p_x_biguint_montgomery = to_biguint_le(x, num_words, WORD_SIZE as u32); let p_y_biguint_montgomery = to_biguint_le(y, num_words, WORD_SIZE as u32); - let p_x_biguint = p_x_biguint_montgomery * &PARAMS.rinv % P.clone(); - let p_y_biguint = p_y_biguint_montgomery * &PARAMS.rinv % P.clone(); + let p_x_biguint = p_x_biguint_montgomery * ¶ms.rinv % p.clone(); + let p_y_biguint = p_y_biguint_montgomery * ¶ms.rinv % p.clone(); let p_x_field = bytes_to_field(&p_x_biguint.to_bytes_le()); let p_y_field = bytes_to_field(&p_y_biguint.to_bytes_le()); @@ -143,17 +145,32 @@ pub async fn run_webgpu_decompose_async( #[cfg(test)] mod tests { use super::*; + use halo2curves::secp256k1::Secp256k1Affine; use msm_webgpu::{sample_points, sample_scalars}; use halo2curves::bn256::{Fr, G1Affine}; + use halo2curves::pasta::pallas::{Affine as PallasAffine, Scalar as PallasScalar}; - #[test] - fn test_decompose() { + fn test_decompose() { let input_size = 1 << 16; - let scalars = sample_scalars::(input_size); - let points = sample_points::(input_size); + let scalars = sample_scalars::(input_size); + let points = sample_points::(input_size); - let (result_points, _result_scalars) = run_webgpu_decompose::(&points, &scalars); + let (result_points, _result_scalars) = run_webgpu_decompose::(&points, &scalars); assert_eq!(result_points, points); } + #[test] + fn test_decompose_bn256() { + test_decompose::(); + } + + #[test] + fn test_decompose_pallas() { + test_decompose::(); + } + + #[test] + fn test_decompose_secp256k1() { + test_decompose::(); + } } diff --git a/tests/field.rs b/tests/field.rs index 3758a84..5167776 100644 --- a/tests/field.rs +++ b/tests/field.rs @@ -1,6 +1,8 @@ use std::time::Instant; use ff::PrimeField; +use num_bigint::BigUint; +use num_traits::Num; use wgpu::CommandEncoderDescriptor; use msm_webgpu::cuzk::{ @@ -9,24 +11,26 @@ use msm_webgpu::cuzk::{ create_compute_pipeline, create_storage_buffer, execute_pipeline, get_adapter, get_device, read_from_gpu_test, }, - msm::{PARAMS, WORD_SIZE}, + msm::WORD_SIZE, shader_manager::ShaderManager, - utils::{bytes_to_field, field_to_u8_vec_for_gpu, to_biguint_le}, + utils::{bytes_to_field, compute_misc_params, field_to_u8_vec_for_gpu, to_biguint_le}, }; async fn field_op(op: &str, a: F, b: F) -> F { - let a_bytes = field_to_u8_vec_for_gpu(&a, PARAMS.num_words, WORD_SIZE); - let b_bytes = field_to_u8_vec_for_gpu(&b, PARAMS.num_words, WORD_SIZE); + let p = BigUint::from_str_radix(&F::MODULUS[2..], 16).unwrap(); + let params = compute_misc_params(&p, WORD_SIZE); + let a_bytes = field_to_u8_vec_for_gpu(&a, params.num_words, WORD_SIZE); + let b_bytes = field_to_u8_vec_for_gpu(&b, params.num_words, WORD_SIZE); let input_size = 1; let chunk_size = if input_size >= 65536 { 16 } else { 4 }; - let num_words = PARAMS.num_words; + let num_words = params.num_words; println!("Input size: {input_size}"); println!("Chunk size: {chunk_size}"); println!("Num words: {num_words}"); println!("Word size: {WORD_SIZE}"); - println!("Params: {PARAMS:?}"); + println!("Params: {params:?}"); - let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size); + let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size, ¶ms); let adapter = get_adapter().await; let (device, queue) = get_device(&adapter).await; @@ -98,83 +102,131 @@ pub async fn run_webgpu_field_op_async(op: &str, a: F, b: F) -> F #[cfg(test)] mod tests { use msm_webgpu::{ - cuzk::{msm::calc_num_words, utils::u8s_to_field_without_assertion}, + cuzk::utils::{calc_num_words, compute_p, u8s_to_field_without_assertion}, sample_scalars, }; use super::*; use ff::Field; - use halo2curves::bn256::Fq; + use halo2curves::{bn256::{Fq, G1Affine}, CurveAffine}; + use halo2curves::pasta::pallas::{Scalar as PallasFq, Affine as PallasAffine}; use rand::thread_rng; - #[test] - fn test_webgpu_field_add() { - let scalars = sample_scalars::(50); + fn test_webgpu_field_add() { + let scalars = sample_scalars::(50); for scalar in scalars.chunks(2) { let a = scalar[0]; let b = scalar[1]; let fast = a + b; - let result = run_webgpu_field_op::("test_field_add", a, b); + let result = run_webgpu_field_op::("test_field_add", a, b); println!("Result: {:?}", result); assert_eq!(fast, result); } } + #[test] + fn test_webgpu_field_add_bn256() { + test_webgpu_field_add::(); + } #[test] - fn test_webgpu_field_sub() { - let scalars = sample_scalars::(50); + fn test_webgpu_field_add_pallas() { + test_webgpu_field_add::(); + } + + + fn test_webgpu_field_sub() { + let scalars = sample_scalars::(50); for scalar in scalars.chunks(2) { let a = scalar[0]; let b = scalar[1]; let fast = a - b; - let result = run_webgpu_field_op::("test_field_sub", a, b); + let result = run_webgpu_field_op::("test_field_sub", a, b); println!("Result: {:?}", result); assert_eq!(fast, result); } } + #[test] + fn test_webgpu_field_sub_bn256() { + test_webgpu_field_sub::(); + } #[test] - fn test_webgpu_field_mul() { - let mut rng = thread_rng(); - let a = Fq::random(&mut rng); - let b = Fq::random(&mut rng); + fn test_webgpu_field_sub_pallas() { + test_webgpu_field_sub::(); + } - let fast = a * b; - let result = run_webgpu_field_op::("test_field_mul", a, b); + fn test_webgpu_field_mul() { + let scalars = sample_scalars::(50); + for scalar in scalars.chunks(2) { + let a = scalar[0]; + let b = scalar[1]; + + let fast = a * b; + + let result = run_webgpu_field_op::("test_field_mul", a, b); - println!("Result: {:?}", result); - assert_eq!(fast, result); + println!("Result: {:?}", result); + assert_eq!(fast, result); + } + } + #[test] + fn test_webgpu_field_mul_bn256() { + test_webgpu_field_mul::(); } #[test] - fn test_webgpu_field_barret_mul() { - let mut rng = thread_rng(); - let a = Fq::random(&mut rng); - let b = Fq::random(&mut rng); + fn test_webgpu_field_mul_pallas() { + test_webgpu_field_mul::(); + } + + fn test_webgpu_field_barret_mul() { + let scalars = sample_scalars::(50); + for scalar in scalars.chunks(2) { + let a = scalar[0]; + let b = scalar[1]; - let fast = a; - let result = run_webgpu_field_op::("test_barret_mul", a, b); + let fast = a; + let result = run_webgpu_field_op::("test_barret_mul", a, b); - println!("Result: {:?}", result); - assert_eq!(fast, result); + println!("Result: {:?}", result); + assert_eq!(fast, result); + } + } + #[test] + fn test_webgpu_field_barret_mul_bn256() { + test_webgpu_field_barret_mul::(); } #[test] - fn test_field_to_u8_vec_for_gpu() { - // random + fn test_webgpu_field_barret_mul_pallas() { + test_webgpu_field_barret_mul::(); + } + + fn test_field_to_u8_vec_for_gpu() { + let p = compute_p::(); let mut rng = thread_rng(); - let a = Fq::random(&mut rng); + let a = C::Scalar::random(&mut rng); for word_size in 13..17 { - let num_words = calc_num_words(word_size); + let num_words = calc_num_words(&p, word_size); let bytes = field_to_u8_vec_for_gpu(&a, num_words, word_size); - let a_from_bytes = u8s_to_field_without_assertion(&bytes, num_words, word_size); + let a_from_bytes = u8s_to_field_without_assertion(&p, &bytes, num_words, word_size); assert_eq!(a, a_from_bytes); } } + + #[test] + fn test_field_to_u8_vec_for_gpu_bn256() { + test_field_to_u8_vec_for_gpu::(); + } + + #[test] + fn test_field_to_u8_vec_for_gpu_pallas() { + test_field_to_u8_vec_for_gpu::(); + } } diff --git a/tests/point.rs b/tests/point.rs index 8edf914..8b34451 100644 --- a/tests/point.rs +++ b/tests/point.rs @@ -9,25 +9,27 @@ use msm_webgpu::cuzk::{ create_bind_group_layout, create_compute_pipeline, create_storage_buffer, execute_pipeline, get_adapter, get_device, read_from_gpu_test, }, - msm::{P, PARAMS, WORD_SIZE}, + msm::WORD_SIZE, shader_manager::ShaderManager, - utils::{bytes_to_field, points_to_bytes_for_gpu, to_biguint_le}, + utils::{bytes_to_field, compute_misc_params, compute_p, points_to_bytes_for_gpu, to_biguint_le}, }; async fn point_op(op: &str, a: C, b: C, scalar: u32) -> C::Curve { - let a_bytes = points_to_bytes_for_gpu(&[a], PARAMS.num_words, WORD_SIZE); - let b_bytes = points_to_bytes_for_gpu(&[b], PARAMS.num_words, WORD_SIZE); + let p = compute_p::(); + let params = compute_misc_params(&p, WORD_SIZE); + let a_bytes = points_to_bytes_for_gpu(&[a], params.num_words, WORD_SIZE); + let b_bytes = points_to_bytes_for_gpu(&[b], params.num_words, WORD_SIZE); let scalar_bytes = scalar.to_le_bytes(); let input_size = 1; let chunk_size = if input_size >= 65536 { 16 } else { 4 }; - let num_words = PARAMS.num_words; + let num_words = params.num_words; println!("Input size: {input_size}"); println!("Chunk size: {chunk_size}"); println!("Num words: {num_words}"); println!("Word size: {WORD_SIZE}"); - println!("Params: {PARAMS:?}"); + println!("Params: {params:?}"); - let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size); + let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size, ¶ms); let adapter = get_adapter().await; let (device, queue) = get_device(&adapter).await; @@ -81,10 +83,10 @@ async fn point_op(op: &str, a: C, b: C, scalar: u32) -> C::Curve println!("Data length: {:?}", data_u32.len()); let results = data_u32 - .chunks(20) + .chunks(num_words) .map(|chunk| { let biguint_montgomery = to_biguint_le(chunk, num_words, WORD_SIZE as u32); - let biguint = biguint_montgomery * &PARAMS.rinv % P.clone(); + let biguint = biguint_montgomery * ¶ms.rinv % p.clone(); let field: <::CurveExt as CurveExt>::Base = bytes_to_field(&biguint.to_bytes_le()); field @@ -117,71 +119,137 @@ pub async fn run_webgpu_point_op_async( #[cfg(test)] mod tests { use super::*; - use group::Curve; + use ff::{Field, PrimeField}; + use group::{Curve, Group}; use group::cofactor::CofactorCurveAffine; use halo2curves::bn256::{Fr, G1Affine}; + use halo2curves::pasta::pallas::{Affine as PallasAffine, Point as PallasPoint, Scalar as PallasScalar, Base as PallasBase}; + use halo2curves::secp256k1::{Fq as Secp256k1Fq, Secp256k1, Secp256k1Affine}; + use msm_webgpu::cuzk::utils::gen_p_limbs; + use num_bigint::BigUint; + use num_traits::Num; use rand::{Rng, thread_rng}; - #[test] - fn test_webgpu_point_add() { + fn test_webgpu_point_add() { let mut rng = thread_rng(); - let a = G1Affine::random(&mut rng); + let a = C::Curve::random(&mut rng).to_affine(); println!("a: {:?}", a); - let b = G1Affine::random(&mut rng); + let b = C::Curve::random(&mut rng).to_affine(); println!("b: {:?}", b); let fast = a + b; - let result = run_webgpu_point_op::("test_point_add", a, b, 0); + let result = run_webgpu_point_op::("test_point_add", a, b, 0); println!("Result: {:?}", result); assert_eq!(fast, result); } #[test] - fn test_webgpu_point_add_identity() { + fn test_webgpu_point_add_bn256() { + test_webgpu_point_add::(); + } + + #[test] + fn test_webgpu_point_add_pallas() { + test_webgpu_point_add::(); + } + + #[test] + fn test_webgpu_point_add_secp256k1() { + test_webgpu_point_add::(); + } + + fn test_webgpu_point_add_identity() { let mut rng = thread_rng(); - let a = G1Affine::random(&mut rng); + let a = C::Curve::random(&mut rng).to_affine(); println!("a: {:?}", a); - let b = G1Affine::identity(); + let b = C::identity(); println!("b: {:?}", b); let fast = a + b; - let result = run_webgpu_point_op::("test_point_add_identity", a, b, 0); + let result = run_webgpu_point_op::("test_point_add_identity", a, b, 0); println!("Result: {:?}", result); assert_eq!(fast, result); } #[test] - fn test_webgpu_point_negate() { + fn test_webgpu_point_add_identity_bn256() { + test_webgpu_point_add_identity::(); + } + + #[test] + fn test_webgpu_point_add_identity_pallas() { + test_webgpu_point_add_identity::(); + } + + #[test] + fn test_webgpu_point_add_identity_secp256k1() { + test_webgpu_point_add_identity::(); + } + + + fn test_webgpu_point_negate() { + + for _ in 0..1000 { let mut rng = thread_rng(); - let a = G1Affine::random(&mut rng); + let a = C::Curve::random(&mut rng).to_affine(); println!("a: {:?}", a); let fast = -a; - let result = run_webgpu_point_op::("test_negate_point", a, a, 0); + let result = run_webgpu_point_op::("test_negate_point", a, a, 0); println!("Result: {:?}", result); assert_eq!(fast, result.to_affine()); + } } #[test] - fn test_webgpu_point_double_and_add() { + fn test_webgpu_point_negate_bn256() { + test_webgpu_point_negate::(); + } + + #[test] + fn test_webgpu_point_negate_pallas() { + test_webgpu_point_negate::(); + } + + #[test] + fn test_webgpu_point_negate_secp256k1() { + test_webgpu_point_negate::(); + } + + fn test_webgpu_point_double_and_add() { let mut rng = thread_rng(); - let a = G1Affine::random(&mut rng); + let a = C::Curve::random(&mut rng).to_affine(); println!("a: {:?}", a); // random u32 let scalar = rng.gen_range(0..u32::MAX); println!("scalar: {:?}", scalar); - let fast = a * Fr::from(scalar as u64); + let fast = a * C::Scalar::from(scalar as u64); - let result = run_webgpu_point_op::("test_double_and_add", a, a, scalar); + let result = run_webgpu_point_op::("test_double_and_add", a, a, scalar); println!("Result: {:?}", result); assert_eq!(fast, result); } + + #[test] + fn test_webgpu_point_double_and_add_bn256() { + test_webgpu_point_double_and_add::(); + } + + #[test] + fn test_webgpu_point_double_and_add_pallas() { + test_webgpu_point_double_and_add::(); + } + + #[test] + fn test_webgpu_point_double_and_add_secp256k1() { + test_webgpu_point_double_and_add::(); + } } diff --git a/tests/smvp_shader.rs b/tests/smvp_shader.rs index c8f1387..3eeabd5 100644 --- a/tests/smvp_shader.rs +++ b/tests/smvp_shader.rs @@ -7,10 +7,10 @@ use wgpu::CommandEncoderDescriptor; use msm_webgpu::cuzk::{ gpu::{create_storage_buffer, get_adapter, get_device, read_from_gpu_test}, msm::{ - P, PARAMS, WORD_SIZE, convert_point_coords_and_decompose_shaders, smvp_gpu, transpose_gpu, + convert_point_coords_and_decompose_shaders, smvp_gpu, transpose_gpu, WORD_SIZE }, shader_manager::ShaderManager, - utils::{bytes_to_field, debug, to_biguint_le}, + utils::{bytes_to_field, compute_misc_params, compute_p, debug, to_biguint_le}, }; use msm_webgpu::{points_to_bytes, scalars_to_bytes}; @@ -18,12 +18,14 @@ async fn smvp_shader( points: &[C], scalars: &[C::Scalar], ) -> Vec { + let p = compute_p::(); + let params = compute_misc_params(&p, WORD_SIZE); let input_size = scalars.len(); let chunk_size = if input_size >= 65536 { 16 } else { 4 }; let num_columns = 1 << chunk_size; let num_rows = input_size.div_ceil(num_columns); let num_subtasks = 256_usize.div_ceil(chunk_size); - let num_words = PARAMS.num_words; + let num_words = params.num_words; debug(&format!("Input size: {input_size}")); debug(&format!("Chunk size: {chunk_size}")); debug(&format!("Num columns: {num_columns}")); @@ -31,12 +33,12 @@ async fn smvp_shader( debug(&format!("Num subtasks: {num_subtasks}")); debug(&format!("Num words: {num_words}")); debug(&format!("Word size: {WORD_SIZE}")); - println!("Params: {PARAMS:?}"); + println!("Params: {params:?}"); let point_bytes = points_to_bytes(points); let scalar_bytes = scalars_to_bytes(scalars); - let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size); + let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size, ¶ms); let adapter = get_adapter().await; let (device, queue) = get_device(&adapter).await; @@ -189,7 +191,6 @@ async fn smvp_shader( ); let smvp_shader = shader_manager.gen_smvp_shader(s_workgroup_size, num_columns); - debug(&format!("SMVP shader: {smvp_shader}")); debug(&format!( "s_num_x_workgroups / (num_subtasks / num_subtask_chunk_size): {:?}", s_num_x_workgroups / (num_subtasks / num_subtask_chunk_size) @@ -231,11 +232,11 @@ async fn smvp_shader( // Destroy the GPU device object. device.destroy(); - let p_x = bytemuck::cast_slice::(&data[0]).chunks(20); - let p_y = bytemuck::cast_slice::(&data[1]).chunks(20); - let p_z = bytemuck::cast_slice::(&data[2]).chunks(20); + let p_x = bytemuck::cast_slice::(&data[0]).chunks(num_words); + let p_y = bytemuck::cast_slice::(&data[1]).chunks(num_words); + let p_z = bytemuck::cast_slice::(&data[2]).chunks(num_words); - + println!("Num words: {num_words:?}"); zip(zip(p_x, p_y), p_z) .enumerate() .map(|(i, ((x, y), z))| { @@ -243,13 +244,23 @@ async fn smvp_shader( let p_y_biguint_montgomery = to_biguint_le(y, num_words, WORD_SIZE as u32); let p_z_biguint_montgomery = to_biguint_le(z, num_words, WORD_SIZE as u32); - let p_x_biguint = p_x_biguint_montgomery * &PARAMS.rinv % P.clone(); - let p_y_biguint = p_y_biguint_montgomery * &PARAMS.rinv % P.clone(); - let p_z_biguint = p_z_biguint_montgomery * &PARAMS.rinv % P.clone(); + let p_x_biguint = p_x_biguint_montgomery * ¶ms.rinv % p.clone(); + let p_y_biguint = p_y_biguint_montgomery * ¶ms.rinv % p.clone(); + let p_z_biguint = p_z_biguint_montgomery * ¶ms.rinv % p.clone(); let p_x_field = bytes_to_field(&p_x_biguint.to_bytes_le()); let p_y_field = bytes_to_field(&p_y_biguint.to_bytes_le()); let p_z_field = bytes_to_field(&p_z_biguint.to_bytes_le()); - let p = C::Curve::new_jacobian(p_x_field, p_y_field, p_z_field).unwrap(); + let p_opt = C::Curve::new_jacobian(p_x_field, p_y_field, p_z_field); + let p = if p_opt.is_some().into() { + p_opt.unwrap() + } else { + println!("Index: {i:?}"); + println!("P x: {p_x_field:?}"); + println!("P y: {p_y_field:?}"); + println!("P z: {p_z_field:?}"); + + panic!("Bad point"); + }; if p.is_identity().into() && i < 15 { println!("Index: {i:?}"); println!("P x: {p_x_field:?}"); @@ -288,12 +299,13 @@ mod tests { use super::*; use halo2curves::bn256::{Fr, G1Affine}; + use halo2curves::pasta::pallas::{Affine as PallasAffine, Scalar as PallasScalar}; + use halo2curves::secp256k1::{Secp256k1Affine, Fq as Secp256k1Fq}; - #[test] - fn test_webgpu_smvp_shader() { + fn test_webgpu_smvp_shader() { let input_size = 1 << 16; - let scalars = sample_scalars::(input_size); - let points = sample_points::(input_size); + let scalars = sample_scalars::(input_size); + let points = sample_points::(input_size); let chunk_size = if input_size >= 65536 { 16 } else { 4 }; let num_columns = 1 << chunk_size; @@ -311,7 +323,7 @@ mod tests { input_size, ); - let result_bucket_sums = run_webgpu_smvp_shader::(&points, &scalars); + let result_bucket_sums = run_webgpu_smvp_shader::(&points, &scalars); println!("Result bucket sums length: {:?}", result_bucket_sums.len()); let mut bucket_sums = vec![]; @@ -332,4 +344,19 @@ mod tests { } assert_eq!(result_bucket_sums, bucket_sums); } + + #[test] + fn test_webgpu_smvp_shader_bn256() { + test_webgpu_smvp_shader::(); + } + + #[test] + fn test_webgpu_smvp_shader_pallas() { + test_webgpu_smvp_shader::(); + } + + #[test] + fn test_webgpu_smvp_shader_secp256k1() { + test_webgpu_smvp_shader::(); + } } diff --git a/tests/test_webgpu_msm_cuzk_16.rs b/tests/test_webgpu_msm_cuzk_16_bn256.rs similarity index 54% rename from tests/test_webgpu_msm_cuzk_16.rs rename to tests/test_webgpu_msm_cuzk_16_bn256.rs index c255829..2c9a668 100644 --- a/tests/test_webgpu_msm_cuzk_16.rs +++ b/tests/test_webgpu_msm_cuzk_16_bn256.rs @@ -1,13 +1,13 @@ #[cfg(test)] mod tests_wasm_pack_16 { - use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk; + use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk_bn256; use wasm_bindgen_test::wasm_bindgen_test; use wasm_bindgen_test::*; wasm_bindgen_test_configure!(run_in_browser); #[wasm_bindgen_test] - async fn test_webgpu_msm_cuzk_16() { - test_webgpu_msm_cuzk(1 << 16).await; + async fn test_webgpu_msm_cuzk_16_bn256() { + test_webgpu_msm_cuzk_bn256(1 << 16).await; } } diff --git a/tests/test_webgpu_msm_cuzk_16_pallas.rs b/tests/test_webgpu_msm_cuzk_16_pallas.rs new file mode 100644 index 0000000..903a27f --- /dev/null +++ b/tests/test_webgpu_msm_cuzk_16_pallas.rs @@ -0,0 +1,13 @@ +#[cfg(test)] +mod tests_wasm_pack_16 { + use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk_pallas; + use wasm_bindgen_test::wasm_bindgen_test; + use wasm_bindgen_test::*; + + wasm_bindgen_test_configure!(run_in_browser); + + #[wasm_bindgen_test] + async fn test_webgpu_msm_cuzk_16_pallas() { + test_webgpu_msm_cuzk_pallas(1 << 16).await; + } +} diff --git a/tests/test_webgpu_msm_cuzk_16_secp256k1.rs b/tests/test_webgpu_msm_cuzk_16_secp256k1.rs new file mode 100644 index 0000000..77912bb --- /dev/null +++ b/tests/test_webgpu_msm_cuzk_16_secp256k1.rs @@ -0,0 +1,13 @@ +#[cfg(test)] +mod tests_wasm_pack_16 { + use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk_secp256k1; + use wasm_bindgen_test::wasm_bindgen_test; + use wasm_bindgen_test::*; + + wasm_bindgen_test_configure!(run_in_browser); + + #[wasm_bindgen_test] + async fn test_webgpu_msm_cuzk_16_secp256k1() { + test_webgpu_msm_cuzk_secp256k1(1 << 16).await; + } +} diff --git a/tests/test_webgpu_msm_cuzk_16_secq256k1.rs b/tests/test_webgpu_msm_cuzk_16_secq256k1.rs new file mode 100644 index 0000000..6ab8362 --- /dev/null +++ b/tests/test_webgpu_msm_cuzk_16_secq256k1.rs @@ -0,0 +1,13 @@ +#[cfg(test)] +mod tests_wasm_pack_16 { + use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk_secq256k1; + use wasm_bindgen_test::wasm_bindgen_test; + use wasm_bindgen_test::*; + + wasm_bindgen_test_configure!(run_in_browser); + + #[wasm_bindgen_test] + async fn test_webgpu_msm_cuzk_16_secq256k1() { + test_webgpu_msm_cuzk_secq256k1(1 << 16).await; + } +} diff --git a/tests/test_webgpu_msm_cuzk_17.rs b/tests/test_webgpu_msm_cuzk_17.rs index 8a605cf..8011d65 100644 --- a/tests/test_webgpu_msm_cuzk_17.rs +++ b/tests/test_webgpu_msm_cuzk_17.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests_wasm_pack_17 { - use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk; + use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk_bn256; use wasm_bindgen_test::wasm_bindgen_test; use wasm_bindgen_test::*; @@ -8,6 +8,6 @@ mod tests_wasm_pack_17 { #[wasm_bindgen_test] async fn test_webgpu_msm_cuzk_17() { - test_webgpu_msm_cuzk(1 << 17).await; + test_webgpu_msm_cuzk_bn256(1 << 17).await; } } diff --git a/tests/test_webgpu_msm_cuzk_18.rs b/tests/test_webgpu_msm_cuzk_18.rs index 332fd43..bd323f2 100644 --- a/tests/test_webgpu_msm_cuzk_18.rs +++ b/tests/test_webgpu_msm_cuzk_18.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests_wasm_pack_18 { - use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk; + use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk_bn256; use wasm_bindgen_test::wasm_bindgen_test; use wasm_bindgen_test::*; @@ -8,6 +8,6 @@ mod tests_wasm_pack_18 { #[wasm_bindgen_test] async fn test_webgpu_msm_cuzk_18() { - test_webgpu_msm_cuzk(1 << 18).await; + test_webgpu_msm_cuzk_bn256(1 << 18).await; } } diff --git a/tests/test_webgpu_msm_cuzk_19.rs b/tests/test_webgpu_msm_cuzk_19.rs index 37fe3f8..180be72 100644 --- a/tests/test_webgpu_msm_cuzk_19.rs +++ b/tests/test_webgpu_msm_cuzk_19.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests_wasm_pack_19 { - use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk; + use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk_bn256; use wasm_bindgen_test::wasm_bindgen_test; use wasm_bindgen_test::*; @@ -8,6 +8,6 @@ mod tests_wasm_pack_19 { #[wasm_bindgen_test] async fn test_webgpu_msm_cuzk_19() { - test_webgpu_msm_cuzk(1 << 19).await; + test_webgpu_msm_cuzk_bn256(1 << 19).await; } } diff --git a/tests/test_webgpu_msm_cuzk_20.rs b/tests/test_webgpu_msm_cuzk_20.rs index 1c528b1..f05bb2f 100644 --- a/tests/test_webgpu_msm_cuzk_20.rs +++ b/tests/test_webgpu_msm_cuzk_20.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests_wasm_pack_20 { - use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk; + use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk_bn256; use wasm_bindgen_test::wasm_bindgen_test; use wasm_bindgen_test::*; @@ -8,6 +8,6 @@ mod tests_wasm_pack_20 { #[wasm_bindgen_test] async fn test_webgpu_msm_cuzk_20() { - test_webgpu_msm_cuzk(1 << 20).await; + test_webgpu_msm_cuzk_bn256(1 << 20).await; } } diff --git a/tests/test_webgpu_msm_cuzk_random.rs b/tests/test_webgpu_msm_cuzk_random.rs index 6b1b231..bba403b 100644 --- a/tests/test_webgpu_msm_cuzk_random.rs +++ b/tests/test_webgpu_msm_cuzk_random.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests_wasm_pack_16 { - use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk; + use msm_webgpu::tests_wasm_pack::test_webgpu_msm_cuzk_bn256; use rand::Rng; use wasm_bindgen_test::wasm_bindgen_test; use wasm_bindgen_test::*; @@ -11,6 +11,6 @@ mod tests_wasm_pack_16 { async fn test_webgpu_msm_cuzk_random() { // Random between 2^16 and 2^20 let sample_size = rand::thread_rng().gen_range(1 << 16..1 << 20); - test_webgpu_msm_cuzk(sample_size).await; + test_webgpu_msm_cuzk_bn256(sample_size).await; } } diff --git a/tests/transpose_shader.rs b/tests/transpose_shader.rs index 127f990..33328f6 100644 --- a/tests/transpose_shader.rs +++ b/tests/transpose_shader.rs @@ -5,9 +5,9 @@ use wgpu::CommandEncoderDescriptor; use msm_webgpu::cuzk::{ gpu::{get_adapter, get_device, read_from_gpu_test}, - msm::{PARAMS, WORD_SIZE, convert_point_coords_and_decompose_shaders, transpose_gpu}, + msm::{convert_point_coords_and_decompose_shaders, transpose_gpu, WORD_SIZE}, shader_manager::ShaderManager, - utils::debug, + utils::{compute_misc_params, compute_p, debug}, }; use msm_webgpu::{points_to_bytes, scalars_to_bytes}; @@ -15,12 +15,14 @@ async fn transpose_shader( points: &[C], scalars: &[C::Scalar], ) -> (Vec, Vec) { + let p = compute_p::(); + let params = compute_misc_params(&p, WORD_SIZE); let input_size = scalars.len(); let chunk_size = if input_size >= 65536 { 16 } else { 4 }; let num_columns = 1 << chunk_size; let num_rows = input_size.div_ceil(num_columns); let num_subtasks = 256_usize.div_ceil(chunk_size); - let num_words = PARAMS.num_words; + let num_words = params.num_words; debug(&format!("Input size: {input_size}")); debug(&format!("Chunk size: {chunk_size}")); debug(&format!("Num columns: {num_columns}")); @@ -28,12 +30,12 @@ async fn transpose_shader( debug(&format!("Num subtasks: {num_subtasks}")); debug(&format!("Num words: {num_words}")); debug(&format!("Word size: {WORD_SIZE}")); - println!("Params: {PARAMS:?}"); + println!("Params: {params:?}"); let point_bytes = points_to_bytes(points); let scalar_bytes = scalars_to_bytes(scalars); - let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size); + let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size, ¶ms); let adapter = get_adapter().await; let (device, queue) = get_device(&adapter).await;