Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve lagrange retrieval performance & remove unsafes #16272

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/lib/crypto/kimchi_backend/pasta/pallas_based_plonk.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ end
module R1CS_constraint_system =
Kimchi_pasta_constraint_system.Pallas_constraint_system

let lagrange srs domain_log2 : _ Kimchi_types.poly_comm array =
let lagrange (srs : Kimchi_bindings.Protocol.SRS.Fq.t) domain_log2 :
_ Kimchi_types.poly_comm array =
let domain_size = Int.pow 2 domain_log2 in
Array.init domain_size ~f:(fun i ->
Kimchi_bindings.Protocol.SRS.Fq.lagrange_commitment srs domain_size i )
Kimchi_bindings.Protocol.SRS.Fq.lagrange_commitments_whole_domain srs
domain_size

let with_lagrange f (vk : Verification_key.t) =
f (lagrange vk.srs vk.domain.log_size_of_group) vk
Expand Down
4 changes: 2 additions & 2 deletions src/lib/crypto/kimchi_backend/pasta/vesta_based_plonk.ml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ module R1CS_constraint_system =

let lagrange srs domain_log2 : _ Kimchi_types.poly_comm array =
let domain_size = Int.pow 2 domain_log2 in
Array.init domain_size ~f:(fun i ->
Kimchi_bindings.Protocol.SRS.Fp.lagrange_commitment srs domain_size i )
Kimchi_bindings.Protocol.SRS.Fp.lagrange_commitments_whole_domain srs
domain_size

let with_lagrange f (vk : Verification_key.t) =
f (lagrange vk.srs vk.domain.log_size_of_group) vk
Expand Down
12 changes: 12 additions & 0 deletions src/lib/crypto/kimchi_bindings/stubs/kimchi_bindings.ml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ module Protocol = struct
-> Pasta_bindings.Fq.t Kimchi_types.or_infinity Kimchi_types.poly_comm
= "caml_fp_srs_lagrange_commitment"

external lagrange_commitments_whole_domain :
t
-> int
-> Pasta_bindings.Fq.t Kimchi_types.or_infinity Kimchi_types.poly_comm
array = "caml_fp_srs_lagrange_commitments_whole_domain"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit confusing that the one using Pasta_bindings.Fq.t is called caml_fp_srs_lagrange... , but I see this is the format being followed across this file so I suppose that notation is expected.


external add_lagrange_basis : t -> int -> unit
= "caml_fp_srs_add_lagrange_basis"

Expand Down Expand Up @@ -156,6 +162,12 @@ module Protocol = struct
-> Pasta_bindings.Fp.t Kimchi_types.or_infinity Kimchi_types.poly_comm
= "caml_fq_srs_lagrange_commitment"

external lagrange_commitments_whole_domain :
t
-> int
-> Pasta_bindings.Fp.t Kimchi_types.or_infinity Kimchi_types.poly_comm
array = "caml_fq_srs_lagrange_commitments_whole_domain"

external add_lagrange_basis : t -> int -> unit
= "caml_fq_srs_add_lagrange_basis"

Expand Down
10 changes: 5 additions & 5 deletions src/lib/crypto/kimchi_bindings/stubs/src/lagrange_basis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use poly_commitment::{commitment::CommitmentCurve, srs::SRS};
use std::env;

pub trait WithLagrangeBasis<G: AffineRepr> {
fn with_lagrange_basis(&mut self, domain: D<G::ScalarField>);
fn with_lagrange_basis(&self, domain: D<G::ScalarField>);
}

