Skip to content

Commit 996b95b

Browse files
authored
Merge pull request #16272 from MinaProtocol/volhovm/16112-lagrange-bases-one-by-one-optimisation
Improve lagrange retrieval performance & remove unsafes
2 parents 4f78a48 + f5c89f3 commit 996b95b

12 files changed

+84
-75
lines changed

src/lib/crypto/kimchi_backend/pasta/pallas_based_plonk.ml

+4-3
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ end
3434
module R1CS_constraint_system =
3535
Kimchi_pasta_constraint_system.Pallas_constraint_system
3636

37-
let lagrange srs domain_log2 : _ Kimchi_types.poly_comm array =
37+
let lagrange (srs : Kimchi_bindings.Protocol.SRS.Fq.t) domain_log2 :
38+
_ Kimchi_types.poly_comm array =
3839
let domain_size = Int.pow 2 domain_log2 in
39-
Array.init domain_size ~f:(fun i ->
40-
Kimchi_bindings.Protocol.SRS.Fq.lagrange_commitment srs domain_size i )
40+
Kimchi_bindings.Protocol.SRS.Fq.lagrange_commitments_whole_domain srs
41+
domain_size
4142

4243
let with_lagrange f (vk : Verification_key.t) =
4344
f (lagrange vk.srs vk.domain.log_size_of_group) vk

src/lib/crypto/kimchi_backend/pasta/vesta_based_plonk.ml

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ module R1CS_constraint_system =
3535

3636
let lagrange srs domain_log2 : _ Kimchi_types.poly_comm array =
3737
let domain_size = Int.pow 2 domain_log2 in
38-
Array.init domain_size ~f:(fun i ->
39-
Kimchi_bindings.Protocol.SRS.Fp.lagrange_commitment srs domain_size i )
38+
Kimchi_bindings.Protocol.SRS.Fp.lagrange_commitments_whole_domain srs
39+
domain_size
4040

4141
let with_lagrange f (vk : Verification_key.t) =
4242
f (lagrange vk.srs vk.domain.log_size_of_group) vk

src/lib/crypto/kimchi_bindings/stubs/kimchi_bindings.ml

+12
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ module Protocol = struct
107107
-> Pasta_bindings.Fq.t Kimchi_types.or_infinity Kimchi_types.poly_comm
108108
= "caml_fp_srs_lagrange_commitment"
109109

110+
external lagrange_commitments_whole_domain :
111+
t
112+
-> int
113+
-> Pasta_bindings.Fq.t Kimchi_types.or_infinity Kimchi_types.poly_comm
114+
array = "caml_fp_srs_lagrange_commitments_whole_domain"
115+
110116
external add_lagrange_basis : t -> int -> unit
111117
= "caml_fp_srs_add_lagrange_basis"
112118

@@ -156,6 +162,12 @@ module Protocol = struct
156162
-> Pasta_bindings.Fp.t Kimchi_types.or_infinity Kimchi_types.poly_comm
157163
= "caml_fq_srs_lagrange_commitment"
158164

165+
external lagrange_commitments_whole_domain :
166+
t
167+
-> int
168+
-> Pasta_bindings.Fp.t Kimchi_types.or_infinity Kimchi_types.poly_comm
169+
array = "caml_fq_srs_lagrange_commitments_whole_domain"
170+
159171
external add_lagrange_basis : t -> int -> unit
160172
= "caml_fq_srs_add_lagrange_basis"
161173

src/lib/crypto/kimchi_bindings/stubs/src/lagrange_basis.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ use poly_commitment::{commitment::CommitmentCurve, srs::SRS};
66
use std::env;
77

88
pub trait WithLagrangeBasis<G: AffineRepr> {
9-
fn with_lagrange_basis(&mut self, domain: D<G::ScalarField>);
9+
fn with_lagrange_basis(&self, domain: D<G::ScalarField>);
1010
}
1111