impl WithLagrangeBasis<Vesta> for SRS<Vesta> {
fn with_lagrange_basis(&mut self, domain: D<<Vesta as AffineRepr>::ScalarField>) {
fn with_lagrange_basis(&self, domain: D<<Vesta as AffineRepr>::ScalarField>) {
match env::var("LAGRANGE_CACHE_DIR") {
Ok(_) => add_lagrange_basis_with_cache(self, domain, cache::get_vesta_file_cache()),
Err(_) => {
Expand All @@ -21,7 +21,7 @@ impl WithLagrangeBasis<Vesta> for SRS<Vesta> {
}

impl WithLagrangeBasis<Pallas> for SRS<Pallas> {
fn with_lagrange_basis(&mut self, domain: D<<Pallas as AffineRepr>::ScalarField>) {
fn with_lagrange_basis(&self, domain: D<<Pallas as AffineRepr>::ScalarField>) {
match env::var("LAGRANGE_CACHE_DIR") {
Ok(_) => add_lagrange_basis_with_cache(self, domain, cache::get_pallas_file_cache()),
Err(_) => {
Expand All @@ -32,7 +32,7 @@ impl WithLagrangeBasis<Pallas> for SRS<Pallas> {
}

fn add_lagrange_basis_with_cache<G: CommitmentCurve, C: LagrangeCache<G>>(
srs: &mut SRS<G>,
srs: &SRS<G>,
domain: D<G::ScalarField>,
cache: &C,
) {
Expand All @@ -41,7 +41,7 @@ fn add_lagrange_basis_with_cache<G: CommitmentCurve, C: LagrangeCache<G>>(
return;
}
if let Some(basis) = cache.load_lagrange_basis_from_cache(srs.g.len(), &domain) {
srs.lagrange_bases.get_or_generate(n, || { basis });
srs.lagrange_bases.get_or_generate(n, || basis);
return;
} else {
let basis = srs.get_lagrange_basis(domain);
Expand Down
2 changes: 2 additions & 0 deletions src/lib/crypto/kimchi_bindings/stubs/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ fn generate_kimchi_bindings(mut w: impl std::io::Write, env: &mut Env) {
decl_func!(w, env, caml_fp_srs_write => "write");
decl_func!(w, env, caml_fp_srs_read => "read");
decl_func!(w, env, caml_fp_srs_lagrange_commitment => "lagrange_commitment");
decl_func!(w, env, caml_fp_srs_lagrange_commitments_whole_domain => "lagrange_commitments_whole_domain");
decl_func!(w, env, caml_fp_srs_add_lagrange_basis=> "add_lagrange_basis");
decl_func!(w, env, caml_fp_srs_commit_evaluations => "commit_evaluations");
decl_func!(w, env, caml_fp_srs_b_poly_commitment => "b_poly_commitment");
Expand All @@ -378,6 +379,7 @@ fn generate_kimchi_bindings(mut w: impl std::io::Write, env: &mut Env) {
decl_func!(w, env, caml_fq_srs_write => "write");
decl_func!(w, env, caml_fq_srs_read => "read");
decl_func!(w, env, caml_fq_srs_lagrange_commitment => "lagrange_commitment");
decl_func!(w, env, caml_fq_srs_lagrange_commitments_whole_domain => "lagrange_commitments_whole_domain");
decl_func!(w, env, caml_fq_srs_add_lagrange_basis=> "add_lagrange_basis");
decl_func!(w, env, caml_fq_srs_commit_evaluations => "commit_evaluations");
decl_func!(w, env, caml_fq_srs_b_poly_commitment => "b_poly_commitment");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,7 @@ pub fn caml_pasta_fp_plonk_index_create(
// endo
let (endo_q, _endo_r) = poly_commitment::srs::endos::<Pallas>();

// Unsafe if we are in a multi-core ocaml
{
let ptr: &mut poly_commitment::srs::SRS<Vesta> =
unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
ptr.with_lagrange_basis(cs.domain.d1);
}
srs.0.with_lagrange_basis(cs.domain.d1);

// create index
let mut index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.clone());
Expand Down
52 changes: 25 additions & 27 deletions src/lib/crypto/kimchi_bindings/stubs/src/pasta_fp_plonk_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ pub fn caml_pasta_fp_plonk_proof_create(
prev_sgs: Vec<CamlGVesta>,
) -> Result<CamlProofWithPublic<CamlGVesta, CamlFp>, ocaml::Error> {
{
let ptr: &mut poly_commitment::srs::SRS<Vesta> =
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
index
.as_ref()
.0
.srs
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
}

let prev = if prev_challenges.is_empty() {
Vec::new()
} else {
Expand Down Expand Up @@ -112,9 +115,11 @@ pub fn caml_pasta_fp_plonk_proof_create_and_verify(
prev_sgs: Vec<CamlGVesta>,
) -> Result<CamlProofWithPublic<CamlGVesta, CamlFp>, ocaml::Error> {
{
let ptr: &mut poly_commitment::srs::SRS<Vesta> =
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
index
.as_ref()
.0
.srs
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
}
let prev = if prev_challenges.is_empty() {
Vec::new()
Expand Down Expand Up @@ -199,7 +204,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_lookup(
polynomial::COLUMNS,
wires::Wire,
};
use poly_commitment::srs::{endos, SRS};
use poly_commitment::srs::endos;

let num_gates = 1000;
let num_tables: usize = 5;
Expand Down Expand Up @@ -276,8 +281,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_lookup(
.build()
.unwrap();

let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
ptr.with_lagrange_basis(cs.domain.d1);
srs.0.with_lagrange_basis(cs.domain.d1);

let (endo_q, _endo_r) = endos::<Pallas>();
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
Expand Down Expand Up @@ -321,7 +325,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_foreign_field_mul(
use num_bigint::BigUint;
use num_bigint::RandBigInt;
use o1_utils::{foreign_field::BigUintForeignFieldHelpers, FieldHelpers};
use poly_commitment::srs::{endos, SRS};
use poly_commitment::srs::endos;
use rand::{rngs::StdRng, SeedableRng};

let foreign_field_modulus = Fq::modulus_biguint();
Expand Down Expand Up @@ -441,8 +445,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_foreign_field_mul(
// Create constraint system
let cs = ConstraintSystem::<Fp>::create(gates).build().unwrap();

let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
ptr.with_lagrange_basis(cs.domain.d1);
srs.0.with_lagrange_basis(cs.domain.d1);

let (endo_q, _endo_r) = endos::<Pallas>();
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
Expand Down Expand Up @@ -478,7 +481,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_range_check(
use num_bigint::BigUint;
use num_bigint::RandBigInt;
use o1_utils::{foreign_field::BigUintForeignFieldHelpers, BigUintFieldHelpers};
use poly_commitment::srs::{endos, SRS};
use poly_commitment::srs::endos;
use rand::{rngs::StdRng, SeedableRng};

let rng = &mut StdRng::from_seed([255u8; 32]);
Expand Down Expand Up @@ -508,8 +511,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_range_check(
// Create constraint system
let cs = ConstraintSystem::<Fp>::create(gates).build().unwrap();

let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
ptr.with_lagrange_basis(cs.domain.d1);
srs.0.with_lagrange_basis(cs.domain.d1);

let (endo_q, _endo_r) = endos::<Pallas>();
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
Expand Down Expand Up @@ -546,7 +548,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_range_check0(
polynomials::{generic::GenericGateSpec, range_check},
wires::Wire,
};
use poly_commitment::srs::{endos, SRS};
use poly_commitment::srs::endos;

let gates = {
// Public input row with value 0
Expand Down Expand Up @@ -581,8 +583,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_range_check0(
// not sure if theres a smarter way instead of the double unwrap, but should be fine in the test
let cs = ConstraintSystem::<Fp>::create(gates).build().unwrap();

let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
ptr.with_lagrange_basis(cs.domain.d1);
srs.0.with_lagrange_basis(cs.domain.d1);

let (endo_q, _endo_r) = endos::<Pallas>();
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
Expand Down Expand Up @@ -625,7 +626,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_ffadd(
wires::Wire,
};
use num_bigint::BigUint;
use poly_commitment::srs::{endos, SRS};
use poly_commitment::srs::endos;

// Includes a row to store value 1
let num_public_inputs = 1;
Expand Down Expand Up @@ -706,8 +707,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_ffadd(
.build()
.unwrap();

let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
ptr.with_lagrange_basis(cs.domain.d1);
srs.0.with_lagrange_basis(cs.domain.d1);

let (endo_q, _endo_r) = endos::<Pallas>();
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
Expand Down Expand Up @@ -747,7 +747,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_xor(
polynomials::{generic::GenericGateSpec, xor},
wires::Wire,
};
use poly_commitment::srs::{endos, SRS};
use poly_commitment::srs::endos;

let num_public_inputs = 2;

Expand Down Expand Up @@ -795,8 +795,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_xor(
.build()
.unwrap();

let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
ptr.with_lagrange_basis(cs.domain.d1);
srs.0.with_lagrange_basis(cs.domain.d1);

let (endo_q, _endo_r) = endos::<Pallas>();
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
Expand Down Expand Up @@ -839,7 +838,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_rot(
},
wires::Wire,
};
use poly_commitment::srs::{endos, SRS};
use poly_commitment::srs::endos;

// Includes the actual input of the rotation and a row with the zero value
let num_public_inputs = 2;
Expand Down Expand Up @@ -889,8 +888,7 @@ pub fn caml_pasta_fp_plonk_proof_example_with_rot(
.build()
.unwrap();

let ptr: &mut SRS<Vesta> = unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
ptr.with_lagrange_basis(cs.domain.d1);
srs.0.with_lagrange_basis(cs.domain.d1);

let (endo_q, _endo_r) = endos::<Pallas>();
let index = ProverIndex::<Vesta, OpeningProof<Vesta>>::create(cs, endo_q, srs.0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@ pub fn caml_pasta_fp_plonk_verifier_index_write(
pub fn caml_pasta_fp_plonk_verifier_index_create(
index: CamlPastaFpPlonkIndexPtr,
) -> CamlPastaFpPlonkVerifierIndex {
{
let ptr: &mut poly_commitment::srs::SRS<Vesta> =
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
}
index
.as_ref()
.0
.srs
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
let verifier_index = index.as_ref().0.verifier_index();
verifier_index.into()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,7 @@ pub fn caml_pasta_fq_plonk_index_create(
// endo
let (endo_q, _endo_r) = poly_commitment::srs::endos::<Vesta>();

// Unsafe if we are in a multi-core ocaml
{
let ptr: &mut poly_commitment::srs::SRS<Pallas> =
unsafe { &mut *(std::sync::Arc::as_ptr(&srs.0) as *mut _) };
ptr.with_lagrange_basis(cs.domain.d1);
}
srs.0.with_lagrange_basis(cs.domain.d1);

// create index
let mut index = ProverIndex::<Pallas, OpeningProof<Pallas>>::create(cs, endo_q, srs.clone());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ pub fn caml_pasta_fq_plonk_proof_create(
prev_sgs: Vec<CamlGPallas>,
) -> Result<CamlProofWithPublic<CamlGPallas, CamlFq>, ocaml::Error> {
{
let ptr: &mut poly_commitment::srs::SRS<Pallas> =
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
index
.as_ref()
.0
.srs
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
}
let prev = if prev_challenges.is_empty() {
Vec::new()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ pub fn caml_pasta_fq_plonk_verifier_index_write(
pub fn caml_pasta_fq_plonk_verifier_index_create(
index: CamlPastaFqPlonkIndexPtr,
) -> CamlPastaFqPlonkVerifierIndex {
{
let ptr: &mut poly_commitment::srs::SRS<Pallas> =
unsafe { &mut *(std::sync::Arc::as_ptr(&index.as_ref().0.srs) as *mut _) };
ptr.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
}
index
.as_ref()
.0
.srs
.with_lagrange_basis(index.as_ref().0.cs.domain.d1);
let verifier_index = index.as_ref().0.verifier_index();
verifier_index.into()
}
Expand Down
Loading