1212
impl WithLagrangeBasis<Vesta> for SRS<Vesta> {
13-
fn with_lagrange_basis(&mut self, domain: D<<Vesta as AffineRepr>::ScalarField>) {
13+
fn with_lagrange_basis(&self, domain: D<<Vesta as AffineRepr>::ScalarField>) {
1414
match env::var("LAGRANGE_CACHE_DIR") {
1515
Ok(_) => add_lagrange_basis_with_cache(self, domain, cache::get_vesta_file_cache()),
1616
Err(_) => {
@@ -21,7 +21,7 @@ impl WithLagrangeBasis<Vesta> for SRS<Vesta> {
2121
}
2222

2323
impl WithLagrangeBasis<Pallas> for SRS<Pallas> {
24-
fn with_lagrange_basis(&mut self, domain: D<<Pallas as AffineRepr>::ScalarField>) {
24+
fn with_lagrange_basis(&self, domain: D<<Pallas as AffineRepr>::ScalarField>) {
2525
match env::var("LAGRANGE_CACHE_DIR") {
2626
Ok(_) => add_lagrange_basis_with_cache(self, domain, cache::get_pallas_file_cache()),
2727
Err(_) => {
@@ -32,7 +32,7 @@ impl WithLagrangeBasis<Pallas> for SRS<Pallas> {
3232
}
3333

3434
fn add_lagrange_basis_with_cache<G: CommitmentCurve, C: LagrangeCache<G>>(
35-
srs: &mut SRS<G>,
35+
srs: &SRS<G>,
3636
domain: D<G::ScalarField>,
3737
cache: &C,
3838
) {
@@ -41,7 +41,7 @@ fn add_lagrange_basis_with_cache<G: CommitmentCurve, C: LagrangeCache<G>>(
4141
return;
4242
}
4343
if let Some(basis) = cache.load_lagrange_basis_from_cache(srs.g.len(), &domain) {
44-
srs.lagrange_bases.get_or_generate(n, || { basis });
44+
srs.lagrange_bases.get_or_generate(n, || basis);
4545
return;
4646
} else {
4747
let basis = srs.get_lagrange_basis(domain);

src/lib/crypto/kimchi_bindings/stubs/src/main.rs

+2
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ fn generate_kimchi_bindings(mut w: impl std::io::Write, env: &mut Env) {
363363
decl_func!(w, env, caml_fp_srs_write => "write");
364364
decl_func!(w, env, caml_fp_srs_read => "read");
365365
decl_func!(w, env, caml_fp_srs_lagrange_commitment => "lagrange_commitment");
366+
decl_func!(w, env, caml_fp_srs_lagrange_commitments_whole_domain => "lagrange_commitments_whole_domain");
366367
decl_func!(w, env, caml_fp_srs_add_lagrange_basis=> "add_lagrange_basis");
367368
decl_func!(w, env, caml_fp_srs_commit_evaluations => "commit_evaluations");
368369
decl_func!(w, env, caml_fp_srs_b_poly_commitment => "b_poly_commitment");
@@ -378,6 +379,7 @@ fn generate_kimchi_bindings(mut w: impl std::io::Write, env: &mut Env) {
378379
decl_func!(w, env, caml_fq_srs_write => "write");
379380
decl_func!(w, env, caml_fq_srs_read => "read");
380381
decl_func!(w, env, caml_fq_srs_lagrange_commitment => "lagrange_commitment");
382+
decl_func!(w, env, caml_fq_srs_lagrange_commitments_whole_domain => "lagrange_commitments_whole_domain");
381383
decl_func!(w, env, caml_fq_srs_add_lagrange_basis=> "add_lagrange_basis");
382384
decl_func!(w, env, caml_fq_srs_commit_evaluations => "commit_evaluations");
383385
decl_func!(w, env, caml_fq_srs_b_poly_commitment => "b_poly_commitment");

src/lib/crypto/kimchi_bindings/stubs/src/pasta_fp_plonk_index.rs

+1-6
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,7 @@ pub fn caml_pasta_fp_plonk_index_create(
8787
// endo
8888
let (endo_q, _endo_r) = poly_commitment::srs::endos::<Pallas>();
8989

90-
// Unsafe if we are in a multi-core ocaml
91-
{
92-
let ptr: &mut poly_commitment::srs::SRS<Vesta> =
93-
unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
94-
ptr.with_lagrange_basis(cs.domain.d1);
95-
}
90+
srs.0.with_lagrange_basis(cs.domain.d1);
9691

9792
// create index
9893
let mut index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.clone());

src/lib/crypto/kimchi_bindings/stubs/src/pasta_fp_plonk_proof.rs

+25-27
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,13 @@ pub fn caml_pasta_fp_plonk_proof_create(
4646
prev_sgs: Vec<CamlGVesta>,
4747
) -> Result<CamlProofWithPublic<CamlGVesta, CamlFp>, ocaml::Error> {
4848
{
49-
let ptr: &mut poly_commitment::srs::SRS<Vesta> =
50-
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
51-
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
49+
index
50+
.as_ref()
51+
.0
52+
.srs
53+
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
5254
}
55+
5356
let prev = if prev_challenges.is_empty() {
5457
Vec::new()
5558
} else {
@@ -112,9 +115,11 @@ pub fn caml_pasta_fp_plonk_proof_create_and_verify(
112115
prev_sgs: Vec<CamlGVesta>,
113116
) -> Result<CamlProofWithPublic<CamlGVesta, CamlFp>, ocaml::Error> {
114117
{
115-
let ptr: &mut poly_commitment::srs::SRS<Vesta> =
116-
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
117-
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
118+
index
119+
.as_ref()
120+
.0
121+
.srs
122+
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
118123
}
119124
let prev = if prev_challenges.is_empty() {
120125
Vec::new()
@@ -199,7 +204,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_lookup(
199204
polynomial::COLUMNS,
200205
wires::Wire,
201206
};
202-
use poly_commitment::srs::{endos, SRS};
207+
use poly_commitment::srs::endos;
203208

204209
let num_gates = 1000;
205210
let num_tables: usize = 5;
@@ -276,8 +281,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_lookup(
276281
.build()
277282
.unwrap();
278283

279-
let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
280-
ptr.with_lagrange_basis(cs.domain.d1);
284+
srs.0.with_lagrange_basis(cs.domain.d1);
281285

282286
let (endo_q, _endo_r) = endos::<Pallas>();
283287
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
@@ -321,7 +325,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_foreign_field_mul(
321325
use num_bigint::BigUint;
322326
use num_bigint::RandBigInt;
323327
use o1_utils::{foreign_field::BigUintForeignFieldHelpers, FieldHelpers};
324-
use poly_commitment::srs::{endos, SRS};
328+
use poly_commitment::srs::endos;
325329
use rand::{rngs::StdRng, SeedableRng};
326330

327331
let foreign_field_modulus = Fq::modulus_biguint();
@@ -441,8 +445,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_foreign_field_mul(
441445
// Create constraint system
442446
let cs = ConstraintSystem::<Fp>::create(gates).build().unwrap();
443447

444-
let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
445-
ptr.with_lagrange_basis(cs.domain.d1);
448+
srs.0.with_lagrange_basis(cs.domain.d1);
446449

447450
let (endo_q, _endo_r) = endos::<Pallas>();
448451
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
@@ -478,7 +481,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_range_check(
478481
use num_bigint::BigUint;
479482
use num_bigint::RandBigInt;
480483
use o1_utils::{foreign_field::BigUintForeignFieldHelpers, BigUintFieldHelpers};
481-
use poly_commitment::srs::{endos, SRS};
484+
use poly_commitment::srs::endos;
482485
use rand::{rngs::StdRng, SeedableRng};
483486

484487
let rng = &mut StdRng::from_seed([255u8; 32]);
@@ -508,8 +511,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_range_check(
508511
// Create constraint system
509512
let cs = ConstraintSystem::<Fp>::create(gates).build().unwrap();
510513

511-
let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
512-
ptr.with_lagrange_basis(cs.domain.d1);
514+
srs.0.with_lagrange_basis(cs.domain.d1);
513515

514516
let (endo_q, _endo_r) = endos::<Pallas>();
515517
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
@@ -546,7 +548,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_range_check0(
546548
polynomials::{generic::GenericGateSpec, range_check},
547549
wires::Wire,
548550
};
549-
use poly_commitment::srs::{endos, SRS};
551+
use poly_commitment::srs::endos;
550552

551553
let gates = {
552554
// Public input row with value 0
@@ -581,8 +583,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_range_check0(
581583
// not sure if theres a smarter way instead of the double unwrap, but should be fine in the test
582584
let cs = ConstraintSystem::<Fp>::create(gates).build().unwrap();
583585

584-
let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
585-
ptr.with_lagrange_basis(cs.domain.d1);
586+
srs.0.with_lagrange_basis(cs.domain.d1);
586587

587588
let (endo_q, _endo_r) = endos::<Pallas>();
588589
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
@@ -625,7 +626,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_ffadd(
625626
wires::Wire,
626627
};
627628
use num_bigint::BigUint;
628-
use poly_commitment::srs::{endos, SRS};
629+
use poly_commitment::srs::endos;
629630

630631
// Includes a row to store value 1
631632
let num_public_inputs = 1;
@@ -706,8 +707,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_ffadd(
706707
.build()
707708
.unwrap();
708709

709-
let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
710-
ptr.with_lagrange_basis(cs.domain.d1);
710+
srs.0.with_lagrange_basis(cs.domain.d1);
711711

712712
let (endo_q, _endo_r) = endos::<Pallas>();
713713
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
@@ -747,7 +747,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_xor(
747747
polynomials::{generic::GenericGateSpec, xor},
748748
wires::Wire,
749749
};
750-
use poly_commitment::srs::{endos, SRS};
750+
use poly_commitment::srs::endos;
751751

752752
let num_public_inputs = 2;
753753

@@ -795,8 +795,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_xor(
795795
.build()
796796
.unwrap();
797797

798-
let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
799-
ptr.with_lagrange_basis(cs.domain.d1);
798+
srs.0.with_lagrange_basis(cs.domain.d1);
800799

801800
let (endo_q, _endo_r) = endos::<Pallas>();
802801
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
@@ -839,7 +838,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_rot(
839838
},
840839
wires::Wire,
841840
};
842-
use poly_commitment::srs::{endos, SRS};
841+
use poly_commitment::srs::endos;
843842

844843
// Includes the actual input of the rotation and a row with the zero value
845844
let num_public_inputs = 2;
@@ -889,8 +888,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_rot(
889888
.build()
890889
.unwrap();
891890

892-
let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
893-
ptr.with_lagrange_basis(cs.domain.d1);
891+
srs.0.with_lagrange_basis(cs.domain.d1);
894892

895893
let (endo_q, _endo_r) = endos::<Pallas>();
896894
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);

src/lib/crypto/kimchi_bindings/stubs/src/pasta_fp_plonk_verifier_index.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,11 @@ pub fn caml_pasta_fp_plonk_verifier_index_write(
221221
pub fn caml_pasta_fp_plonk_verifier_index_create(
222222
index: CamlPastaFpPlonkIndexPtr,
223223
) -> CamlPastaFpPlonkVerifierIndex {
224-
{
225-
let ptr: &mut poly_commitment::srs::SRS<Vesta> =
226-
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
227-
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
228-
}
224+
index
225+
.as_ref()
226+
.0
227+
.srs
228+
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
229229
let verifier_index = index.as_ref().0.verifier_index();
230230
verifier_index.into()
231231
}

src/lib/crypto/kimchi_bindings/stubs/src/pasta_fq_plonk_index.rs

+1-6
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,7 @@ pub fn caml_pasta_fq_plonk_index_create(
8686
// endo
8787
let (endo_q, _endo_r) = poly_commitment::srs::endos::<Vesta>();
8888

89-
// Unsafe if we are in a multi-core ocaml
90-
{
91-
let ptr: &mut poly_commitment::srs::SRS<Pallas> =
92-
unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
93-
ptr.with_lagrange_basis(cs.domain.d1);
94-
}
89+
srs.0.with_lagrange_basis(cs.domain.d1);
9590

9691
// create index
9792
let mut index = ProverIndex::<Pallas, OpeningProof<Pallas>>::create(cs, endo_q, srs.clone());

src/lib/crypto/kimchi_bindings/stubs/src/pasta_fq_plonk_proof.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ pub fn caml_pasta_fq_plonk_proof_create(
4141
prev_sgs: Vec<CamlGPallas>,
4242
) -> Result<CamlProofWithPublic<CamlGPallas, CamlFq>, ocaml::Error> {
4343
{
44-
let ptr: &mut poly_commitment::srs::SRS<Pallas> =
45-
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
46-
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
44+
index
45+
.as_ref()
46+
.0
47+
.srs
48+
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
4749
}
4850
let prev = if prev_challenges.is_empty() {
4951
Vec::new()

src/lib/crypto/kimchi_bindings/stubs/src/pasta_fq_plonk_verifier_index.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,11 @@ pub fn caml_pasta_fq_plonk_verifier_index_write(
220220
pub fn caml_pasta_fq_plonk_verifier_index_create(
221221
index: CamlPastaFqPlonkIndexPtr,
222222
) -> CamlPastaFqPlonkVerifierIndex {
223-
{
224-
let ptr: &mut poly_commitment::srs::SRS<Pallas> =
225-
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
226-
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
227-
}
223+
index
224+
.as_ref()
225+
.0
226+
.srs
227+
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
228228
let verifier_index = index.as_ref().0.verifier_index();
229229
verifier_index.into()
230230
}

0 commit comments

Comments
 (0)