diff --git a/Cargo.lock b/Cargo.lock index 41f9f3811..303de99b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,10 +7,10 @@ name = "acir" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acir_field 1.0.0-beta.11", + "acir_field", "base64", "bincode 2.0.1", - "brillig 1.0.0-beta.11", + "brillig", "color-eyre", "flate2", "noir_protobuf", @@ -21,32 +21,11 @@ dependencies = [ "rmp-serde", "serde", "serde-big-array", - "strum 0.24.1", - "strum_macros 0.24.3", + "strum", + "strum_macros", "thiserror 1.0.69", ] -[[package]] -name = "acir" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" -dependencies = [ - "acir_field 1.0.0-beta.19", - "base64", - "brillig 1.0.0-beta.19", - "flate2", - "noirc_span", - "num-bigint", - "num-traits", - "num_enum", - "rmp-serde", - "serde", - "serde-big-array", - "strum 0.26.3", - "strum_macros 0.26.4", - "thiserror 2.0.18", -] - [[package]] name = "acir_field" version = "1.0.0-beta.11" @@ -61,28 +40,14 @@ dependencies = [ "serde", ] -[[package]] -name = "acir_field" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" -dependencies = [ - "ark-bn254", - "ark-ff 0.5.0", - "ark-std 0.5.0", - "cfg-if", - "hex", - "num-bigint", - "serde", -] - [[package]] name = "acvm" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acir 1.0.0-beta.11", - "acvm_blackbox_solver 1.0.0-beta.11", - "brillig_vm 1.0.0-beta.11", + "acir", + "acvm_blackbox_solver", + "brillig_vm", "fxhash", "indexmap 2.13.0", "serde", @@ -90,57 +55,24 @@ dependencies = [ "tracing", ] -[[package]] -name = "acvm" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" -dependencies = [ - "acir 1.0.0-beta.19", - "acvm_blackbox_solver 1.0.0-beta.19", - "brillig_vm 1.0.0-beta.19", - "indexmap 2.13.0", - "rustc-hash", - "serde", - "thiserror 2.0.18", - "tracing", -] - [[package]] name = "acvm_blackbox_solver" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acir 1.0.0-beta.11", - "blake2 0.10.6", + "acir", + "blake2", "blake3", - "k256 0.13.4", - "keccak 0.1.6", + "k256", + "keccak", "libaes", "log", "num-bigint", - "p256 0.13.2", - "sha2 0.10.9", + "p256", + "sha2", "thiserror 1.0.69", ] -[[package]] -name = "acvm_blackbox_solver" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" -dependencies = [ - "acir 1.0.0-beta.19", - "aes", - "blake2 0.11.0-rc.5", - "blake3", - "cbc", - "k256 0.14.0-rc.7", - "keccak 0.2.0-rc.2", - "log", - "p256 0.14.0-rc.7", - "sha2 0.11.0-rc.5", - "thiserror 2.0.18", -] - [[package]] name = "addr2line" version = "0.25.1" @@ -156,17 +88,6 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" -[[package]] -name = "aes" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures 0.2.17", -] - [[package]] name = "ahash" version = "0.8.12" @@ -326,41 +247,6 @@ dependencies = [ "ark-std 0.5.0", ] -[[package]] -name = "ark-crypto-primitives" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e0c292754729c8a190e50414fd1a37093c786c709899f29c9f7daccecfa855e" -dependencies = [ - "ahash", - "ark-crypto-primitives-macros", - "ark-ec", - "ark-ff 0.5.0", - "ark-relations", - "ark-serialize 0.5.0", - "ark-snark", - "ark-std 0.5.0", - "blake2 0.10.6", - "derivative", - "digest 0.10.7", - "fnv", - "hashbrown 0.14.5", - "merlin", - "rayon", - "sha2 0.10.9", -] - -[[package]] -name = "ark-crypto-primitives-macros" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7e89fe77d1f0f4fe5b96dfc940923d88d17b6a773808124f21e764dfb063c6a" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] - [[package]] name = "ark-ec" version = "0.5.0" @@ -379,7 +265,6 @@ dependencies = [ "num-bigint", "num-integer", "num-traits", - "rayon", "zeroize", ] @@ -438,7 +323,6 @@ dependencies = [ "num-bigint", "num-traits", "paste", - "rayon", "zeroize", ] @@ -537,18 +421,6 @@ dependencies = [ "hashbrown 0.15.5", ] -[[package]] -name = "ark-relations" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec46ddc93e7af44bcab5230937635b06fb5744464dd6a7e7b083e80ebd274384" -dependencies = [ - "ark-ff 0.5.0", - "ark-std 0.5.0", - "tracing", - "tracing-subscriber 0.2.25", -] - [[package]] name = "ark-serialize" version = "0.3.0" @@ -581,7 +453,6 @@ dependencies = [ "arrayvec", "digest 0.10.7", "num-bigint", - "rayon", ] [[package]] @@ -595,18 +466,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "ark-snark" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d368e2848c2d4c129ce7679a7d0d2d612b6a274d3ea6a13bad4445d61b381b88" -dependencies = [ - "ark-ff 0.5.0", - "ark-relations", - "ark-serialize 0.5.0", - "ark-std 0.5.0", -] - [[package]] name = "ark-std" version = "0.3.0" @@ -635,7 +494,6 @@ checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" dependencies = [ "num-traits", "rand 0.8.5", - "rayon", ] [[package]] @@ -653,6 +511,45 @@ dependencies = [ "zeroize", ] +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", + "synstructure", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "async-lsp" version = "0.2.3" @@ -789,12 +686,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" -[[package]] -name = "base16ct" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd307490d624467aa6f74b0eabb77633d1f758a7b25f12bceb0b22e08d9726f6" - [[package]] name = "base64" version = "0.22.1" @@ -891,21 +782,22 @@ dependencies = [ ] [[package]] -name = "blake2" -version = "0.10.6" +name = "bitvec-nom2" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +checksum = "d988fcc40055ceaa85edc55875a08f8abd29018582647fd82ad6128dba14a5f0" dependencies = [ - "digest 0.10.7", + "bitvec", + "nom", ] [[package]] name = "blake2" -version = "0.11.0-rc.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d52965399b470437fc7f4d4b51134668dbc96573fea6f1b83318a420e4605745" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" dependencies = [ - "digest 0.11.1", + "digest 0.10.7", ] [[package]] @@ -919,7 +811,7 @@ dependencies = [ "cc", "cfg-if", "constant_time_eq", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.10.7", "zeroize", ] @@ -933,24 +825,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "block-buffer" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdd35008169921d80bc60d3d0ab416eecb028c4cd653352907921d95084790be" -dependencies = [ - "hybrid-array", -] - -[[package]] -name = "block-padding" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" -dependencies = [ - "generic-array", -] - [[package]] name = "bn254-multiplier" version = "0.1.0" @@ -979,8 +853,8 @@ name = "bn254_blackbox_solver" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acir 1.0.0-beta.11", - "acvm_blackbox_solver 1.0.0-beta.11", + "acir", + "acvm_blackbox_solver", "ark-bn254", "ark-ec", "ark-ff 0.5.0", @@ -1004,16 +878,7 @@ name = "brillig" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acir_field 1.0.0-beta.11", - "serde", -] - -[[package]] -name = "brillig" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" -dependencies = [ - "acir_field 1.0.0-beta.19", + "acir_field", "serde", ] @@ -1022,25 +887,13 @@ name = "brillig_vm" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acir 1.0.0-beta.11", - "acvm_blackbox_solver 1.0.0-beta.11", + "acir", + "acvm_blackbox_solver", "num-bigint", "num-traits", "thiserror 1.0.69", ] -[[package]] -name = "brillig_vm" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" -dependencies = [ - "acir 1.0.0-beta.19", - "acvm_blackbox_solver 1.0.0-beta.19", - "num-bigint", - "num-traits", - "thiserror 2.0.18", -] - [[package]] name = "bstr" version = "1.12.1" @@ -1092,15 +945,6 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" -[[package]] -name = "cbc" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" -dependencies = [ - "cipher", -] - [[package]] name = "cc" version = "1.2.56" @@ -1172,16 +1016,6 @@ dependencies = [ "half", ] -[[package]] -name = "cipher" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" -dependencies = [ - "crypto-common 0.1.7", - "inout", -] - [[package]] name = "clap" version = "4.5.60" @@ -1243,12 +1077,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "cmov" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de0758edba32d61d1fd9f4d69491b47604b91ee2f7e6b33de7e54ca4ebe55dc3" - [[package]] name = "cobs" version = "0.3.0" @@ -1357,12 +1185,6 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" -[[package]] -name = "const-oid" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6ef517f0926dd24a1582492c791b6a4818a4d94e789a334894aa15b0d12f55c" - [[package]] name = "const_format" version = "0.2.35" @@ -1383,6 +1205,15 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "const_panic" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e262cdaac42494e3ae34c43969f9cdeb7da178bdb4b66fa6a1ea2edb4c8ae652" +dependencies = [ + "typewit", +] + [[package]] name = "constant_time_eq" version = "0.4.2" @@ -1424,12 +1255,6 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" -[[package]] -name = "cpubits" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef0c543070d296ea414df2dd7625d1b24866ce206709d8a4a424f28377f5861" - [[package]] name = "cpufeatures" version = "0.2.17" @@ -1439,15 +1264,6 @@ dependencies = [ "libc", ] -[[package]] -name = "cpufeatures" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" -dependencies = [ - "libc", -] - [[package]] name = "crc32fast" version = "1.5.0" @@ -1515,22 +1331,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "crypto-bigint" -version = "0.7.0-rc.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96dacf199529fb801ae62a9aafdc01b189e9504c0d1ee1512a4c16bcd8666a93" -dependencies = [ - "cpubits", - "ctutils", - "getrandom 0.4.2", - "hybrid-array", - "num-traits", - "rand_core 0.10.0", - "subtle", - "zeroize", -] - [[package]] name = "crypto-common" version = "0.1.7" @@ -1541,17 +1341,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "crypto-common" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77727bb15fa921304124b128af125e7e3b968275d1b108b379190264f4423710" -dependencies = [ - "getrandom 0.4.2", - "hybrid-array", - "rand_core 0.10.0", -] - [[package]] name = "csv" version = "1.4.0" @@ -1573,16 +1362,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "ctutils" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1005a6d4446f5120ef475ad3d2af2b30c49c2c9c6904258e3bb30219bebed5e4" -dependencies = [ - "cmov", - "subtle", -] - [[package]] name = "dap" version = "0.4.1-alpha1" @@ -1629,26 +1408,35 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "data-encoding" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" + [[package]] name = "der" version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" dependencies = [ - "const-oid 0.9.6", - "pem-rfc7468 0.7.0", + "const-oid", + "pem-rfc7468", "zeroize", ] [[package]] -name = "der" -version = "0.8.0" +name = "der-parser" +version = "9.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" dependencies = [ - "const-oid 0.10.2", - "pem-rfc7468 1.0.0", - "zeroize", + "asn1-rs", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", ] [[package]] @@ -1698,24 +1486,12 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer 0.10.4", - "const-oid 0.9.6", - "crypto-common 0.1.7", + "block-buffer", + "const-oid", + "crypto-common", "subtle", ] -[[package]] -name = "digest" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "285743a676ccb6b3e116bc14cc69319b957867930ae9c4822f8e0f54509d7243" -dependencies = [ - "block-buffer 0.12.0", - "const-oid 0.10.2", - "crypto-common 0.2.1", - "ctutils", -] - [[package]] name = "dirs" version = "4.0.0" @@ -1793,6 +1569,12 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "doc-comment" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780955b8b195a21ab8e4ac6b60dd1dbdcec1dc6c51c0617964b08c81785e12c9" + [[package]] name = "dyn-clone" version = "1.0.20" @@ -1820,27 +1602,12 @@ version = "0.16.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" dependencies = [ - "der 0.7.10", + "der", "digest 0.10.7", - "elliptic-curve 0.13.8", - "rfc6979 0.4.0", - "signature 2.2.0", - "spki 0.7.3", -] - -[[package]] -name = "ecdsa" -version = "0.17.0-rc.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91bbdd377139884fafcad8dc43a760a3e1e681aa26db910257fa6535b70e1829" -dependencies = [ - "der 0.8.0", - "digest 0.11.1", - "elliptic-curve 0.14.0-rc.28", - "rfc6979 0.5.0-rc.5", - "signature 3.0.0-rc.10", - "spki 0.8.0-rc.4", - "zeroize", + "elliptic-curve", + "rfc6979", + "signature", + "spki", ] [[package]] @@ -1867,38 +1634,16 @@ version = "0.13.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" dependencies = [ - "base16ct 0.2.0", - "crypto-bigint 0.5.5", + "base16ct", + "crypto-bigint", "digest 0.10.7", "ff", "generic-array", "group", - "pem-rfc7468 0.7.0", - "pkcs8 0.10.2", + "pem-rfc7468", + "pkcs8", "rand_core 0.6.4", - "sec1 0.7.3", - "subtle", - "zeroize", -] - -[[package]] -name = "elliptic-curve" -version = "0.14.0-rc.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bde7860544606d222fd6bd6d9f9a0773321bf78072a637e1d560a058c0031978" -dependencies = [ - "base16ct 1.0.0", - "crypto-bigint 0.7.0-rc.28", - "crypto-common 0.2.1", - "digest 0.11.1", - "hybrid-array", - "once_cell", - "pem-rfc7468 1.0.0", - "pkcs8 0.11.0-rc.11", - "rand_core 0.10.0", - "rustcrypto-ff", - "rustcrypto-group", - "sec1 0.8.0-rc.13", + "sec1", "subtle", "zeroize", ] @@ -2130,7 +1875,7 @@ version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ "codespan-reporting", - "iter-extended 1.0.0-beta.11", + "iter-extended", "serde", ] @@ -2421,15 +2166,6 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "allocator-api2", -] - [[package]] name = "hashbrown" version = "0.15.5" @@ -2508,15 +2244,6 @@ dependencies = [ "digest 0.10.7", ] -[[package]] -name = "hmac" -version = "0.13.0-rc.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef451d73f36d8a3f93ad32c332ea01146c9650e1ec821a9b0e46c01277d544f8" -dependencies = [ - "digest 0.11.1", -] - [[package]] name = "http" version = "1.4.0" @@ -2562,17 +2289,6 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" -[[package]] -name = "hybrid-array" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b229d73f5803b562cc26e4da0396c8610a4ee209f4fac8fa4f8d709166dc45" -dependencies = [ - "subtle", - "typenum", - "zeroize", -] - [[package]] name = "hyper" version = "1.8.1" @@ -2885,16 +2601,6 @@ dependencies = [ "libc", ] -[[package]] -name = "inout" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" -dependencies = [ - "block-padding", - "generic-array", -] - [[package]] name = "inplace-vec-builder" version = "0.1.1" @@ -2942,11 +2648,6 @@ name = "iter-extended" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" -[[package]] -name = "iter-extended" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" - [[package]] name = "itertools" version = "0.10.5" @@ -3105,6 +2806,12 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "jzon" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17ab85f84ca42c5ec520e6f3c9966ba1fd62909ce260f8837e248857d2560509" + [[package]] name = "k256" version = "0.13.4" @@ -3112,42 +2819,40 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6e3919bbaa2945715f0bb6d3934a173d1e9a59ac23767fbaaef277265a7411b" dependencies = [ "cfg-if", - "ecdsa 0.16.9", - "elliptic-curve 0.13.8", + "ecdsa", + "elliptic-curve", "once_cell", - "sha2 0.10.9", - "signature 2.2.0", + "sha2", + "signature", ] [[package]] -name = "k256" -version = "0.14.0-rc.7" +name = "keccak" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83da23da11f0b5db6f23d9280a84b3a33a746aa43ebb9270d6b445991da9cee3" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" dependencies = [ - "cpubits", - "ecdsa 0.17.0-rc.16", - "elliptic-curve 0.14.0-rc.28", - "sha2 0.11.0-rc.5", - "signature 3.0.0-rc.10", + "cpufeatures", ] [[package]] -name = "keccak" -version = "0.1.6" +name = "konst" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +checksum = "4381b9b00c55f251f2ebe9473aef7c117e96828def1a7cb3bd3f0f903c6894e9" dependencies = [ - "cpufeatures 0.2.17", + "const_panic", + "konst_kernel", + "typewit", ] [[package]] -name = "keccak" -version = "0.2.0-rc.2" +name = "konst_kernel" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "882b69cb15b1f78b51342322a97ccd16f5123d1dc8a3da981a95244f488e8692" +checksum = "e4b1eb7788f3824c629b1116a7a9060d6e898c358ebff59070093d51103dcc3c" dependencies = [ - "cpufeatures 0.3.0", + "typewit", ] [[package]] @@ -3175,6 +2880,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin 0.9.8", +] [[package]] name = "leb128fmt" @@ -3266,7 +2974,7 @@ dependencies = [ "generator", "scoped-tls", "tracing", - "tracing-subscriber 0.3.22", + "tracing-subscriber", ] [[package]] @@ -3333,11 +3041,10 @@ checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" [[package]] name = "mavros-artifacts" version = "0.1.0" -source = "git+https://github.com/reilabs/mavros?branch=split_main#2fe1fe3cafb2df3c46a220f3aa78f47acf63eeef" +source = "git+https://github.com/reilabs/mavros?rev=3e47fd58001a0109a0314bc080b5246fd807ba04#3e47fd58001a0109a0314bc080b5246fd807ba04" dependencies = [ "ark-bn254", "ark-ff 0.5.0", - "bincode 1.3.3", "serde", "tracing", ] @@ -3345,14 +3052,11 @@ dependencies = [ [[package]] name = "mavros-vm" version = "0.1.0" -source = "git+https://github.com/reilabs/mavros?branch=split_main#2fe1fe3cafb2df3c46a220f3aa78f47acf63eeef" +source = "git+https://github.com/reilabs/mavros?rev=3e47fd58001a0109a0314bc080b5246fd807ba04#3e47fd58001a0109a0314bc080b5246fd807ba04" dependencies = [ - "ark-bn254", "ark-ff 0.5.0", "mavros-artifacts", - "noirc_abi 1.0.0-beta.19", "opcode-gen", - "serde", "tracing", ] @@ -3371,24 +3075,18 @@ dependencies = [ "autocfg", ] -[[package]] -name = "merlin" -version = "3.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58c38e2799fc0978b65dfff8023ec7843e2330bb462f19198840b34b6582397d" -dependencies = [ - "byteorder", - "keccak 0.1.6", - "rand_core 0.6.4", - "zeroize", -] - [[package]] name = "mime" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -3433,16 +3131,16 @@ name = "nargo" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "fm", - "iter-extended 1.0.0-beta.11", + "iter-extended", "jsonrpsee", "noir_greybox_fuzzer", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", "noirc_driver", "noirc_errors", "noirc_frontend", - "noirc_printable_type 1.0.0-beta.11", + "noirc_printable_type", "rand 0.8.5", "rayon", "serde", @@ -3459,7 +3157,7 @@ name = "nargo_cli" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "async-lsp", "bn254_blackbox_solver", "build-data", @@ -3471,7 +3169,7 @@ dependencies = [ "fm", "fs2", "fxhash", - "iter-extended 1.0.0-beta.11", + "iter-extended", "nargo", "nargo_expand", "nargo_fmt", @@ -3480,7 +3178,7 @@ dependencies = [ "noir_ast_fuzzer", "noir_debugger", "noir_lsp", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", "noirc_artifacts", "noirc_artifacts_info", "noirc_driver", @@ -3503,7 +3201,7 @@ dependencies = [ "tower", "tracing", "tracing-appender", - "tracing-subscriber 0.3.22", + "tracing-subscriber", ] [[package]] @@ -3589,30 +3287,43 @@ dependencies = [ "memoffset", ] +[[package]] +name = "noir-bignum-paramgen" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a2f214558ab24dd9af1d905187b42024370600b9458be46ab34c9af2f11e441" +dependencies = [ + "hex", + "itoa", + "num-bigint-dig", + "num-integer", + "num-traits", +] + [[package]] name = "noir_artifact_cli" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acir 1.0.0-beta.11", - "acvm 1.0.0-beta.11", + "acir", + "acvm", "bn254_blackbox_solver", "clap", "color-eyre", "const_format", "fm", "nargo", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", "noirc_artifacts", "noirc_artifacts_info", "noirc_driver", "noirc_errors", - "noirc_printable_type 1.0.0-beta.11", + "noirc_printable_type", "serde", "serde_json", "thiserror 1.0.69", "toml 0.7.8", - "tracing-subscriber 0.3.22", + "tracing-subscriber", ] [[package]] @@ -3620,18 +3331,18 @@ name = "noir_ast_fuzzer" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acir 1.0.0-beta.11", - "acvm 1.0.0-beta.11", + "acir", + "acvm", "arbitrary", "bn254_blackbox_solver", "build-data", "color-eyre", "im", - "iter-extended 1.0.0-beta.11", + "iter-extended", "log", "nargo", "noir_greybox_fuzzer", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", "noirc_driver", "noirc_errors", "noirc_evaluator", @@ -3639,7 +3350,7 @@ dependencies = [ "proptest", "rand 0.8.5", "regex", - "strum 0.24.1", + "strum", ] [[package]] @@ -3647,7 +3358,7 @@ name = "noir_debugger" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "bn254_blackbox_solver", "build-data", "codespan-reporting", @@ -3658,7 +3369,7 @@ dependencies = [ "noirc_artifacts", "noirc_driver", "noirc_errors", - "noirc_printable_type 1.0.0-beta.11", + "noirc_printable_type", "owo-colors", "thiserror 1.0.69", ] @@ -3668,10 +3379,10 @@ name = "noir_greybox_fuzzer" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "build-data", "fm", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", "noirc_artifacts", "num-traits", "proptest", @@ -3688,14 +3399,14 @@ name = "noir_lsp" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "async-lsp", "codespan-lsp", "convert_case", "fm", "fuzzy-matcher", "fxhash", - "iter-extended 1.0.0-beta.11", + "iter-extended", "nargo", "nargo_expand", "nargo_fmt", @@ -3707,7 +3418,7 @@ dependencies = [ "rayon", "serde", "serde_json", - "strum 0.24.1", + "strum", "thiserror 1.0.69", "tower", "wasm-bindgen", @@ -3727,9 +3438,9 @@ name = "noirc_abi" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", - "iter-extended 1.0.0-beta.11", - "noirc_printable_type 1.0.0-beta.11", + "acvm", + "iter-extended", + "noirc_printable_type", "num-bigint", "num-traits", "serde", @@ -3738,22 +3449,6 @@ dependencies = [ "toml 0.7.8", ] -[[package]] -name = "noirc_abi" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" -dependencies = [ - "acvm 1.0.0-beta.19", - "iter-extended 1.0.0-beta.19", - "noirc_printable_type 1.0.0-beta.19", - "num-bigint", - "num-traits", - "serde", - "serde_json", - "thiserror 2.0.18", - "toml 0.8.23", -] - [[package]] name = "noirc_arena" version = "1.0.0-beta.11" @@ -3764,13 +3459,13 @@ name = "noirc_artifacts" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "codespan-reporting", "fm", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", "noirc_driver", "noirc_errors", - "noirc_printable_type 1.0.0-beta.11", + "noirc_printable_type", "serde", ] @@ -3779,10 +3474,10 @@ name = "noirc_artifacts_info" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acir 1.0.0-beta.11", - "acvm 1.0.0-beta.11", + "acir", + "acvm", "clap", - "iter-extended 1.0.0-beta.11", + "iter-extended", "noirc_artifacts", "prettytable-rs", "rayon", @@ -3795,13 +3490,13 @@ name = "noirc_driver" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "build-data", "clap", "fm", "fxhash", - "iter-extended 1.0.0-beta.11", - "noirc_abi 1.0.0-beta.11", + "iter-extended", + "noirc_abi", "noirc_errors", "noirc_evaluator", "noirc_frontend", @@ -3815,14 +3510,14 @@ name = "noirc_errors" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "base64", "codespan", "codespan-reporting", "flate2", "fm", "fxhash", - "noirc_printable_type 1.0.0-beta.11", + "noirc_printable_type", "serde", "serde_json", "tracing", @@ -3833,17 +3528,17 @@ name = "noirc_evaluator" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "bn254_blackbox_solver", "cfg-if", "chrono", "fm", "fxhash", "im", - "iter-extended 1.0.0-beta.11", + "iter-extended", "noirc_errors", "noirc_frontend", - "noirc_printable_type 1.0.0-beta.11", + "noirc_printable_type", "num-bigint", "num-integer", "num-traits", @@ -3863,16 +3558,16 @@ name = "noirc_frontend" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", + "acvm", "bn254_blackbox_solver", "cfg-if", "fm", "fxhash", "im", - "iter-extended 1.0.0-beta.11", + "iter-extended", "noirc_arena", "noirc_errors", - "noirc_printable_type 1.0.0-beta.11", + "noirc_printable_type", "num-bigint", "num-traits", "petgraph 0.8.3", @@ -3882,8 +3577,8 @@ dependencies = [ "serde_json", "small-ord-set", "smol_str", - "strum 0.24.1", - "strum_macros 0.24.3", + "strum", + "strum_macros", "thiserror 1.0.69", "tracing", ] @@ -3893,30 +3588,20 @@ name = "noirc_printable_type" version = "1.0.0-beta.11" source = "git+https://github.com/noir-lang/noir?rev=v1.0.0-beta.11#fd3925aaaeb76c76319f44590d135498ef41ea6c" dependencies = [ - "acvm 1.0.0-beta.11", - "iter-extended 1.0.0-beta.11", - "serde", - "serde_json", -] - -[[package]] -name = "noirc_printable_type" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" -dependencies = [ - "acvm 1.0.0-beta.19", - "iter-extended 1.0.0-beta.19", + "acvm", + "iter-extended", "serde", "serde_json", ] [[package]] -name = "noirc_span" -version = "1.0.0-beta.19" -source = "git+https://github.com/noir-lang/noir.git?branch=master#efd7f97e84c659a92cc8b7c501c578fbb718df37" +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" dependencies = [ - "codespan", - "serde", + "memchr", + "minimal-lexical", ] [[package]] @@ -3969,17 +3654,33 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.61.2", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", ] [[package]] -name = "num-bigint" -version = "0.4.6" +name = "num-bigint-dig" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +checksum = "e661dda6640fad38e827a6d4a310ff4763082116fe217f279885c97f511bb0b7" dependencies = [ + "lazy_static", + "libm", "num-integer", + "num-iter", "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", ] [[package]] @@ -3997,6 +3698,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -4044,6 +3756,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -4059,7 +3780,7 @@ checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "opcode-gen" version = "0.1.0" -source = "git+https://github.com/reilabs/mavros?branch=split_main#2fe1fe3cafb2df3c46a220f3aa78f47acf63eeef" +source = "git+https://github.com/reilabs/mavros?rev=3e47fd58001a0109a0314bc080b5246fd807ba04#3e47fd58001a0109a0314bc080b5246fd807ba04" dependencies = [ "proc-macro2", "quote", @@ -4134,23 +3855,10 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" dependencies = [ - "ecdsa 0.16.9", - "elliptic-curve 0.13.8", - "primeorder 0.13.6", - "sha2 0.10.9", -] - -[[package]] -name = "p256" -version = "0.14.0-rc.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "018bfbb86e05fd70a83e985921241035ee09fcd369c4a2c3680b389a01d2ad28" -dependencies = [ - "ecdsa 0.17.0-rc.16", - "elliptic-curve 0.14.0-rc.28", - "primefield", - "primeorder 0.14.0-rc.7", - "sha2 0.11.0-rc.5", + "ecdsa", + "elliptic-curve", + "primeorder", + "sha2", ] [[package]] @@ -4355,6 +4063,37 @@ dependencies = [ "windows-link", ] +[[package]] +name = "passport-input-gen" +version = "0.1.0" +dependencies = [ + "anyhow", + "argh", + "ark-bn254", + "ark-ff 0.5.0", + "base64", + "chrono", + "hex", + "lazy_static", + "noir-bignum-paramgen", + "noirc_abi", + "poseidon2", + "provekit-common", + "provekit-prover", + "rasn", + "rasn-cms", + "rasn-pkix", + "rsa", + "serde", + "serde_json", + "sha2", + "signature", + "thiserror 2.0.18", + "tracing", + "tracing-subscriber", + "x509-parser", +] + [[package]] name = "paste" version = "1.0.15" @@ -4370,15 +4109,6 @@ dependencies = [ "base64ct", ] -[[package]] -name = "pem-rfc7468" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6305423e0e7738146434843d1694d621cce767262b2a86910beab705e4493d9" -dependencies = [ - "base64ct", -] - [[package]] name = "percent-encoding" version = "2.3.2" @@ -4450,23 +4180,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] -name = "pkcs8" -version = "0.10.2" +name = "pkcs1" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" dependencies = [ - "der 0.7.10", - "spki 0.7.3", + "der", + "pkcs8", + "spki", ] [[package]] name = "pkcs8" -version = "0.11.0-rc.11" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12922b6296c06eb741b02d7b5161e3aaa22864af38dfa025a1a3ba3f68c84577" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" dependencies = [ - "der 0.8.0", - "spki 0.8.0-rc.4", + "der", + "spki", ] [[package]] @@ -4481,6 +4212,15 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" +[[package]] +name = "poseidon2" +version = "0.1.0" +dependencies = [ + "ark-bn254", + "ark-ff 0.5.0", + "ark-std 0.5.0", +] + [[package]] name = "postcard" version = "1.1.3" @@ -4542,36 +4282,13 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "primefield" -version = "0.14.0-rc.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93401c13cc7ff24684571cfca9d3cf9ebabfaf3d4b7b9963ade41ec54da196b5" -dependencies = [ - "crypto-bigint 0.7.0-rc.28", - "crypto-common 0.2.1", - "rand_core 0.10.0", - "rustcrypto-ff", - "subtle", - "zeroize", -] - [[package]] name = "primeorder" version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" dependencies = [ - "elliptic-curve 0.13.8", -] - -[[package]] -name = "primeorder" -version = "0.14.0-rc.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c5c8a39bcd764bfedf456e8d55e115fe86dda3e0f555371849f2a41cbc9706" -dependencies = [ - "elliptic-curve 0.14.0-rc.28", + "elliptic-curve", ] [[package]] @@ -4753,22 +4470,18 @@ checksum = "95067976aca6421a523e491fce939a3e65249bac4b977adee0ee9771568e8aa3" name = "provekit-bench" version = "0.1.0" dependencies = [ + "acir", "anyhow", "ark-ff 0.5.0", - "ark-std 0.5.0", "divan", "nargo", "nargo_cli", "nargo_toml", - "noir_artifact_cli", - "noirc_abi 1.0.0-beta.11", - "noirc_artifacts", "noirc_driver", "provekit-common", "provekit-prover", "provekit-r1cs-compiler", "provekit-verifier", - "rand 0.9.2", "serde", "test-case", "toml 0.8.23", @@ -4779,13 +4492,13 @@ dependencies = [ name = "provekit-cli" version = "0.1.0" dependencies = [ - "acir 1.0.0-beta.11", + "acir", "anyhow", "argh", "ark-ff 0.5.0", "base64", "hex", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", "postcard", "provekit-common", "provekit-gnark", @@ -4793,11 +4506,10 @@ dependencies = [ "provekit-r1cs-compiler", "provekit-verifier", "rayon", - "serde", "serde_json", "tikv-jemallocator", "tracing", - "tracing-subscriber 0.3.22", + "tracing-subscriber", "tracing-tracy", ] @@ -4805,7 +4517,7 @@ dependencies = [ name = "provekit-common" version = "0.1.0" dependencies = [ - "acir 1.0.0-beta.11", + "acir", "anyhow", "ark-bn254", "ark-ff 0.5.0", @@ -4817,10 +4529,8 @@ dependencies = [ "itertools 0.14.0", "mavros-artifacts", "mavros-vm", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", "postcard", - "rand 0.8.5", - "rand 0.9.2", "rayon", "ruint", "serde", @@ -4832,7 +4542,6 @@ dependencies = [ "whir", "xz2", "zerocopy", - "zeroize", "zstd", ] @@ -4840,10 +4549,8 @@ dependencies = [ name = "provekit-ffi" version = "0.1.0" dependencies = [ - "acir 1.0.0-beta.11", "anyhow", "libc", - "noirc_abi 1.0.0-beta.11", "parking_lot", "provekit-common", "provekit-prover", @@ -4866,9 +4573,8 @@ dependencies = [ name = "provekit-prover" version = "0.1.0" dependencies = [ - "acir 1.0.0-beta.11", + "acir", "anyhow", - "ark-crypto-primitives", "ark-ff 0.5.0", "ark-std 0.5.0", "bn254_blackbox_solver", @@ -4876,14 +4582,10 @@ dependencies = [ "mavros-vm", "nargo", "noir_artifact_cli", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", + "num-bigint", "postcard", "provekit-common", - "rand 0.9.2", - "rayon", - "skyscraper", - "spongefish", - "spongefish-pow", "tracing", "whir", ] @@ -4892,17 +4594,17 @@ dependencies = [ name = "provekit-r1cs-compiler" version = "0.1.0" dependencies = [ - "acir 1.0.0-beta.11", + "acir", "anyhow", "ark-bn254", - "ark-crypto-primitives", "ark-ff 0.5.0", "ark-std 0.5.0", "bincode 1.3.3", "mavros-artifacts", - "noirc_abi 1.0.0-beta.11", + "noirc_abi", "noirc_artifacts", "ntt", + "poseidon2", "postcard", "provekit-common", "serde", @@ -5098,6 +4800,63 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "973443cf09a9c8656b574a866ab68dfa19f0867d0340648c7d2f6a71b8a8ea68" +[[package]] +name = "rasn" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5379b720091e4bf4a9f118eb46f4ffb67bb8b7551649528c89e265cf880e748" +dependencies = [ + "arrayvec", + "bitvec", + "bitvec-nom2", + "bytes", + "chrono", + "either", + "jzon", + "konst", + "nom", + "num-bigint", + "num-integer", + "num-traits", + "once_cell", + "rasn-derive", + "snafu", +] + +[[package]] +name = "rasn-cms" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee9c688bd3aa3db270834720ab22b2862cd07ed094c4b2262bfb74a91008681e" +dependencies = [ + "rasn", + "rasn-pkix", +] + +[[package]] +name = "rasn-derive" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e521162112419405837a6590b327f24707ce9f9b3ac9c9c4a4d10673b63abcd8" +dependencies = [ + "either", + "itertools 0.10.5", + "proc-macro2", + "quote", + "rayon", + "syn 1.0.109", + "uuid", +] + +[[package]] +name = "rasn-pkix" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9f74a31343c2fd11da94025b8dcbeb96bfb207b4d480db99ad5554c117448fa" +dependencies = [ + "rasn", +] + [[package]] name = "rayon" version = "1.11.0" @@ -5263,17 +5022,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" dependencies = [ - "hmac 0.12.1", - "subtle", -] - -[[package]] -name = "rfc6979" -version = "0.5.0-rc.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23a3127ee32baec36af75b4107082d9bd823501ec14a4e016be4b6b37faa74ae" -dependencies = [ - "hmac 0.13.0-rc.5", + "hmac", "subtle", ] @@ -5320,6 +5069,27 @@ dependencies = [ "serde", ] +[[package]] +name = "rsa" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8573f03f5883dcaebdfcf4725caa1ecb9c15b2ef50c43a07b816e06799bb12d" +dependencies = [ + "const-oid", + "digest 0.10.7", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core 0.6.4", + "sha2", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "ruint" version = "1.17.2" @@ -5384,7 +5154,7 @@ version = "8.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5bcdef0be6fe7f6fa333b1073c949729274b05f123a0ad7efcb8efd878e5c3b1" dependencies = [ - "sha2 0.10.9", + "sha2", "walkdir", ] @@ -5425,24 +5195,12 @@ dependencies = [ ] [[package]] -name = "rustcrypto-ff" -version = "0.14.0-rc.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5db129183b2c139d7d87d08be57cba626c715789db17aec65c8866bfd767d1f" -dependencies = [ - "rand_core 0.10.0", - "subtle", -] - -[[package]] -name = "rustcrypto-group" -version = "0.14.0-rc.0" +name = "rusticata-macros" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c4b1463f274a3ff6fb2f44da43e576cb9424367bd96f185ead87b52fe00523" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" dependencies = [ - "rand_core 0.10.0", - "rustcrypto-ff", - "subtle", + "nom", ] [[package]] @@ -5710,24 +5468,10 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" dependencies = [ - "base16ct 0.2.0", - "der 0.7.10", + "base16ct", + "der", "generic-array", - "pkcs8 0.10.2", - "subtle", - "zeroize", -] - -[[package]] -name = "sec1" -version = "0.8.0-rc.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2400ed44a13193820aa528a19f376c3843141a8ce96ff34b11104cc79763f2" -dependencies = [ - "base16ct 1.0.0", - "ctutils", - "der 0.8.0", - "hybrid-array", + "pkcs8", "subtle", "zeroize", ] @@ -5918,22 +5662,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.10.7", "sha2-asm", ] -[[package]] -name = "sha2" -version = "0.11.0-rc.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c5f3b1e2dc8aad28310d8410bd4d7e180eca65fca176c52ab00d364475d0024" -dependencies = [ - "cfg-if", - "cpufeatures 0.2.17", - "digest 0.11.1", -] - [[package]] name = "sha2-asm" version = "0.6.4" @@ -5952,7 +5685,7 @@ dependencies = [ "async-trait", "bytes", "hex", - "sha2 0.10.9", + "sha2", ] [[package]] @@ -5962,7 +5695,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" dependencies = [ "digest 0.10.7", - "keccak 0.1.6", + "keccak", ] [[package]] @@ -6006,16 +5739,6 @@ dependencies = [ "rand_core 0.6.4", ] -[[package]] -name = "signature" -version = "3.0.0-rc.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f1880df446116126965eeec169136b2e0251dba37c6223bcc819569550edea3" -dependencies = [ - "digest 0.11.1", - "rand_core 0.10.0", -] - [[package]] name = "simd-adler32" version = "0.3.8" @@ -6108,6 +5831,29 @@ dependencies = [ "serde_core", ] +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "backtrace", + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "socket2" version = "0.6.2" @@ -6149,17 +5895,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" dependencies = [ "base64ct", - "der 0.7.10", -] - -[[package]] -name = "spki" -version = "0.8.0-rc.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8baeff88f34ed0691978ec34440140e1572b68c7dd4a495fd14a3dc1944daa80" -dependencies = [ - "base64ct", - "der 0.8.0", + "der", ] [[package]] @@ -6171,10 +5907,10 @@ dependencies = [ "ark-serialize 0.5.0", "blake3", "digest 0.10.7", - "keccak 0.1.6", + "keccak", "p3-koala-bear", "rand 0.8.5", - "sha2 0.10.9", + "sha2", "sha3", "zeroize", ] @@ -6186,7 +5922,7 @@ source = "git+https://github.com/arkworks-rs/spongefish?rev=fcc277f8a857fdeeadd7 dependencies = [ "blake3", "bytemuck", - "keccak 0.1.6", + "keccak", "rand 0.8.5", "rayon", "spongefish", @@ -6234,12 +5970,6 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" -[[package]] -name = "strum" -version = "0.26.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" - [[package]] name = "strum_macros" version = "0.24.3" @@ -6253,19 +5983,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "strum_macros" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.117", -] - [[package]] name = "subtle" version = "2.6.1" @@ -6777,7 +6494,7 @@ dependencies = [ "crossbeam-channel", "thiserror 2.0.18", "time", - "tracing-subscriber 0.3.22", + "tracing-subscriber", ] [[package]] @@ -6808,7 +6525,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b1581020d7a273442f5b45074a6a57d5757ad0a47dac0e9f0bd57b81936f3db" dependencies = [ "tracing", - "tracing-subscriber 0.3.22", + "tracing-subscriber", ] [[package]] @@ -6832,15 +6549,6 @@ dependencies = [ "tracing-core", ] -[[package]] -name = "tracing-subscriber" -version = "0.2.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" -dependencies = [ - "tracing-core", -] - [[package]] name = "tracing-subscriber" version = "0.3.22" @@ -6869,7 +6577,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eaa1852afa96e0fe9e44caa53dc0bd2d9d05e0f2611ce09f97f8677af56e4ba" dependencies = [ "tracing-core", - "tracing-subscriber 0.3.22", + "tracing-subscriber", "tracy-client", ] @@ -6925,6 +6633,21 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "typewit" +version = "1.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" +dependencies = [ + "typewit_proc_macros", +] + +[[package]] +name = "typewit_proc_macros" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e36a83ea2b3c704935a01b4642946aadd445cea40b10935e3f8bd8052b8193d6" + [[package]] name = "ucd-trie" version = "0.1.7" @@ -7028,6 +6751,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" +dependencies = [ + "getrandom 0.4.2", +] + [[package]] name = "valuable" version = "0.1.1" @@ -7067,13 +6799,13 @@ dependencies = [ "reqwest", "serde", "serde_json", - "sha2 0.10.9", + "sha2", "tokio", "tokio-util", "tower", "tower-http", "tracing", - "tracing-subscriber 0.3.22", + "tracing-subscriber", ] [[package]] @@ -7298,7 +7030,7 @@ dependencies = [ "blake3", "ciborium", "clap", - "const-oid 0.9.6", + "const-oid", "derive-where", "digest 0.10.7", "hex", @@ -7308,7 +7040,7 @@ dependencies = [ "rayon", "serde", "serde_json", - "sha2 0.10.9", + "sha2", "sha3", "spongefish", "static_assertions", @@ -7835,6 +7567,23 @@ dependencies = [ "tap", ] +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs", + "data-encoding", + "der-parser", + "lazy_static", + "nom", + "oid-registry", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "xz2" version = "0.1.7" diff --git a/Cargo.toml b/Cargo.toml index ca25762d1..25d8672f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,6 +113,7 @@ chrono = "0.4.41" divan = "0.1.21" hex = "0.4.3" itertools = "0.14.0" +num-bigint = "0.4" paste = "1.0.15" postcard = { version = "1.1.1", features = ["use-std"] } primitive-types = "0.13.1" diff --git a/noir-examples/embedded_curve_msm/Nargo.toml b/noir-examples/embedded_curve_msm/Nargo.toml new file mode 100644 index 000000000..ec9891616 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "embedded_curve_msm" +type = "bin" +authors = [""] +compiler_version = ">=0.22.0" + +[dependencies] diff --git a/noir-examples/embedded_curve_msm/Prover.toml b/noir-examples/embedded_curve_msm/Prover.toml new file mode 100644 index 000000000..edf585681 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover.toml @@ -0,0 +1,9 @@ +# ============================================================ +# MSM test vectors: result = s1 * G + s2 * G +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# ============================================================ +# n - 2, n - 3 (previously failing with [2]G offset) +scalar1_lo = "201385395114098847380338600778089168197" +scalar1_hi = "64323764613183177041862057485226039389" +scalar2_lo = "201385395114098847380338600778089168196" +scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/Prover_near_identity.toml b/noir-examples/embedded_curve_msm/Prover_near_identity.toml new file mode 100644 index 000000000..a156f96ee --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_near_identity.toml @@ -0,0 +1,6 @@ +# MSM edge case: n-2 and n-3 (previously failing with [2]G offset) +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +scalar1_lo = "201385395114098847380338600778089168197" +scalar1_hi = "64323764613183177041862057485226039389" +scalar2_lo = "201385395114098847380338600778089168196" +scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/Prover_near_order.toml b/noir-examples/embedded_curve_msm/Prover_near_order.toml new file mode 100644 index 000000000..d8ae04eb5 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_near_order.toml @@ -0,0 +1,6 @@ +# MSM edge case: near-max scalars (n-10 and n-20) +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +scalar1_lo = "201385395114098847380338600778089168189" +scalar1_hi = "64323764613183177041862057485226039389" +scalar2_lo = "201385395114098847380338600778089168179" +scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/Prover_single_nonzero.toml b/noir-examples/embedded_curve_msm/Prover_single_nonzero.toml new file mode 100644 index 000000000..8455db356 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_single_nonzero.toml @@ -0,0 +1,5 @@ +# MSM edge case: one zero scalar, one non-zero (0*G + 5*G = 5*G) +scalar1_lo = "0" +scalar1_hi = "0" +scalar2_lo = "5" +scalar2_hi = "0" diff --git a/noir-examples/embedded_curve_msm/Prover_zero_scalars.toml b/noir-examples/embedded_curve_msm/Prover_zero_scalars.toml new file mode 100644 index 000000000..0bd8866c7 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_zero_scalars.toml @@ -0,0 +1,5 @@ +# MSM edge case: all-zero scalars (0*G + 0*G = point at infinity) +scalar1_lo = "0" +scalar1_hi = "0" +scalar2_lo = "0" +scalar2_hi = "0" diff --git a/noir-examples/embedded_curve_msm/src/main.nr b/noir-examples/embedded_curve_msm/src/main.nr new file mode 100644 index 000000000..19a193181 --- /dev/null +++ b/noir-examples/embedded_curve_msm/src/main.nr @@ -0,0 +1,52 @@ +use std::embedded_curve_ops::{ + EmbeddedCurvePoint, + EmbeddedCurveScalar, + multi_scalar_mul, +}; + +/// Exercises the MultiScalarMul ACIR blackbox with 2 Grumpkin points. +/// Computes s1 * G + s2 * G where G is the Grumpkin generator. +fn main( + scalar1_lo: Field, + scalar1_hi: Field, + scalar2_lo: Field, + scalar2_hi: Field, +) { + // Grumpkin generator + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + + let s1 = EmbeddedCurveScalar { lo: scalar1_lo, hi: scalar1_hi }; + let s2 = EmbeddedCurveScalar { lo: scalar2_lo, hi: scalar2_hi }; + + // MSM: result = s1 * G + s2 * G + let result = multi_scalar_mul([g, g], [s1, s2]); + + // Prevent dead-code elimination - forces the blackbox to be retained + // Using is_infinite as return value ensures the MSM is computed + assert(result.is_infinite == (scalar1_lo + scalar1_hi + scalar2_lo + scalar2_hi == 0)); +} + +#[test] +fn test_msm() { + // 3*G on Grumpkin + let expected_x = 18660890509582237958343981571981920822503400000196279471655180441138020044621; + let expected_y = 8902249110305491597038405103722863701255802573786510474664632793109847672620; + + main(1, 0, 2, 0); + + // Verify by computing independently: 3*G should match + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + let s3 = EmbeddedCurveScalar { lo: 3, hi: 0 }; + let check = multi_scalar_mul([g], [s3]); + + assert(check.x == expected_x); + assert(check.y == expected_y); +} diff --git a/noir-examples/native_msm/Nargo.toml b/noir-examples/native_msm/Nargo.toml new file mode 100644 index 000000000..6b16fd3ae --- /dev/null +++ b/noir-examples/native_msm/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "native_msm" +type = "bin" +authors = [""] + +[dependencies] diff --git a/noir-examples/native_msm/Prover.toml b/noir-examples/native_msm/Prover.toml new file mode 100644 index 000000000..58c6933da --- /dev/null +++ b/noir-examples/native_msm/Prover.toml @@ -0,0 +1,5 @@ +# MSM: result = s1 * G + s2 * G = 1*G + 2*G = 3*G +scalar1_lo = "1" +scalar1_hi = "0" +scalar2_lo = "2" +scalar2_hi = "0" diff --git a/noir-examples/native_msm/src/main.nr b/noir-examples/native_msm/src/main.nr new file mode 100644 index 000000000..901722a4e --- /dev/null +++ b/noir-examples/native_msm/src/main.nr @@ -0,0 +1,357 @@ +global GRUMPKIN_GEN_Y: Field = 17631683881184975370165255887551781615748388533673675138860; + +// Hardcoded offset generators: offset = 5*G, offset_final = 2^252 * 5*G +// These are compile-time constants -- no unconstrained computation, no runtime verification needed. +global OFFSET_X: Field = 12229279139087521908560794489267966517139449915173592433539394009359081620359; +global OFFSET_Y: Field = 12096995292699515952722386974733884667125946823386040531322131902193094989869; +global OFFSET_FINAL_X: Field = 17097678145015848904467691187715743297134903912023447344174597163323183228319; +global OFFSET_FINAL_Y: Field = 14560299638432262069836824301755786319891239433999125203465365199349384123743; + +// BN254 scalar field modulus as wNAF slices (MSB first). +// Used for lexicographic range check: ensures wNAF integer < p, preventing mod-p ambiguity. +global MODULUS_SLICES: [u8; 64] = [ + 9, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, + 13, 12, 2, 8, 2, 2, 13, 11, 4, 0, 12, 0, 10, 12, 2, 14, + 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8, + 10, 1, 15, 0, 15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0, +]; + +struct GPoint { x: Field, y: Field } +struct GPointResult { x: Field, y: Field, is_infinity: bool } +struct Hint { lambda: Field, x3: Field, y3: Field } + +// ====== Constrained EC verification ====== + +// ~4 constraints: verify point doubling +fn c_double(p: GPoint, h: Hint) -> GPoint { + let xx = p.x * p.x; + assert(h.lambda * (p.y + p.y) == 3 * xx); // a=0 for Grumpkin + assert(h.lambda * h.lambda == h.x3 + p.x + p.x); + assert(h.lambda * (p.x - h.x3) == h.y3 + p.y); + GPoint { x: h.x3, y: h.y3 } +} + +// ~4 constraints: verify point addition (incomplete -- offset generator prevents edge cases) +fn c_add(p1: GPoint, p2: GPoint, h: Hint) -> GPoint { + assert(p1.x != p2.x); + assert(h.lambda * (p2.x - p1.x) == p2.y - p1.y); + assert(h.lambda * h.lambda == h.x3 + p1.x + p2.x); + assert(h.lambda * (p1.x - h.x3) == h.y3 + p1.y); + GPoint { x: h.x3, y: h.y3 } +} + +// ====== Unconstrained witness generation ====== + +unconstrained fn u_double(p: GPoint) -> (GPoint, Hint) { + let lambda = (3 * p.x * p.x) / (2 * p.y); + let x3 = lambda * lambda - 2 * p.x; + let y3 = lambda * (p.x - x3) - p.y; + (GPoint { x: x3, y: y3 }, Hint { lambda, x3, y3 }) +} + +unconstrained fn u_add(p1: GPoint, p2: GPoint) -> (GPoint, Hint) { + let lambda = (p2.y - p1.y) / (p2.x - p1.x); + let x3 = lambda * lambda - p1.x - p2.x; + let y3 = lambda * (p1.x - x3) - p1.y; + (GPoint { x: x3, y: y3 }, Hint { lambda, x3, y3 }) +} + +// Unconstrained complete addition hint (handles add, double, and inverse cases) +unconstrained fn u_complete_add_hint(p1: GPoint, p2: GPoint) -> Hint { + if p1.x == p2.x { + if p1.y == p2.y { + let (_, hint) = u_double(p1); + hint + } else { + Hint { lambda: 0, x3: 0, y3: 0 } + } + } else { + let (_, hint) = u_add(p1, p2); + hint + } +} + +// Unconstrained full addition hint (handles infinity inputs) +unconstrained fn u_full_add_hint( + p1: GPoint, p1_inf: bool, + p2: GPoint, p2_inf: bool, +) -> Hint { + if p1_inf | p2_inf { + Hint { lambda: 0, x3: 0, y3: 0 } + } else { + u_complete_add_hint(p1, p2) + } +} + +// Constrained complete addition: handles add, double, and inverse-point cases. +// Both inputs must be valid on-curve points (not identity). +// Uses `active * constraint == 0` pattern so constraints are trivially satisfied +// when the result is the identity (inverse-point case). +fn c_complete_add(p1: GPoint, p2: GPoint, h: Hint) -> GPointResult { + let x_eq: bool = p1.x == p2.x; + let y_eq: bool = p1.y == p2.y; + let is_infinity: bool = x_eq & !y_eq; + let is_double: bool = x_eq & y_eq; + let active: Field = (!is_infinity) as Field; + + let lambda_lhs = if is_double { p1.y + p1.y } else { p2.x - p1.x }; + let lambda_rhs = if is_double { 3 * p1.x * p1.x } else { p2.y - p1.y }; + assert(active * (h.lambda * lambda_lhs - lambda_rhs) == 0); + + // x3 verification: lambda^2 = x3 + x1 + x2 (same for add and double since x2=x1 when doubling) + assert(active * (h.lambda * h.lambda - h.x3 - p1.x - p2.x) == 0); + + // y3 verification: lambda * (x1 - x3) = y3 + y1 + assert(active * (h.lambda * (p1.x - h.x3) - h.y3 - p1.y) == 0); + + GPointResult { x: h.x3 * active, y: h.y3 * active, is_infinity } +} + +// Constrained full addition: handles all cases including identity inputs. +// This is used for the final MSM sum where either operand may be the identity. +fn c_full_add( + p1: GPoint, p1_inf: bool, + p2: GPoint, p2_inf: bool, + h: Hint, +) -> GPointResult { + let neither_inf = !p1_inf & !p2_inf; + let both_inf = p1_inf & p2_inf; + let only_p1_inf = p1_inf & !p2_inf; + + // EC constraints are only active when neither input is identity. + let ec_active: Field = neither_inf as Field; + + // Determine add/double/inverse case (only meaningful when neither is identity). + // Guard with neither_inf so garbage coordinates from identity points don't affect predicates. + let x_eq: bool = (p1.x == p2.x) & neither_inf; + let y_eq: bool = (p1.y == p2.y) & neither_inf; + let is_inf_from_add: bool = x_eq & !y_eq; + let is_double: bool = x_eq & y_eq; + let arith_active: Field = ec_active * (!is_inf_from_add as Field); + + // Lambda, x3, y3 constraints (zeroed when inactive) + let lambda_lhs = if is_double { p1.y + p1.y } else { p2.x - p1.x }; + let lambda_rhs = if is_double { 3 * p1.x * p1.x } else { p2.y - p1.y }; + assert(arith_active * (h.lambda * lambda_lhs - lambda_rhs) == 0); + assert(arith_active * (h.lambda * h.lambda - h.x3 - p1.x - p2.x) == 0); + assert(arith_active * (h.lambda * (p1.x - h.x3) - h.y3 - p1.y) == 0); + + // Output selection + let result_is_inf: bool = both_inf | is_inf_from_add; + + let out_x = if result_is_inf { 0 } + else if only_p1_inf { p2.x } + else if p2_inf { p1.x } + else { h.x3 }; + let out_y = if result_is_inf { 0 } + else if only_p1_inf { p2.y } + else if p2_inf { p1.y } + else { h.y3 }; + + GPointResult { x: out_x, y: out_y, is_infinity: result_is_inf } +} + +unconstrained fn decompose_wnaf(scalar_lo: Field, scalar_hi: Field) -> ([u8; 64], bool) { + let lo_bytes = scalar_lo.to_le_bytes::<16>(); + let hi_bytes = scalar_hi.to_le_bytes::<16>(); + let mut nibbles: [u8; 64] = [0; 64]; + for i in 0..16 { + nibbles[2 * i] = lo_bytes[i] & 0x0F; + nibbles[2 * i + 1] = lo_bytes[i] >> 4; + } + for i in 0..16 { + nibbles[32 + 2 * i] = hi_bytes[i] & 0x0F; + nibbles[32 + 2 * i + 1] = hi_bytes[i] >> 4; + } + let skew: bool = (nibbles[0] & 1) == 0; + nibbles[0] = nibbles[0] + (skew as u8); + let mut slices: [u8; 64] = [0; 64]; + slices[63] = (nibbles[0] + 15) / 2; + for i in 1..64 { + let nibble = nibbles[i]; + slices[63 - i] = (nibble + 15) / 2; + if (nibble & 1) == 0 { + slices[63 - i] += 1; + slices[64 - i] -= 8; + } + } + (slices, skew) +} + +// 326 hints per scalar mul: 8 table + 1 init + 63*5 loop + 1 skew + 1 final +unconstrained fn compute_transcript( + P: GPoint, slices: [u8; 64], skew: bool, + offset: GPoint, offset_final: GPoint, +) -> [Hint; 326] { + let mut h: [Hint; 326] = [Hint { lambda: 0, x3: 0, y3: 0 }; 326]; + let mut p: u32 = 0; + + // Table: 2P, then P+2P, 3P+2P, ... + let (d2, d2h) = u_double(P); + h[p] = d2h; p += 1; + let mut table: [GPoint; 16] = [GPoint { x: 0, y: 0 }; 16]; + table[8] = P; + table[7] = GPoint { x: P.x, y: 0 - P.y }; + let mut A = P; + for i in 1..8 { + let (s, sh) = u_add(A, d2); h[p] = sh; p += 1; + A = s; + table[8 + i] = A; + table[7 - i] = GPoint { x: A.x, y: 0 - A.y }; + } + + // Init: offset + T[slices[0]] + let (ir, ih) = u_add(offset, table[slices[0] as u32]); + h[p] = ih; p += 1; + let mut acc = ir; + + // 63 windows: 4 doubles + 1 add each + for _w in 1..64 { + for _ in 0..4 { let (d, dh) = u_double(acc); h[p] = dh; p += 1; acc = d; } + let tp = table[slices[_w] as u32]; + let (s, sh) = u_add(acc, tp); h[p] = sh; p += 1; acc = s; + } + + // Skew correction (always compute valid hint even if unused) + let neg_P = GPoint { x: P.x, y: 0 - P.y }; + let (sr, sh) = u_add(acc, neg_P); + h[p] = sh; p += 1; + if skew { acc = sr; } + + // Final offset subtraction (complete -- handles identity result when scalar = 0) + let neg_off = GPoint { x: offset_final.x, y: 0 - offset_final.y }; + h[p] = u_complete_add_hint(acc, neg_off); + + h +} + +// ====== Scalar range check ====== +// Lexicographic comparison: ensures wNAF slices represent an integer < field modulus. +// Without this, a prover could encode scalar + k*p (for k != 0) using valid 4-bit slices, +// since the Horner reconstruction only checks equality mod p. +// Mirrors noir_bigcurve's `compare_scalar_field_to_bignum`. +fn assert_slices_less_than_modulus(slices: [u8; 64]) { + let mut found_strictly_less: bool = false; + for i in 0..64 { + if !found_strictly_less { + let s = slices[i]; + let m = MODULUS_SLICES[i]; + // If we find a digit strictly less than modulus digit, scalar < modulus -- done. + if s as u8 < m { + found_strictly_less = true; + } else { + // If strictly greater at any position (without prior strictly-less), scalar >= modulus. + assert(s == m, "wNAF scalar exceeds field modulus"); + } + } + } + // If all digits equal, scalar == modulus, which is also invalid (must be strictly less). + assert(found_strictly_less, "wNAF scalar equals field modulus"); +} + +// ====== Main scalar multiplication ====== + +fn scalar_mul_wnaf(P: GPoint, scalar_lo: Field, scalar_hi: Field) -> GPointResult { + // 1. Decompose scalar into wNAF slices + // Safety: slices and skew are fully constrained below (range, reconstruction, and modulus bound) + let (slices, skew) = unsafe { decompose_wnaf(scalar_lo, scalar_hi) }; + + // Range check: each slice fits in 4 bits + for i in 0..64 { (slices[i] as Field).assert_max_bit_size::<4>(); } + + // Soundness fix #1: scalar range check -- slices represent integer < field modulus + assert_slices_less_than_modulus(slices); + + // Reconstruction check: wNAF Horner evaluation == scalar_lo + scalar_hi * 2^128 + let mut r: Field = 0; + for i in 0..64 { r = r * 16; r += (slices[i] as Field) * 2 - 15; } + r -= skew as Field; + + let lo_bits: [u1; 128] = scalar_lo.to_le_bits(); + let hi_bits: [u1; 128] = scalar_hi.to_le_bits(); + let mut expected: Field = 0; + let mut pow: Field = 1; + for i in 0..128 { expected += (lo_bits[i] as Field) * pow; pow *= 2; } + let two_128: Field = pow; + let mut hi_val: Field = 0; + pow = 1; + for i in 0..128 { hi_val += (hi_bits[i] as Field) * pow; pow *= 2; } + expected += hi_val * two_128; + assert(r == expected); + + // 2. Offset generators -- hardcoded compile-time constants (soundness fix #2) + let offset = GPoint { x: OFFSET_X, y: OFFSET_Y }; + let offset_final = GPoint { x: OFFSET_FINAL_X, y: OFFSET_FINAL_Y }; + + // 3. Transcript of EC operation hints + // Safety: every hint is verified by a constrained c_double or c_add call below + let hints = unsafe { compute_transcript(P, slices, skew, offset, offset_final) }; + let mut hp: u32 = 0; + + // 4. Build 16-entry lookup table: T[8]=P, T[9]=3P, ..., T[15]=15P, T[7]=-P, ..., T[0]=-15P + let d2 = c_double(P, hints[hp]); hp += 1; + let mut tx: [Field; 16] = [0; 16]; + let mut ty: [Field; 16] = [0; 16]; + tx[8] = P.x; ty[8] = P.y; + tx[7] = P.x; ty[7] = 0 - P.y; + let mut A = P; + for i in 1..8 { + A = c_add(A, d2, hints[hp]); hp += 1; + tx[8 + i] = A.x; ty[8 + i] = A.y; + tx[7 - i] = A.x; ty[7 - i] = 0 - A.y; + } + + // 5. Init accumulator: offset + T[slices[0]] + let first = GPoint { x: tx[slices[0] as u32], y: ty[slices[0] as u32] }; + let mut acc = c_add(offset, first, hints[hp]); hp += 1; + + // 6. Main wNAF loop: 63 windows * (4 doublings + 1 table add) + for _w in 1..64 { + acc = c_double(acc, hints[hp]); hp += 1; + acc = c_double(acc, hints[hp]); hp += 1; + acc = c_double(acc, hints[hp]); hp += 1; + acc = c_double(acc, hints[hp]); hp += 1; + let tp = GPoint { x: tx[slices[_w] as u32], y: ty[slices[_w] as u32] }; + acc = c_add(acc, tp, hints[hp]); hp += 1; + } + + // 7. Skew correction: if scalar was even, subtract P + let neg_P = GPoint { x: P.x, y: 0 - P.y }; + let skew_r = c_add(acc, neg_P, hints[hp]); hp += 1; + acc = if skew { skew_r } else { acc }; + + // 8. Subtract accumulated offset: result = acc - 2^252 * offset + // Uses complete addition to handle the identity result (scalar = 0 mod group_order) + let neg_off = GPoint { x: offset_final.x, y: 0 - offset_final.y }; + c_complete_add(acc, neg_off, hints[hp]) +} + +/// 2-point MSM on Grumpkin: s1*G + s2*G +fn main( + scalar1_lo: pub Field, scalar1_hi: pub Field, + scalar2_lo: pub Field, scalar2_hi: pub Field, +) -> pub (Field, Field, bool) { + let g = GPoint { x: 1, y: GRUMPKIN_GEN_Y }; + + let r1 = scalar_mul_wnaf(g, scalar1_lo, scalar1_hi); + let r2 = scalar_mul_wnaf(g, scalar2_lo, scalar2_hi); + + // Full addition: handles r1 == r2 (doubling), r1 == -r2 (identity), and identity inputs + let add_hint = unsafe { + u_full_add_hint( + GPoint { x: r1.x, y: r1.y }, r1.is_infinity, + GPoint { x: r2.x, y: r2.y }, r2.is_infinity, + ) + }; + let result = c_full_add( + GPoint { x: r1.x, y: r1.y }, r1.is_infinity, + GPoint { x: r2.x, y: r2.y }, r2.is_infinity, + add_hint, + ); + + // Verify result is on Grumpkin (skip for identity) + let on_curve = result.y * result.y - (result.x * result.x * result.x - 17); + assert((!result.is_infinity as Field) * on_curve == 0); + + (result.x, result.y, result.is_infinity) +} diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 29ec40b5f..7a7de6b46 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -9,6 +9,7 @@ mod r1cs; pub mod skyscraper; pub mod sparse_matrix; mod transcript_sponge; +pub mod u256_arith; pub mod utils; mod verifier; mod whir_r1cs; diff --git a/provekit/common/src/u256_arith.rs b/provekit/common/src/u256_arith.rs new file mode 100644 index 000000000..a41c99463 --- /dev/null +++ b/provekit/common/src/u256_arith.rs @@ -0,0 +1,214 @@ +//! 256-bit unsigned integer modular arithmetic. +//! +//! Shared across r1cs-compiler (compile-time EC point precomputation) and +//! prover (witness solving). Pure `[u64; 4]` arithmetic with no external +//! dependencies. + +/// 256-bit unsigned integer as 4 little-endian u64 limbs. +pub type U256 = [u64; 4]; + +/// Integer ceiling of log2. +/// ceil_log2(1) = 0, ceil_log2(2) = 1, ceil_log2(3) = 2, ceil_log2(4) = 2. +pub fn ceil_log2(n: u64) -> u32 { + assert!(n > 0, "ceil_log2(0) is undefined"); + u64::BITS - (n - 1).leading_zeros() +} + +/// Returns true if a >= b. +pub fn gte(a: &U256, b: &U256) -> bool { + for i in (0..4).rev() { + if a[i] > b[i] { + return true; + } + if a[i] < b[i] { + return false; + } + } + true // equal +} + +/// a + b, returns (result, carry). +pub fn add(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut carry = 0u128; + for i in 0..4 { + carry += a[i] as u128 + b[i] as u128; + result[i] = carry as u64; + carry >>= 64; + } + (result, carry != 0) +} + +/// a - b, returns (result, borrow). +pub fn sub(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + (result, borrow) +} + +/// (a + b) mod p. +pub fn mod_add(a: &U256, b: &U256, p: &U256) -> U256 { + let (s, overflow) = add(a, b); + if overflow || gte(&s, p) { + sub(&s, p).0 + } else { + s + } +} + +/// (a - b) mod p. +pub fn mod_sub(a: &U256, b: &U256, p: &U256) -> U256 { + let (d, borrow) = sub(a, b); + if borrow { + add(&d, p).0 + } else { + d + } +} + +/// Schoolbook multiplication: 4×4 limbs → 8 limbs (256-bit × 256-bit → +/// 512-bit). +pub fn widening_mul(a: &U256, b: &U256) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let prod = (a[i] as u128) * (b[j] as u128) + result[i + j] as u128 + carry; + result[i + j] = prod as u64; + carry = prod >> 64; + } + result[i + 4] = result[i + 4].wrapping_add(carry as u64); + } + result +} + +/// Reduce a 512-bit value mod a 256-bit prime using bit-by-bit long division. +pub fn reduce_wide(wide: &[u64; 8], p: &U256) -> U256 { + let mut total_bits = 0; + for i in (0..8).rev() { + if wide[i] != 0 { + total_bits = i * 64 + (64 - wide[i].leading_zeros() as usize); + break; + } + } + if total_bits == 0 { + return [0; 4]; + } + + let mut r = [0u64; 4]; + for bit_idx in (0..total_bits).rev() { + // Left shift r by 1 + let overflow = r[3] >> 63; + for j in (1..4).rev() { + r[j] = (r[j] << 1) | (r[j - 1] >> 63); + } + r[0] <<= 1; + + // Insert current bit + let word = bit_idx / 64; + let bit = bit_idx % 64; + r[0] |= (wide[word] >> bit) & 1; + + // If r >= p (or overflow from shift), subtract p + if overflow != 0 || gte(&r, p) { + r = sub(&r, p).0; + } + } + r +} + +/// (a * b) mod p. +pub fn mod_mul(a: &U256, b: &U256, p: &U256) -> U256 { + let wide = widening_mul(a, b); + reduce_wide(&wide, p) +} + +/// a^exp mod p using square-and-multiply. +pub fn mod_pow(base: &U256, exp: &U256, p: &U256) -> U256 { + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [1, 0, 0, 0]; + } + + let mut result: U256 = [1, 0, 0, 0]; + let mut base = *base; + for bit_idx in 0..highest_bit { + let word = bit_idx / 64; + let bit = bit_idx % 64; + if (exp[word] >> bit) & 1 == 1 { + result = mod_mul(&result, &base, p); + } + base = mod_mul(&base, &base, p); + } + result +} + +/// a^(-1) mod p via Fermat's little theorem: a^(p-2) mod p. +pub fn mod_inv(a: &U256, p: &U256) -> U256 { + let two: U256 = [2, 0, 0, 0]; + let exp = sub(p, &two).0; + mod_pow(a, &exp, p) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_add_no_carry() { + let a: U256 = [1, 0, 0, 0]; + let b: U256 = [2, 0, 0, 0]; + let (r, c) = add(&a, &b); + assert_eq!(r, [3, 0, 0, 0]); + assert!(!c); + } + + #[test] + fn test_add_carry() { + let a: U256 = [u64::MAX, 0, 0, 0]; + let b: U256 = [1, 0, 0, 0]; + let (r, c) = add(&a, &b); + assert_eq!(r, [0, 1, 0, 0]); + assert!(!c); + } + + #[test] + fn test_sub_no_borrow() { + let a: U256 = [5, 0, 0, 0]; + let b: U256 = [3, 0, 0, 0]; + let (r, borrow) = sub(&a, &b); + assert_eq!(r, [2, 0, 0, 0]); + assert!(!borrow); + } + + #[test] + fn test_mod_mul_small() { + let a: U256 = [7, 0, 0, 0]; + let b: U256 = [6, 0, 0, 0]; + let p: U256 = [11, 0, 0, 0]; + // 7 * 6 = 42 mod 11 = 9 + assert_eq!(mod_mul(&a, &b, &p), [9, 0, 0, 0]); + } + + #[test] + fn test_mod_inv_small() { + let a: U256 = [3, 0, 0, 0]; + let p: U256 = [11, 0, 0, 0]; + let inv = mod_inv(&a, &p); + // 3^(-1) mod 11 = 4 (since 3*4 = 12 = 1 mod 11) + assert_eq!(inv, [4, 0, 0, 0]); + assert_eq!(mod_mul(&a, &inv, &p), [1, 0, 0, 0]); + } +} diff --git a/provekit/common/src/witness/limbs.rs b/provekit/common/src/witness/limbs.rs new file mode 100644 index 000000000..34dee98f7 --- /dev/null +++ b/provekit/common/src/witness/limbs.rs @@ -0,0 +1,129 @@ +//! `Limbs`: fixed-capacity, `Copy` array of witness indices with push-based +//! construction. + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// Maximum number of limbs supported. +pub const MAX_LIMBS: usize = 32; + +/// A fixed-capacity `Copy` array of witness indices, indexed by limb position. +/// +/// Construction uses `push()` to append elements sequentially, preventing +/// uninitialized access. Slots beyond `len` are never reachable through the +/// public API. +#[derive(Clone, Copy)] +pub struct Limbs { + data: [usize; MAX_LIMBS], + len: usize, +} + +impl Limbs { + /// Create an empty `Limbs`. Use [`Self::push`] to add elements. + pub fn new() -> Self { + Self { + data: [0; MAX_LIMBS], + len: 0, + } + } + + /// Append a witness index. Panics if capacity (`MAX_LIMBS`) is exceeded. + pub fn push(&mut self, value: usize) { + assert!( + self.len < MAX_LIMBS, + "Limbs overflow: cannot push beyond {MAX_LIMBS} elements" + ); + self.data[self.len] = value; + self.len += 1; + } + + /// Create a single-limb `Limbs` wrapping one witness index. + pub fn single(value: usize) -> Self { + let mut l = Self::new(); + l.push(value); + l + } + + /// View the active limbs as a slice. + pub fn as_slice(&self) -> &[usize] { + &self.data[..self.len] + } + + /// Number of active limbs. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.len + } +} + +impl Default for Limbs { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for Limbs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list().entries(self.as_slice().iter()).finish() + } +} + +impl PartialEq for Limbs { + fn eq(&self, other: &Self) -> bool { + self.len == other.len && self.data[..self.len] == other.data[..other.len] + } +} +impl Eq for Limbs {} + +impl From<&[usize]> for Limbs { + fn from(slice: &[usize]) -> Self { + assert!( + slice.len() <= MAX_LIMBS, + "Limbs: slice length {} exceeds MAX_LIMBS ({MAX_LIMBS})", + slice.len() + ); + let mut l = Self::new(); + for &v in slice { + l.push(v); + } + l + } +} + +impl std::ops::Index for Limbs { + type Output = usize; + fn index(&self, i: usize) -> &usize { + assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &self.data[i] + } +} + +impl std::ops::IndexMut for Limbs { + fn index_mut(&mut self, i: usize) -> &mut usize { + assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &mut self.data[i] + } +} + +/// Serialize only the active elements (same wire format as `Vec`). +impl Serialize for Limbs { + fn serialize(&self, serializer: S) -> Result { + self.as_slice().serialize(serializer) + } +} + +/// Deserialize from a variable-length sequence (same wire format as +/// `Vec`). +impl<'de> Deserialize<'de> for Limbs { + fn deserialize>(deserializer: D) -> Result { + let v: Vec = Vec::deserialize(deserializer)?; + Ok(Limbs::from(v.as_slice())) + } +} diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index f7cf80db2..b0dc929c1 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -1,5 +1,6 @@ mod binops; mod digits; +mod limbs; mod ram; mod scheduling; mod witness_builder; @@ -16,11 +17,14 @@ use { pub use { binops::{BINOP_ATOMIC_BITS, BINOP_BITS, NUM_DIGITS}, digits::{decompose_into_digits, DigitalDecompositionWitnesses}, + limbs::{Limbs, MAX_LIMBS}, ram::{SpiceMemoryOperation, SpiceWitnesses}, - scheduling::{Layer, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders}, + scheduling::{ + Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, + }, witness_builder::{ - CombinedTableEntryInverseData, ConstantTerm, ProductLinearTerm, SumTerm, WitnessBuilder, - WitnessCoefficient, + CombinedTableEntryInverseData, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, + WitnessBuilder, WitnessCoefficient, }, witness_generator::NoirWitnessGenerator, }; diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index a5cbaefd6..7e511f834 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -1,7 +1,7 @@ use { crate::witness::{ - ConstantOrR1CSWitness, ConstantTerm, ProductLinearTerm, SumTerm, WitnessBuilder, - WitnessCoefficient, + ConstantOrR1CSWitness, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, + WitnessBuilder, WitnessCoefficient, }, std::collections::HashMap, }; @@ -76,9 +76,15 @@ impl DependencyInfo { | WitnessBuilder::Acir(..) | WitnessBuilder::Challenge(_) => vec![], WitnessBuilder::Sum(_, ops) => ops.iter().map(|SumTerm(_, idx)| *idx).collect(), + WitnessBuilder::SumQuotient { terms, .. } => { + terms.iter().map(|SumTerm(_, idx)| *idx).collect() + } WitnessBuilder::Product(_, a, b) => vec![*a, *b], WitnessBuilder::MultiplicitiesForRange(_, _, values) => values.clone(), - WitnessBuilder::Inverse(_, x) => vec![*x], + WitnessBuilder::Inverse(_, x) + | WitnessBuilder::SafeInverse(_, x) + | WitnessBuilder::ModularInverse(_, x, _) + | WitnessBuilder::IntegerQuotient(_, x, _) => vec![*x], WitnessBuilder::IndexedLogUpDenominator( _, sz, @@ -152,6 +158,28 @@ impl DependencyInfo { } v } + WitnessBuilder::MultiLimbMulModHint { + a_limbs, b_limbs, .. + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbModularInverse { a_limbs, .. } => a_limbs.clone(), + WitnessBuilder::MultiLimbAddQuotient { + a_limbs, b_limbs, .. + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbSubBorrow { + a_limbs, b_limbs, .. + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } WitnessBuilder::BytePartition { x, .. } => vec![*x], WitnessBuilder::U32AdditionMulti(_, _, inputs) => inputs @@ -198,6 +226,32 @@ impl DependencyInfo { data.rs_cubed, ] } + WitnessBuilder::EcDoubleHint { px, py, .. } => vec![*px, *py], + WitnessBuilder::EcAddHint { x1, y1, x2, y2, .. } => vec![*x1, *y1, *x2, *y2], + WitnessBuilder::NonNativeEcHint { inputs, .. } => { + inputs.iter().flat_map(|l| l.as_slice()).copied().collect() + } + WitnessBuilder::FakeGLVHint { s_lo, s_hi, .. } => vec![*s_lo, *s_hi], + WitnessBuilder::EcScalarMulHint { + px_limbs, + py_limbs, + s_lo, + s_hi, + .. + } => px_limbs + .iter() + .chain(py_limbs.iter()) + .copied() + .chain([*s_lo, *s_hi]) + .collect(), + WitnessBuilder::SelectWitness { + flag, + on_false, + on_true, + .. + } => vec![*flag, *on_false, *on_true], + WitnessBuilder::BooleanOr { a, b, .. } => vec![*a, *b], + WitnessBuilder::SignedBitHint { scalar, .. } => vec![*scalar], WitnessBuilder::ChunkDecompose { packed, .. } => vec![*packed], WitnessBuilder::SpreadWitness(_, input) => vec![*input], WitnessBuilder::SpreadBitExtract { sum_terms, .. } => { @@ -240,6 +294,9 @@ impl DependencyInfo { | WitnessBuilder::Challenge(idx) | WitnessBuilder::IndexedLogUpDenominator(idx, ..) | WitnessBuilder::Inverse(idx, _) + | WitnessBuilder::SafeInverse(idx, _) + | WitnessBuilder::ModularInverse(idx, ..) + | WitnessBuilder::IntegerQuotient(idx, ..) | WitnessBuilder::ProductLinearOperation(idx, ..) | WitnessBuilder::LogUpDenominator(idx, ..) | WitnessBuilder::LogUpInverse(idx, ..) @@ -254,6 +311,14 @@ impl DependencyInfo { | WitnessBuilder::SpreadWitness(idx, ..) | WitnessBuilder::SpreadLookupDenominator(idx, ..) | WitnessBuilder::SpreadTableQuotient { idx, .. } => vec![*idx], + WitnessBuilder::SumQuotient { output, .. } => vec![*output], + WitnessBuilder::SelectWitness { output, .. } + | WitnessBuilder::BooleanOr { output, .. } => vec![*output], + WitnessBuilder::SignedBitHint { + output_start, + num_bits, + .. + } => (*output_start..*output_start + *num_bits + 1).collect(), WitnessBuilder::MultiplicitiesForRange(start, range, _) => { (*start..*start + *range).collect() @@ -282,6 +347,47 @@ impl DependencyInfo { let n = 1usize << *num_bits; (*start..*start + n).collect() } + WitnessBuilder::MultiLimbMulModHint { + output_start, + num_limbs, + .. + } => { + let count = (4 * *num_limbs - 2) as usize; + (*output_start..*output_start + count).collect() + } + WitnessBuilder::MultiLimbModularInverse { + output_start, + num_limbs, + .. + } => (*output_start..*output_start + *num_limbs as usize).collect(), + WitnessBuilder::EcDoubleHint { output_start, .. } => { + (*output_start..*output_start + 3).collect() + } + WitnessBuilder::EcAddHint { output_start, .. } => { + (*output_start..*output_start + 3).collect() + } + WitnessBuilder::NonNativeEcHint { + output_start, + num_limbs, + op, + .. + } => { + let count = match op { + NonNativeEcOp::Double | NonNativeEcOp::Add => (15 * *num_limbs - 6) as usize, + NonNativeEcOp::OnCurve => (9 * *num_limbs - 4) as usize, + }; + (*output_start..*output_start + count).collect() + } + WitnessBuilder::FakeGLVHint { output_start, .. } => { + (*output_start..*output_start + 4).collect() + } + WitnessBuilder::EcScalarMulHint { + output_start, + num_limbs, + .. + } => (*output_start..*output_start + 2 * *num_limbs as usize).collect(), + WitnessBuilder::MultiLimbAddQuotient { output, .. } => vec![*output], + WitnessBuilder::MultiLimbSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { vec![*result_idx, *carry_idx] } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 9503847a3..389a14358 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -2,8 +2,8 @@ use { crate::{ sparse_matrix::SparseMatrix, witness::{ - scheduling::DependencyInfo, ConstantOrR1CSWitness, ConstantTerm, ProductLinearTerm, - SumTerm, WitnessBuilder, WitnessCoefficient, + limbs::Limbs, scheduling::DependencyInfo, ConstantOrR1CSWitness, ConstantTerm, + ProductLinearTerm, SumTerm, WitnessBuilder, WitnessCoefficient, }, R1CS, }, @@ -115,6 +115,30 @@ impl WitnessIndexRemapper { WitnessBuilder::Inverse(idx, operand) => { WitnessBuilder::Inverse(self.remap(*idx), self.remap(*operand)) } + WitnessBuilder::SafeInverse(idx, operand) => { + WitnessBuilder::SafeInverse(self.remap(*idx), self.remap(*operand)) + } + WitnessBuilder::ModularInverse(idx, operand, modulus) => { + WitnessBuilder::ModularInverse(self.remap(*idx), self.remap(*operand), *modulus) + } + WitnessBuilder::IntegerQuotient(idx, dividend, divisor) => { + WitnessBuilder::IntegerQuotient(self.remap(*idx), self.remap(*dividend), *divisor) + } + WitnessBuilder::SumQuotient { + output, + terms, + divisor, + } => { + let new_terms = terms + .iter() + .map(|SumTerm(coeff, idx)| SumTerm(*coeff, self.remap(*idx))) + .collect(); + WitnessBuilder::SumQuotient { + output: self.remap(*output), + terms: new_terms, + divisor: *divisor, + } + } WitnessBuilder::ProductLinearOperation( idx, ProductLinearTerm(x, a, b), @@ -215,6 +239,64 @@ impl WitnessIndexRemapper { .collect(), ) } + WitnessBuilder::MultiLimbMulModHint { + output_start, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbMulModHint { + output_start: self.remap(*output_start), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::MultiLimbModularInverse { + output_start, + a_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbModularInverse { + output_start: self.remap(*output_start), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::MultiLimbAddQuotient { + output, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbAddQuotient { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::MultiLimbSubBorrow { + output, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbSubBorrow { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, WitnessBuilder::BytePartition { lo, hi, x, k } => WitnessBuilder::BytePartition { lo: self.remap(*lo), hi: self.remap(*hi), @@ -299,6 +381,117 @@ impl WitnessIndexRemapper { }, ) } + WitnessBuilder::EcDoubleHint { + output_start, + px, + py, + curve_a, + field_modulus_p, + } => WitnessBuilder::EcDoubleHint { + output_start: self.remap(*output_start), + px: self.remap(*px), + py: self.remap(*py), + curve_a: *curve_a, + field_modulus_p: *field_modulus_p, + }, + WitnessBuilder::EcAddHint { + output_start, + x1, + y1, + x2, + y2, + field_modulus_p, + } => WitnessBuilder::EcAddHint { + output_start: self.remap(*output_start), + x1: self.remap(*x1), + y1: self.remap(*y1), + x2: self.remap(*x2), + y2: self.remap(*y2), + field_modulus_p: *field_modulus_p, + }, + WitnessBuilder::NonNativeEcHint { + output_start, + op, + inputs, + curve_a, + curve_b, + field_modulus_p, + limb_bits, + num_limbs, + } => WitnessBuilder::NonNativeEcHint { + output_start: self.remap(*output_start), + op: op.clone(), + inputs: inputs + .iter() + .map(|l| { + let remapped: Vec = + l.as_slice().iter().map(|&w| self.remap(w)).collect(); + Limbs::from(remapped.as_slice()) + }) + .collect(), + curve_a: *curve_a, + curve_b: *curve_b, + field_modulus_p: *field_modulus_p, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => WitnessBuilder::FakeGLVHint { + output_start: self.remap(*output_start), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_order: *curve_order, + }, + WitnessBuilder::EcScalarMulHint { + output_start, + px_limbs, + py_limbs, + s_lo, + s_hi, + curve_a, + field_modulus_p, + num_limbs, + limb_bits, + } => WitnessBuilder::EcScalarMulHint { + output_start: self.remap(*output_start), + px_limbs: px_limbs.iter().map(|&w| self.remap(w)).collect(), + py_limbs: py_limbs.iter().map(|&w| self.remap(w)).collect(), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_a: *curve_a, + field_modulus_p: *field_modulus_p, + num_limbs: *num_limbs, + limb_bits: *limb_bits, + }, + WitnessBuilder::SelectWitness { + output, + flag, + on_false, + on_true, + } => WitnessBuilder::SelectWitness { + output: self.remap(*output), + flag: self.remap(*flag), + on_false: self.remap(*on_false), + on_true: self.remap(*on_true), + }, + WitnessBuilder::BooleanOr { output, a, b } => WitnessBuilder::BooleanOr { + output: self.remap(*output), + a: self.remap(*a), + b: self.remap(*b), + }, + WitnessBuilder::SignedBitHint { + output_start, + scalar, + num_bits, + } => WitnessBuilder::SignedBitHint { + output_start: self.remap(*output_start), + scalar: self.remap(*scalar), + num_bits: *num_bits, + }, WitnessBuilder::ChunkDecompose { output_start, packed, diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 0628fc2e3..ee1af6f26 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -3,6 +3,7 @@ use { utils::{serde_ark, serde_ark_option}, witness::{ digits::DigitalDecompositionWitnesses, + limbs::Limbs, ram::SpiceWitnesses, scheduling::{ LayerScheduler, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, @@ -54,9 +55,21 @@ pub struct CombinedTableEntryInverseData { pub xor_out: FieldElement, } +/// Operation type for the unified non-native EC hint. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum NonNativeEcOp { + /// Point doubling: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 15N-6 + Double, + /// Point addition: inputs = \[\[x1_limbs\], \[y1_limbs\], \[x2_limbs\], + /// \[y2_limbs\]\], outputs 15N-6 + Add, + /// On-curve check: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 9N-4 + OnCurve, +} + /// Indicates how to solve for a collection of R1CS witnesses in terms of /// earlier (i.e. already solved for) R1CS witnesses and/or ACIR witness values. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum WitnessBuilder { /// Constant value, used for the constant one witness & e.g. static lookups /// (witness index, constant value) @@ -88,6 +101,23 @@ pub enum WitnessBuilder { /// The inverse of the value at a specified witness index /// (witness index, operand witness index) Inverse(usize, usize), + /// Safe inverse: like Inverse but handles zero by outputting 0. + /// Used by compute_is_zero where the input may be zero. Solved in the + /// Other layer (not batch-inverted), so zero inputs don't poison the batch. + /// (witness index, operand witness index) + SafeInverse(usize, usize), + /// The modular inverse of the value at a specified witness index, modulo + /// a given prime modulus. Computes a^{-1} mod m using Fermat's little + /// theorem (a^{m-2} mod m). Unlike Inverse (BN254 field inverse), this + /// operates as integer modular arithmetic. + /// (witness index, operand witness index, modulus) + ModularInverse(usize, usize, #[serde(with = "serde_ark")] FieldElement), + /// The integer quotient floor(dividend / divisor). Used by reduce_mod to + /// compute k = floor(v / m) so that v = k*m + result with 0 <= result < m. + /// Unlike field multiplication by the inverse, this performs true integer + /// division on the BigInteger representation. + /// (witness index, dividend witness index, divisor constant) + IntegerQuotient(usize, usize, #[serde(with = "serde_ark")] FieldElement), /// Products with linear operations on the witness indices. /// Fields are ProductLinearOperation(witness_idx, (index, a, b), (index, c, /// d)) such that we wish to compute (ax + b) * (cx + d). @@ -189,6 +219,61 @@ pub enum WitnessBuilder { /// Inverse of combined lookup table entry denominator (constant operands). /// Computes: 1 / (sz - lhs - rs*rhs - rs²*and_out - rs³*xor_out) CombinedTableEntryInverse(CombinedTableEntryInverseData), + /// Prover hint for multi-limb modular multiplication: (a * b) mod p. + /// Given inputs a and b as N-limb vectors (each limb `limb_bits` wide), + /// and a constant 256-bit modulus p, computes quotient q, remainder r, + /// and carry witnesses for schoolbook column verification. + /// + /// Outputs (4*num_limbs - 2) witnesses starting at output_start: + /// [0..N) q limbs (quotient) + /// [N..2N) r limbs (remainder) — OUTPUT + /// [2N..4N-2) carry witnesses (unsigned-offset) + MultiLimbMulModHint { + output_start: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Prover hint for multi-limb modular inverse: a^{-1} mod p. + /// Given input a as N-limb vector and constant modulus p, + /// computes the inverse via Fermat's little theorem (a^{p-2} mod p). + /// + /// Outputs num_limbs witnesses at output_start: inv limbs. + MultiLimbModularInverse { + output_start: usize, + a_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Prover hint for multi-limb addition quotient: q = floor((a + b) / p). + /// Given inputs a and b as N-limb vectors, and a constant modulus p, + /// computes q ∈ {0, 1}. + /// + /// Outputs 1 witness at output: q. + MultiLimbAddQuotient { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Prover hint for multi-limb subtraction borrow: q = (a < b) ? 1 : 0. + /// Given inputs a and b as N-limb vectors, and a constant modulus p, + /// computes q ∈ {0, 1} indicating whether a borrow (adding p) is needed. + /// + /// Outputs 1 witness at output: q. + MultiLimbSubBorrow { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, /// Decomposes a packed value into chunks of specified bit-widths. /// Given packed value and chunk_bits = [b0, b1, ..., bn]: /// packed = c0 + c1 * 2^b0 + c2 * 2^(b0+b1) + ... @@ -198,6 +283,118 @@ pub enum WitnessBuilder { packed: usize, chunk_bits: Vec, }, + /// Prover hint for FakeGLV scalar decomposition. + /// Given scalar s (from s_lo + s_hi * 2^128) and curve order n, + /// computes half_gcd(s, n) → (|s1|, |s2|, neg1, neg2) such that: + /// (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n) + /// + /// Outputs 4 witnesses starting at output_start: + /// \[0\] |s1| (128-bit field element) + /// \[1\] |s2| (128-bit field element) + /// \[2\] neg1 (boolean: 0 or 1) + /// \[3\] neg2 (boolean: 0 or 1) + FakeGLVHint { + output_start: usize, + s_lo: usize, + s_hi: usize, + curve_order: [u64; 4], + }, + /// Prover hint for EC scalar multiplication: computes R = \[s\]P. + /// Given point P = (px, py) and scalar s = s_lo + s_hi * 2^128, + /// computes R = \[s\]P on the curve with parameter `curve_a` and + /// field modulus `field_modulus_p`. + /// + /// When `num_limbs == 1`: inputs are single witnesses, outputs 2 + /// witnesses (R_x, R_y) as native field elements. + /// When `num_limbs >= 2`: inputs are limb witnesses, outputs + /// `2 * num_limbs` witnesses (R_x limbs then R_y limbs). + EcScalarMulHint { + output_start: usize, + px_limbs: Vec, + py_limbs: Vec, + s_lo: usize, + s_hi: usize, + curve_a: [u64; 4], + field_modulus_p: [u64; 4], + num_limbs: u32, + limb_bits: u32, + }, + /// Prover hint for EC point doubling on native field. + /// Given P = (px, py) and curve parameter `a`, computes: + /// lambda = (3*px^2 + a) / (2*py) mod p + /// x3 = lambda^2 - 2*px mod p + /// y3 = lambda * (px - x3) - py mod p + /// + /// Outputs 3 witnesses at output_start: lambda, x3, y3. + EcDoubleHint { + output_start: usize, + px: usize, + py: usize, + curve_a: [u64; 4], + field_modulus_p: [u64; 4], + }, + /// Prover hint for EC point addition on native field. + /// Given P1 = (x1, y1) and P2 = (x2, y2), computes: + /// lambda = (y2 - y1) / (x2 - x1) mod p + /// x3 = lambda^2 - x1 - x2 mod p + /// y3 = lambda * (x1 - x3) - y1 mod p + /// + /// Outputs 3 witnesses at output_start: lambda, x3, y3. + EcAddHint { + output_start: usize, + x1: usize, + y1: usize, + x2: usize, + y2: usize, + field_modulus_p: [u64; 4], + }, + /// Conditional select: output = on_false + flag * (on_true - on_false). + /// When flag=0, output=on_false; when flag=1, output=on_true. + /// (output, flag, on_false, on_true) + SelectWitness { + output: usize, + flag: usize, + on_false: usize, + on_true: usize, + }, + /// Boolean OR: output = a + b - a*b = 1 - (1-a)*(1-b). + /// (output, a, b) + BooleanOr { + output: usize, + a: usize, + b: usize, + }, + /// Unified prover hint for non-native EC operations (multi-limb). + /// + /// `op` selects the operation: + /// - `Double`: inputs = \[\[px\], \[py\]\], outputs 15N-6 witnesses + /// - `Add`: inputs = \[\[x1\], \[y1\], \[x2\], \[y2\]\], outputs 15N-6 + /// witnesses + /// - `OnCurve`: inputs = \[\[px\], \[py\]\], outputs 9N-4 witnesses + NonNativeEcHint { + output_start: usize, + op: NonNativeEcOp, + inputs: Vec, + curve_a: [u64; 4], + curve_b: [u64; 4], + field_modulus_p: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Signed-bit decomposition hint for wNAF scalar multiplication. + /// Given scalar s with num_bits bits, computes sign-bits b_0..b_{n-1} + /// and skew ∈ {0,1} such that: + /// s + skew + (2^n - 1) = Σ b_i * 2^{i+1} + /// where d_i = 2*b_i - 1 ∈ {-1, +1}. + /// + /// Outputs (num_bits + 1) witnesses at output_start: + /// [0..num_bits) b_i sign bits + /// \[num_bits\] skew (0 if s is odd, 1 if s is even) + SignedBitHint { + output_start: usize, + scalar: usize, + num_bits: usize, + }, /// Computes spread(input): interleave bits with zeros. /// Output: 0 b_{n-1} 0 b_{n-2} ... 0 b_1 0 b_0 /// (witness index of output, witness index of input) @@ -239,6 +436,15 @@ pub enum WitnessBuilder { spread_val: FieldElement, multiplicity: usize, }, + /// Computes `floor(linear_combination(terms) / divisor)` as an integer + /// quotient. Used for carry/borrow computation in multi-limb arithmetic, + /// avoiding an intermediate `Sum` witness. + SumQuotient { + output: usize, + terms: Vec, + #[serde(with = "serde_ark")] + divisor: FieldElement, + }, } impl WitnessBuilder { @@ -260,6 +466,17 @@ impl WitnessBuilder { WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, + WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => (4 * *num_limbs - 2) as usize, + WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, + WitnessBuilder::SignedBitHint { num_bits, .. } => *num_bits + 1, + WitnessBuilder::EcDoubleHint { .. } => 3, + WitnessBuilder::EcAddHint { .. } => 3, + WitnessBuilder::NonNativeEcHint { op, num_limbs, .. } => match op { + NonNativeEcOp::Double | NonNativeEcOp::Add => (15 * *num_limbs - 6) as usize, + NonNativeEcOp::OnCurve => (9 * *num_limbs - 4) as usize, + }, + WitnessBuilder::FakeGLVHint { .. } => 4, + WitnessBuilder::EcScalarMulHint { num_limbs, .. } => 2 * *num_limbs as usize, _ => 1, } diff --git a/provekit/prover/Cargo.toml b/provekit/prover/Cargo.toml index f8ab4e06e..bfea3d15d 100644 --- a/provekit/prover/Cargo.toml +++ b/provekit/prover/Cargo.toml @@ -29,6 +29,7 @@ whir.workspace = true # 3rd party anyhow.workspace = true +num-bigint.workspace = true postcard.workspace = true tracing.workspace = true mavros-vm.workspace = true diff --git a/provekit/prover/src/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs new file mode 100644 index 000000000..b9d35d31f --- /dev/null +++ b/provekit/prover/src/bigint_mod.rs @@ -0,0 +1,1194 @@ +// Re-export shared 256-bit arithmetic from provekit_common. +// Names are aliased where the prover's historical API differs. +use provekit_common::u256_arith::ceil_log2; +pub use provekit_common::u256_arith::{ + mod_add, mod_inv as mod_inverse, mod_mul as mul_mod, mod_pow, mod_sub, widening_mul, +}; +/// BigInteger modular arithmetic on [u64; 4] limbs (256-bit). +/// +/// These helpers compute modular inverse via Fermat's little theorem: +/// a^{-1} = a^{m-2} mod m, using schoolbook multiplication and +/// square-and-multiply exponentiation. +use { + ark_ff::PrimeField, + num_bigint::{BigInt, Sign}, + provekit_common::FieldElement, +}; + +/// Compare 8-limb value with 4-limb value (zero-extended to 8 limbs). +/// Returns Ordering::Greater if wide > narrow, etc. +#[cfg(test)] +fn cmp_wide_narrow(wide: &[u64; 8], narrow: &[u64; 4]) -> std::cmp::Ordering { + // Check high limbs of wide (must all be zero for equality/less) + for i in (4..8).rev() { + if wide[i] != 0 { + return std::cmp::Ordering::Greater; + } + } + // Compare the low 4 limbs + for i in (0..4).rev() { + match wide[i].cmp(&narrow[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Left-shift a 4-limb number by 1 bit. Returns the carry-out bit. +fn shift_left_one(a: &mut [u64; 4]) -> u64 { + let mut carry = 0u64; + for limb in a.iter_mut() { + let new_carry = *limb >> 63; + *limb = (*limb << 1) | carry; + carry = new_carry; + } + carry +} + +/// Compare two 4-limb numbers. +pub fn cmp_4limb(a: &[u64; 4], b: &[u64; 4]) -> std::cmp::Ordering { + for i in (0..4).rev() { + match a[i].cmp(&b[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Subtract b from a in-place (a -= b). Assumes a >= b. +fn sub_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) { + let mut borrow = 0u64; + for i in 0..4 { + let (diff, borrow1) = a[i].overflowing_sub(b[i]); + let (diff2, borrow2) = diff.overflowing_sub(borrow); + a[i] = diff2; + borrow = (borrow1 as u64) + (borrow2 as u64); + } + debug_assert_eq!(borrow, 0, "subtraction underflow: a < b"); +} + +/// Integer division with remainder: dividend = quotient * divisor + remainder, +/// where 0 <= remainder < divisor. Uses bit-by-bit long division. +pub fn divmod(dividend: &[u64; 4], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + // Find the highest set bit in dividend + let mut highest_bit = 0; + for i in (0..4).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + // Left-shift remainder by 1 + let carry = shift_left_one(&mut remainder); + debug_assert_eq!(carry, 0, "remainder overflow during shift"); + + // Bring in the next bit from dividend + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If remainder >= divisor, subtract and set quotient bit + if cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + sub_4limb_inplace(&mut remainder, divisor); + quotient[limb_idx] |= 1u64 << bit_idx; + } + } + + (quotient, remainder) +} + +/// Subtract a small u64 value from a 4-limb number. Assumes a >= small. +pub fn sub_u64(a: &[u64; 4], small: u64) -> [u64; 4] { + let mut result = *a; + let (diff, borrow) = result[0].overflowing_sub(small); + result[0] = diff; + if borrow { + for limb in result[1..].iter_mut() { + let (d, b) = limb.overflowing_sub(1); + *limb = d; + if !b { + break; + } + } + } + result +} + +/// Add two 4-limb (256-bit) numbers, returning a 5-limb result with carry. +pub fn add_4limb(a: &[u64; 4], b: &[u64; 4]) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + result[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + result[4] = carry; + result +} + +/// Add two 4-limb numbers in-place: a += b. Returns the carry-out. +pub fn add_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) -> u64 { + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + a[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + carry +} + +/// Subtract b from a in-place, returning true if a >= b (no underflow). +/// If a < b, the result is a += 2^256 - b (wrapping subtraction) and returns +/// false. +pub fn sub_4limb_checked(a: &mut [u64; 4], b: &[u64; 4]) -> bool { + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + a[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + borrow == 0 +} + +/// Returns true if val == 0. +pub fn is_zero(val: &[u64; 4]) -> bool { + val[0] == 0 && val[1] == 0 && val[2] == 0 && val[3] == 0 +} + +/// Compute the bit mask for a limb of the given width. +pub fn limb_mask(limb_bits: u32) -> u128 { + if limb_bits >= 128 { + u128::MAX + } else { + (1u128 << limb_bits) - 1 + } +} + +/// Right-shift a 4-limb (256-bit) value by `bits` positions. +pub fn shr_256(val: &[u64; 4], bits: u32) -> [u64; 4] { + if bits >= 256 { + return [0; 4]; + } + let mut shifted = [0u64; 4]; + let word_shift = (bits / 64) as usize; + let bit_shift = bits % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = val[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= val[i + word_shift + 1] << (64 - bit_shift); + } + } + } + shifted +} + +/// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width. +/// Returns u128 limb values (each < 2^limb_bits). +pub fn decompose_to_u128_limbs(val: &[u64; 4], num_limbs: usize, limb_bits: u32) -> Vec { + let mask = limb_mask(limb_bits); + let mut limbs = Vec::with_capacity(num_limbs); + let mut remaining = *val; + for _ in 0..num_limbs { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + limbs.push(lo & mask); + remaining = shr_256(&remaining, limb_bits); + } + limbs +} + +/// Convert u128 limbs to i128 limbs (for carry computation linear terms). +pub fn to_i128_limbs(limbs: &[u128]) -> Vec { + limbs.iter().map(|&v| v as i128).collect() +} + +/// Convert a `[u64; 8]` wide value to a `BigInt`. +fn wide_to_bigint(v: &[u64; 8]) -> BigInt { + let mut bytes = [0u8; 64]; + for (i, &limb) in v.iter().enumerate() { + bytes[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes()); + } + BigInt::from_bytes_le(Sign::Plus, &bytes) +} + +/// Convert a `[u64; 4]` to a `BigInt`. +fn u256_to_bigint(v: &[u64; 4]) -> BigInt { + let mut bytes = [0u8; 32]; + for (i, &limb) in v.iter().enumerate() { + bytes[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes()); + } + BigInt::from_bytes_le(Sign::Plus, &bytes) +} + +/// Convert a non-negative `BigInt` to `u128`. Panics if negative or too large. +fn bigint_to_u128(v: &BigInt) -> u128 { + assert!(v.sign() != Sign::Minus, "bigint_to_u128: negative value"); + let (_, bytes) = v.to_bytes_le(); + assert!(bytes.len() <= 16, "bigint_to_u128: value exceeds 128 bits"); + let mut buf = [0u8; 16]; + buf[..bytes.len()].copy_from_slice(&bytes); + u128::from_le_bytes(buf) +} + +/// Compute signed quotient q such that: +/// Σ lhs_products\[i\] * coeff_i + Σ lhs_linear\[j\] * coeff_j +/// - Σ rhs_products\[i\] * coeff_i - Σ rhs_linear\[j\] * coeff_j ≡ 0 (mod p) +/// +/// Returns (|q| limbs, is_negative) where q = (LHS - RHS) / p. +pub fn signed_quotient_wide( + lhs_products: &[(&[u64; 4], &[u64; 4], u64)], + rhs_products: &[(&[u64; 4], &[u64; 4], u64)], + lhs_linear: &[(&[u64; 4], u64)], + rhs_linear: &[(&[u64; 4], u64)], + p: &[u64; 4], + n: usize, + w: u32, +) -> (Vec, bool) { + fn accumulate_wide_products(terms: &[(&[u64; 4], &[u64; 4], u64)]) -> BigInt { + let mut acc = BigInt::from(0); + for &(a, b, coeff) in terms { + let prod = widening_mul(a, b); + acc += wide_to_bigint(&prod) * BigInt::from(coeff); + } + acc + } + + fn accumulate_wide_linear(terms: &[(&[u64; 4], u64)]) -> BigInt { + let mut acc = BigInt::from(0); + for &(val, coeff) in terms { + acc += u256_to_bigint(val) * BigInt::from(coeff); + } + acc + } + + let lhs = accumulate_wide_products(lhs_products) + accumulate_wide_linear(lhs_linear); + let rhs = accumulate_wide_products(rhs_products) + accumulate_wide_linear(rhs_linear); + + let diff = lhs - rhs; + let p_big = u256_to_bigint(p); + + let q_big = &diff / &p_big; + let rem = &diff - &q_big * &p_big; + assert_eq!( + rem, + BigInt::from(0), + "signed_quotient_wide: non-zero remainder" + ); + + let is_neg = q_big.sign() == Sign::Minus; + let q_abs_big = if is_neg { -&q_big } else { q_big }; + + // Decompose directly from BigInt into u128 limbs at `w` bits each, + // since the quotient may exceed 256 bits. + let limb_mask = (BigInt::from(1u64) << w) - 1; + let mut limbs = Vec::with_capacity(n); + let mut remaining = q_abs_big; + for _ in 0..n { + let limb_val = &remaining & &limb_mask; + limbs.push(bigint_to_u128(&limb_val)); + remaining >>= w; + } + assert_eq!( + remaining, + BigInt::from(0), + "quotient doesn't fit in {n} limbs at {w} bits" + ); + + (limbs, is_neg) +} + +/// Reconstruct a 256-bit value from u128 limb values packed at `limb_bits` +/// boundaries. +pub fn reconstruct_from_u128_limbs(limb_values: &[u128], limb_bits: u32) -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_u128 in limb_values.iter() { + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + if word_start + 2 < 4 && bit_within > 0 { + let upper = limb_u128 >> (128 - bit_within); + if upper > 0 { + val[word_start + 2] |= upper as u64; + } + } + } + bit_offset += limb_bits; + } + val +} + +/// Compute schoolbook carries for a*b = p*q + r verification in base +/// 2^limb_bits. Returns unsigned-offset carries ready to be written as +/// witnesses. +pub fn compute_mul_mod_carries( + a_limbs: &[u128], + b_limbs: &[u128], + p_limbs: &[u128], + q_limbs: &[u128], + r_limbs: &[u128], + limb_bits: u32, +) -> Vec { + let n = a_limbs.len(); + let w = limb_bits; + let num_carries = 2 * n - 2; + let carry_offset = BigInt::from(1u64) << (w + ceil_log2(n as u64) + 1); + let mut carries = Vec::with_capacity(num_carries); + let mut carry = BigInt::from(0); + + for k in 0..(2 * n - 1) { + let mut col_value = BigInt::from(0); + + // a*b products + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value += BigInt::from(a_limbs[i]) * BigInt::from(b_limbs[j as usize]); + } + } + + // Subtract p*q + r + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value -= BigInt::from(p_limbs[i]) * BigInt::from(q_limbs[j as usize]); + } + } + if k < n { + col_value -= BigInt::from(r_limbs[k]); + } + + col_value += &carry; + + if k < 2 * n - 2 { + let mask = (BigInt::from(1u64) << w) - 1; + debug_assert_eq!( + &col_value & &mask, + BigInt::from(0), + "non-zero remainder at column {k}" + ); + carry = &col_value >> w; + let stored = &carry + &carry_offset; + carries.push(bigint_to_u128(&stored)); + } + } + + carries +} + +/// Compute the number of bits needed for the half-GCD sub-scalars. +/// Returns `ceil(order_bits / 2)` where `order_bits` is the bit length of `n`. +pub fn half_gcd_bits(n: &[u64; 4]) -> u32 { + let mut order_bits = 0u32; + for i in (0..4).rev() { + if n[i] != 0 { + order_bits = (i as u32) * 64 + (64 - n[i].leading_zeros()); + break; + } + } + (order_bits + 1) / 2 +} + +/// Build the threshold value `2^half_bits` as a `[u64; 4]`. +fn build_threshold(half_bits: u32) -> [u64; 4] { + assert!(half_bits <= 255, "half_bits must be <= 255"); + let mut threshold = [0u64; 4]; + let word = (half_bits / 64) as usize; + let bit = half_bits % 64; + threshold[word] = 1u64 << bit; + threshold +} + +/// Half-GCD scalar decomposition for FakeGLV. +/// +/// Given scalar `s` and curve order `n`, finds `(|s1|, |s2|, neg1, neg2)` such +/// that: `(-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n)` +/// +/// Uses the extended GCD on `(n, s)`, stopping when the remainder drops below +/// `2^half_bits` where `half_bits = ceil(order_bits / 2)`. +/// Returns `(val1, val2, neg1, neg2)` where both fit in `half_bits` bits. +pub fn half_gcd(s: &[u64; 4], n: &[u64; 4]) -> ([u64; 4], [u64; 4], bool, bool) { + // Extended GCD on (n, s): + // We track: r_{i} = r_{i-2} - q_i * r_{i-1} + // t_{i} = t_{i-2} - q_i * t_{i-1} + // Starting: r_0 = n, r_1 = s, t_0 = 0, t_1 = 1 + // + // We want: t_i * s ≡ r_i (mod n) [up to sign] + // More precisely: t_i * s ≡ (-1)^{i+1} * r_i (mod n) + // + // The relation we verify is: sign_r * |r_i| + sign_t * |t_i| * s ≡ 0 (mod n) + + // Threshold: 2^half_bits where half_bits = ceil(order_bits / 2) + let half_bits = half_gcd_bits(n); + let threshold = build_threshold(half_bits); + + // r_prev = n, r_curr = s + let mut r_prev = *n; + let mut r_curr = *s; + + // t_prev = 0, t_curr = 1 + let mut t_prev = [0u64; 4]; + let mut t_curr = [1u64, 0, 0, 0]; + + // Track sign of t: t_curr_neg=false (t_1=1, positive) + let mut t_curr_neg = false; + + loop { + // Check if r_curr < threshold + if cmp_4limb(&r_curr, &threshold) == std::cmp::Ordering::Less { + break; + } + + if is_zero(&r_curr) { + break; + } + + // q = r_prev / r_curr, new_r = r_prev % r_curr + let (q, new_r) = divmod(&r_prev, &r_curr); + + // new_t = t_prev + q * t_curr (in terms of absolute values and signs) + // Since the GCD recurrence is: t_{i} = t_{i-2} - q_i * t_{i-1} + // In terms of absolute values with sign tracking: + // If t_prev and q*t_curr have the same sign → subtract magnitudes + // If they have different signs → add magnitudes + // new_t = |t_prev| +/- q * |t_curr|, with sign flips each + // iteration. + // + // The standard extended GCD recurrence gives: + // t_i = t_{i-2} - q_i * t_{i-1} + // We track magnitudes and sign bits separately. + + // Compute q * t_curr + let qt = mul_mod_no_reduce(&q, &t_curr); + + // new_t magnitude and sign: + // In the standard recurrence: new_t_val = t_prev_val - q * t_curr_val + // where t_prev_val = (-1)^t_prev_neg * |t_prev|, etc. + // + // But it's simpler to just track: alternating signs. + // In the half-GCD: t values alternate in sign. So: + // new_t = t_prev + q * t_curr (absolute addition since signs alternate) + let mut new_t = qt; + add_4limb_inplace(&mut new_t, &t_prev); + let new_t_neg = !t_curr_neg; + + r_prev = r_curr; + r_curr = new_r; + t_prev = t_curr; + t_curr = new_t; + t_curr_neg = new_t_neg; + } + + // At this point: r_curr < 2^half_bits and t_curr < ~2^half_bits (half-GCD + // property). + // + // From the extended GCD identity: t_i * s ≡ r_i (mod n) + // Rearranging: -r_i + t_i * s ≡ 0 (mod n) + // + // The circuit checks: (-1)^neg1 * |r_i| + (-1)^neg2 * |t_i| * s ≡ 0 (mod n) + // Since r_i is always non-negative, neg1 must always be true (negate r_i). + // neg2 must match the actual sign of t_i so that (-1)^neg2 * |t_i| = t_i. + + let val1 = r_curr; // |s1| = |r_i| + let val2 = t_curr; // |s2| = |t_i| + + let neg1 = true; // always negate r_i: -r_i + t_i * s ≡ 0 (mod n) + let neg2 = t_curr_neg; + + (val1, val2, neg1, neg2) +} + +/// Multiply two 4-limb values without modular reduction. +/// Returns the lower 4 limbs (ignoring overflow beyond 256 bits). +/// Used internally by half_gcd for q * t_curr where the result is known to fit. +fn mul_mod_no_reduce(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let wide = widening_mul(a, b); + debug_assert!( + wide[4..].iter().all(|&x| x == 0), + "mul_mod_no_reduce overflow: upper limbs are non-zero" + ); + [wide[0], wide[1], wide[2], wide[3]] +} + +// --------------------------------------------------------------------------- +// Conversion helpers +// --------------------------------------------------------------------------- + +/// Convert a `[u64; 4]` bigint to a `FieldElement`. +pub fn bigint_to_fe(val: &[u64; 4]) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt(*val)).unwrap() +} + +/// Read a `FieldElement` witness as a `[u64; 4]` bigint. +pub fn fe_to_bigint(fe: FieldElement) -> [u64; 4] { + fe.into_bigint().0 +} + +/// Reconstruct a 256-bit scalar from two 128-bit halves: `scalar = lo + hi * +/// 2^128`. +pub fn reconstruct_from_halves(lo: &[u64; 4], hi: &[u64; 4]) -> [u64; 4] { + [lo[0], lo[1], hi[0], hi[1]] +} + +/// EC point doubling with lambda exposed: returns (lambda, x3, y3). +/// +/// Used by the `EcDoubleHint` prover which needs lambda as a witness. +pub fn ec_point_double_with_lambda( + px: &[u64; 4], + py: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4], [u64; 4]) { + let x_sq = mul_mod(px, px, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let numerator = mod_add(&three_x_sq, a, p); + let two_y = mod_add(py, py, p); + let denom_inv = mod_inverse(&two_y, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + let lambda_sq = mul_mod(&lambda, &lambda, p); + let two_x = mod_add(px, px, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + let x_minus_x3 = mod_sub(px, &x3, p); + let lambda_dx = mul_mod(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, py, p); + + (lambda, x3, y3) +} + +/// EC point doubling in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = 2*(px, py). +pub fn ec_point_double( + px: &[u64; 4], + py: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + let (_, x3, y3) = ec_point_double_with_lambda(px, py, a, p); + (x3, y3) +} + +/// EC point addition with lambda exposed: returns (lambda, x3, y3). +/// +/// Used by the `EcAddHint` prover which needs lambda as a witness. +pub fn ec_point_add_with_lambda( + p1x: &[u64; 4], + p1y: &[u64; 4], + p2x: &[u64; 4], + p2y: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4], [u64; 4]) { + let numerator = mod_sub(p2y, p1y, p); + let denominator = mod_sub(p2x, p1x, p); + let denom_inv = mod_inverse(&denominator, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + let lambda_sq = mul_mod(&lambda, &lambda, p); + let x1_plus_x2 = mod_add(p1x, p2x, p); + let x3 = mod_sub(&lambda_sq, &x1_plus_x2, p); + + let x1_minus_x3 = mod_sub(p1x, &x3, p); + let lambda_dx = mul_mod(&lambda, &x1_minus_x3, p); + let y3 = mod_sub(&lambda_dx, p1y, p); + + (lambda, x3, y3) +} + +/// EC point addition in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = (p1x, p1y) + (p2x, p2y). Requires p1x != p2x. +pub fn ec_point_add( + p1x: &[u64; 4], + p1y: &[u64; 4], + p2x: &[u64; 4], + p2y: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + let (_, x3, y3) = ec_point_add_with_lambda(p1x, p1y, p2x, p2y, p); + (x3, y3) +} + +/// EC scalar multiplication via double-and-add: returns \[scalar\]*P. +/// +/// # Panics +/// Panics if `scalar` is zero (the point at infinity is not representable in +/// affine coordinates). +pub fn ec_scalar_mul( + px: &[u64; 4], + py: &[u64; 4], + scalar: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // Find highest set bit in scalar + let mut highest_bit = 0; + for i in (0..4).rev() { + if scalar[i] != 0 { + highest_bit = i * 64 + (64 - scalar[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + // scalar == 0 → point at infinity (not representable in affine) + panic!("ec_scalar_mul: scalar is zero"); + } + + // Start from the MSB-1 and double-and-add + let mut rx = *px; + let mut ry = *py; + + for bit_pos in (0..highest_bit - 1).rev() { + // Double + let (dx, dy) = ec_point_double(&rx, &ry, a, p); + rx = dx; + ry = dy; + + // Add if bit is set + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + if (scalar[limb_idx] >> bit_idx) & 1 == 1 { + let (ax, ay) = ec_point_add(&rx, &ry, px, py, p); + rx = ax; + ry = ay; + } + } + + (rx, ry) +} + +/// Integer division of a 512-bit dividend by a 256-bit divisor. +/// Returns (quotient, remainder) where both fit in 256 bits. +/// Panics if the quotient would exceed 256 bits. +pub fn divmod_wide(dividend: &[u64; 8], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + let mut highest_bit = 0; + for i in (0..8).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + let shift_carry = shift_left_one(&mut remainder); + + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If shift_carry is set, the effective remainder is 2^256 + remainder, + // which is always > any 256-bit divisor, so we must subtract. + if shift_carry != 0 || cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + // Subtract divisor with inline borrow tracking (handles the case + // where remainder < divisor but shift_carry provides the extra bit). + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = remainder[i].overflowing_sub(divisor[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + remainder[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + // When shift_carry was set, the borrow absorbs it (they cancel out). + debug_assert_eq!( + borrow, shift_carry, + "unexpected borrow in divmod_wide at bit_pos {}", + bit_pos + ); + + assert!(bit_pos < 256, "quotient exceeds 256 bits"); + quotient[bit_pos / 64] |= 1u64 << (bit_pos % 64); + } + } + + (quotient, remainder) +} + +/// Compute unsigned-offset carries for a general merged column equation. +/// +/// Each `product_set` entry is (a_limbs, b_limbs, coefficient): +/// LHS_terms = Σ coeff * Σ_{i+j=k} a\[i\]*b\[j\] +/// +/// Each `linear_set` entry is (limb_values, coefficient) for non-product terms: +/// LHS_terms += Σ coeff * val\[k\] (for k < val.len()) +/// +/// The equation verified is: +/// LHS + Σ p\[i\]*q_neg\[j\] = RHS + Σ p\[i\]*q_pos\[j\] + carry_chain +/// +/// `q_pos_limbs` and `q_neg_limbs` are both non-negative; at most one is +/// non-zero. +pub fn compute_ec_verification_carries( + product_sets: &[(&[u128], &[u128], i64)], + linear_terms: &[(Vec, i64)], // (limb_values extended to 2N-1, coefficient) + p_limbs: &[u128], + q_pos_limbs: &[u128], + q_neg_limbs: &[u128], + n: usize, + limb_bits: u32, + max_coeff_sum: u64, +) -> Vec { + let w = limb_bits; + let num_columns = 2 * n - 1; + let num_carries = num_columns - 1; + + let extra_bits = ceil_log2(max_coeff_sum * n as u64) + 1; + let carry_offset_bits = w + extra_bits; + let carry_offset = BigInt::from(1u64) << carry_offset_bits; + + let mut carries = Vec::with_capacity(num_carries); + let mut carry = BigInt::from(0); + + for k in 0..num_columns { + let mut col_value = BigInt::from(0); + + // Product terms + for &(a, b, coeff) in product_sets { + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value += + BigInt::from(coeff) * BigInt::from(a[i]) * BigInt::from(b[j as usize]); + } + } + } + + // Linear terms + for (vals, coeff) in linear_terms { + if k < vals.len() { + col_value += BigInt::from(*coeff) * BigInt::from(vals[k]); + } + } + + // p*q_neg on positive side, p*q_pos on negative side + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value += BigInt::from(p_limbs[i]) * BigInt::from(q_neg_limbs[j as usize]); + col_value -= BigInt::from(p_limbs[i]) * BigInt::from(q_pos_limbs[j as usize]); + } + } + + col_value += &carry; + + if k < num_carries { + let mask = (BigInt::from(1u64) << w) - 1; + debug_assert_eq!( + &col_value & &mask, + BigInt::from(0), + "non-zero remainder at column {k}: col_value={col_value}" + ); + carry = &col_value >> w; + let stored = &carry + &carry_offset; + carries.push(bigint_to_u128(&stored)); + } else { + debug_assert_eq!( + col_value, + BigInt::from(0), + "non-zero final column value: {col_value}" + ); + } + } + + carries +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_widening_mul_small() { + // 3 * 7 = 21 + let a = [3, 0, 0, 0]; + let b = [7, 0, 0, 0]; + let result = widening_mul(&a, &b); + assert_eq!(result[0], 21); + assert_eq!(result[1..], [0; 7]); + } + + #[test] + fn test_widening_mul_overflow() { + // u64::MAX * u64::MAX = (2^64-1)^2 = 2^128 - 2^65 + 1 + let a = [u64::MAX, 0, 0, 0]; + let b = [u64::MAX, 0, 0, 0]; + let result = widening_mul(&a, &b); + // (2^64-1)^2 = 0xFFFFFFFFFFFFFFFE_0000000000000001 + assert_eq!(result[0], 1); + assert_eq!(result[1], u64::MAX - 1); + assert_eq!(result[2..], [0; 6]); + } + + #[test] + fn test_reduce_wide_no_reduction() { + use provekit_common::u256_arith::reduce_wide; + // 5 mod 7 = 5 + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [5, 0, 0, 0]); + } + + #[test] + fn test_reduce_wide_basic() { + use provekit_common::u256_arith::reduce_wide; + // 10 mod 7 = 3 + let wide = [10, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [3, 0, 0, 0]); + } + + #[test] + fn test_mul_mod_small() { + // (5 * 3) mod 7 = 15 mod 7 = 1 + let a = [5, 0, 0, 0]; + let b = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mul_mod(&a, &b, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_small() { + // 3^4 mod 7 = 81 mod 7 = 4 + let base = [3, 0, 0, 0]; + let exp = [4, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [4, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_small() { + // Inverse of 3 mod 7: 3^{7-2} = 3^5 mod 7 = 243 mod 7 = 5 + // Check: 3 * 5 = 15 = 2*7 + 1 ≡ 1 (mod 7) ✓ + let a = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + let exp = sub_u64(&m, 2); // m - 2 = 5 + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [5, 0, 0, 0]); + // Verify: a * inv mod m = 1 + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_prime_23() { + // Inverse of 5 mod 23: 5^{21} mod 23 + // 5^{-1} mod 23 = 14 (because 5*14 = 70 = 3*23 + 1) + let a = [5, 0, 0, 0]; + let m = [23, 0, 0, 0]; + let exp = sub_u64(&m, 2); + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [14, 0, 0, 0]); + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_basic() { + assert_eq!(sub_u64(&[10, 0, 0, 0], 3), [7, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_borrow() { + // [0, 1, 0, 0] = 2^64; subtract 1 → [u64::MAX, 0, 0, 0] + assert_eq!(sub_u64(&[0, 1, 0, 0], 1), [u64::MAX, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_large_prime() { + // Use a 128-bit prime: p = 2^127 - 1 = 170141183460469231731687303715884105727 + // In limbs: [u64::MAX, 2^63 - 1, 0, 0] + let p = [u64::MAX, (1u64 << 63) - 1, 0, 0]; + + // a = 42 + let a = [42, 0, 0, 0]; + let exp = sub_u64(&p, 2); + let inv = mod_pow(&a, &exp, &p); + + // Verify: a * inv mod p = 1 + assert_eq!(mul_mod(&a, &inv, &p), [1, 0, 0, 0]); + } + + #[test] + fn test_cmp_wide_narrow() { + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let narrow = [5, 0, 0, 0]; + assert_eq!(cmp_wide_narrow(&wide, &narrow), std::cmp::Ordering::Equal); + + let wide_greater = [0, 0, 0, 0, 1, 0, 0, 0]; + assert_eq!( + cmp_wide_narrow(&wide_greater, &narrow), + std::cmp::Ordering::Greater + ); + } + + #[test] + fn test_mod_pow_zero_exp() { + // a^0 mod m = 1 + let base = [42, 0, 0, 0]; + let exp = [0, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_one_exp() { + // a^1 mod m = a mod m + let base = [10, 0, 0, 0]; + let exp = [1, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_exact() { + // 21 / 7 = 3 remainder 0 + let (q, r) = divmod(&[21, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_with_remainder() { + // 17 / 7 = 2 remainder 3 + let (q, r) = divmod(&[17, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [2, 0, 0, 0]); + assert_eq!(r, [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_smaller_dividend() { + // 5 / 7 = 0 remainder 5 + let (q, r) = divmod(&[5, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_zero_dividend() { + let (q, r) = divmod(&[0, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_large() { + // 2^64 / 3 = 6148914691236517205 remainder 1 + // 2^64 in limbs: [0, 1, 0, 0] + let (q, r) = divmod(&[0, 1, 0, 0], &[3, 0, 0, 0]); + assert_eq!(q, [6148914691236517205, 0, 0, 0]); + assert_eq!(r, [1, 0, 0, 0]); + // Verify: q * 3 + 1 = 2^64 + assert_eq!(6148914691236517205u64.wrapping_mul(3).wrapping_add(1), 0u64); + // wraps to 0 in u64 = 2^64 + } + + #[test] + fn test_divmod_consistency() { + // Verify dividend = quotient * divisor + remainder for various inputs + let cases: Vec<([u64; 4], [u64; 4])> = vec![ + ([100, 0, 0, 0], [7, 0, 0, 0]), + ([u64::MAX, 0, 0, 0], [1000, 0, 0, 0]), + ([0, 1, 0, 0], [u64::MAX, 0, 0, 0]), // 2^64 / (2^64 - 1) + ]; + for (dividend, divisor) in cases { + let (q, r) = divmod(÷nd, &divisor); + // Verify: q * divisor + r = dividend + let product = widening_mul(&q, &divisor); + // Add remainder to product + let mut sum = product; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + // sum should equal dividend (zero-extended to 8 limbs) + let mut expected = [0u64; 8]; + expected[..4].copy_from_slice(÷nd); + assert_eq!(sum, expected, "dividend={dividend:?} divisor={divisor:?}"); + } + } + + #[test] + fn test_divmod_wide_small() { + // 21 / 7 = 3 remainder 0 (512-bit dividend) + let dividend = [21, 0, 0, 0, 0, 0, 0, 0]; + let divisor = [7, 0, 0, 0]; + let (q, r) = divmod_wide(÷nd, &divisor); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_large() { + // Compute a * b where a, b are 256-bit, then divide by a + // Should give quotient = b, remainder = 0 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; // secp256r1 p + let b = [42, 0, 0, 0]; + let product = widening_mul(&a, &b); + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_with_remainder() { + // (a * b + 5) / a = b remainder 5 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let b = [100, 0, 0, 0]; + let mut product = widening_mul(&a, &b); + // Add 5 + let (sum, overflow) = product[0].overflowing_add(5); + product[0] = sum; + if overflow { + for i in 1..8 { + let (s, o) = product[i].overflowing_add(1); + product[i] = s; + if !o { + break; + } + } + } + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_consistency() { + // Verify: q * divisor + r = dividend + let a = [ + 0x123456789abcdef0, + 0xfedcba9876543210, + 0x1111111111111111, + 0x2222222222222222, + ]; + let b = [0xaabbccdd, 0x11223344, 0x55667788, 0x99001122]; + let product = widening_mul(&a, &b); + let divisor = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let (q, r) = divmod_wide(&product, &divisor); + + // Verify: q * divisor + r = product + let qd = widening_mul(&q, &divisor); + let mut sum = qd; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + assert_eq!(sum, product); + } + + #[test] + fn test_half_gcd_small() { + // s = 42, n = 101 + let s = [42, 0, 0, 0]; + let n = [101, 0, 0, 0]; + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + let sign1: i128 = if neg1 { -1 } else { 1 }; + let sign2: i128 = if neg2 { -1 } else { 1 }; + let v1 = val1[0] as i128; + let v2 = val2[0] as i128; + let s_val = s[0] as i128; + let n_val = n[0] as i128; + let lhs = ((sign1 * v1 + sign2 * v2 * s_val) % n_val + n_val) % n_val; + assert_eq!(lhs, 0, "half_gcd relation failed for small values"); + } + + #[test] + fn test_half_gcd_grumpkin_order() { + // Grumpkin curve order (BN254 base field order) + let n = [ + 0x3c208c16d87cfd47_u64, + 0x97816a916871ca8d_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ]; + // Some scalar + let s = [ + 0x123456789abcdef0_u64, + 0xfedcba9876543210_u64, + 0x1111111111111111_u64, + 0x2222222222222222_u64, + ]; + + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // val1 and val2 should be < 2^128 + assert_eq!(val1[2], 0, "val1 should be < 2^128"); + assert_eq!(val1[3], 0, "val1 should be < 2^128"); + assert_eq!(val2[2], 0, "val2 should be < 2^128"); + assert_eq!(val2[3], 0, "val2 should be < 2^128"); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + // Use big integer arithmetic + let term2_full = widening_mul(&val2, &s); + let (_, term2_mod_n) = divmod_wide(&term2_full, &n); + + // Compute: sign1 * val1 + sign2 * term2_mod_n (mod n) + let effective1 = if neg1 { + // n - val1 + let mut result = n; + sub_4limb_checked(&mut result, &val1); + result + } else { + val1 + }; + let effective2 = if neg2 { + let mut result = n; + sub_4limb_checked(&mut result, &term2_mod_n); + result + } else { + term2_mod_n + }; + + let sum = add_4limb(&effective1, &effective2); + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + // sum might be >= n, so reduce + let (_, remainder) = if sum[4] > 0 { + // Sum overflows 256 bits, need wide divmod + let wide = [sum[0], sum[1], sum[2], sum[3], sum[4], 0, 0, 0]; + divmod_wide(&wide, &n) + } else { + divmod(&sum4, &n) + }; + assert_eq!( + remainder, + [0, 0, 0, 0], + "half_gcd relation failed for Grumpkin order" + ); + } +} diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index bbac4de0a..f0def8273 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -21,8 +21,9 @@ use { whir::transcript::{codecs::Empty, ProverState, VerifierMessage}, }; +pub mod bigint_mod; pub mod input_utils; -mod r1cs; +pub mod r1cs; mod whir_r1cs; mod witness; diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index db91e5e0a..b914b54f1 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -1,13 +1,22 @@ use { - crate::witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, + crate::{ + bigint_mod::{ + add_4limb, bigint_to_fe, cmp_4limb, compute_ec_verification_carries, + compute_mul_mod_carries, decompose_to_u128_limbs, divmod, divmod_wide, + ec_point_add_with_lambda, ec_point_double_with_lambda, ec_scalar_mul, fe_to_bigint, + half_gcd, mod_pow, mul_mod, reconstruct_from_halves, reconstruct_from_u128_limbs, + signed_quotient_wide, sub_u64, to_i128_limbs, widening_mul, + }, + witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, + }, acir::native_types::WitnessMap, - ark_ff::{BigInteger, PrimeField}, + ark_ff::{BigInteger, Field, PrimeField}, ark_std::Zero, provekit_common::{ utils::noir_to_native, witness::{ - compute_spread, ConstantOrR1CSWitness, ConstantTerm, ProductLinearTerm, SumTerm, - WitnessBuilder, WitnessCoefficient, + compute_spread, ConstantOrR1CSWitness, ConstantTerm, NonNativeEcOp, ProductLinearTerm, + SumTerm, WitnessBuilder, WitnessCoefficient, }, FieldElement, NoirElement, TranscriptSponge, }, @@ -23,6 +32,102 @@ pub trait WitnessBuilderSolver { ); } +/// Resolve a ConstantOrR1CSWitness to its FieldElement value. +fn resolve(witness: &[Option], v: &ConstantOrR1CSWitness) -> FieldElement { + match v { + ConstantOrR1CSWitness::Constant(c) => *c, + ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), + } +} + +/// Convert a u128 value to a FieldElement. +fn u128_to_fe(val: u128) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt([val as u64, (val >> 64) as u64, 0, 0])).unwrap() +} + +/// Read witness limbs and reconstruct as [u64; 4]. +fn read_witness_limbs( + witness: &[Option], + indices: &[usize], + limb_bits: u32, +) -> [u64; 4] { + let limb_values: Vec = indices + .iter() + .map(|&idx| { + assert!( + idx < witness.len(), + "read_witness_limbs: index {idx} out of bounds (witness len {})", + witness.len() + ); + let bigint = witness[idx].unwrap().into_bigint().0; + bigint[0] as u128 | ((bigint[1] as u128) << 64) + }) + .collect(); + reconstruct_from_u128_limbs(&limb_values, limb_bits) +} + +/// Write u128 limb values as FieldElement witnesses starting at `start`. +fn write_limbs(witness: &mut [Option], start: usize, vals: &[u128]) { + for (i, &val) in vals.iter().enumerate() { + witness[start + i] = Some(u128_to_fe(val)); + } +} + +/// Split a signed quotient into `(q_pos, q_neg)` limb vectors. +fn split_quotient(q_abs: Vec, is_neg: bool, n: usize) -> (Vec, Vec) { + if is_neg { + (vec![0u128; n], q_abs) + } else { + (q_abs, vec![0u128; n]) + } +} + +/// Solve one column equation: quotient → split → write → carries → write. +/// +/// Computes `(LHS - RHS) / p` as a signed quotient, splits into (q_pos, q_neg), +/// computes verification carries, and writes all witnesses at `os + offset`. +/// Witness layout at offset: `[q_pos(N), q_neg(N), carries(2N-2)]`. +fn solve_and_write_equation( + witness: &mut [Option], + os: usize, + offset: usize, + n: usize, + w: u32, + max_coeff_sum: u64, + field_modulus_p: &[u64; 4], + p_l: &[u128], + sq_lhs_prods: &[(&[u64; 4], &[u64; 4], u64)], + sq_rhs_prods: &[(&[u64; 4], &[u64; 4], u64)], + sq_lhs_linear: &[(&[u64; 4], u64)], + sq_rhs_linear: &[(&[u64; 4], u64)], + carry_prods: &[(&[u128], &[u128], i64)], + carry_linear: &[(Vec, i64)], +) { + let (q_abs, is_neg) = signed_quotient_wide( + sq_lhs_prods, + sq_rhs_prods, + sq_lhs_linear, + sq_rhs_linear, + field_modulus_p, + n, + w, + ); + let (q_pos, q_neg) = split_quotient(q_abs, is_neg, n); + write_limbs(witness, os + offset, &q_pos); + write_limbs(witness, os + offset + n, &q_neg); + let carries = compute_ec_verification_carries( + carry_prods, + carry_linear, + p_l, + &q_pos, + &q_neg, + n, + w, + max_coeff_sum, + ); + write_limbs(witness, os + offset + 2 * n, &carries); +} + impl WitnessBuilderSolver for WitnessBuilder { fn solve( &self, @@ -65,6 +170,43 @@ impl WitnessBuilderSolver for WitnessBuilder { "Inverse/LogUpInverse should not be called - handled by batch inversion" ) } + WitnessBuilder::SafeInverse(witness_idx, operand_idx) => { + let val = witness[*operand_idx].unwrap(); + witness[*witness_idx] = Some(if val == FieldElement::zero() { + FieldElement::zero() + } else { + val.inverse().unwrap() + }); + } + WitnessBuilder::ModularInverse(witness_idx, operand_idx, modulus) => { + let a_limbs = fe_to_bigint(witness[*operand_idx].unwrap()); + let m_limbs = modulus.into_bigint().0; + let exp = sub_u64(&m_limbs, 2); + witness[*witness_idx] = Some(bigint_to_fe(&mod_pow(&a_limbs, &exp, &m_limbs))); + } + WitnessBuilder::IntegerQuotient(witness_idx, dividend_idx, divisor) => { + let d_limbs = fe_to_bigint(witness[*dividend_idx].unwrap()); + let m_limbs = divisor.into_bigint().0; + let (quotient, _) = divmod(&d_limbs, &m_limbs); + witness[*witness_idx] = Some(bigint_to_fe("ient)); + } + WitnessBuilder::SumQuotient { + output, + terms, + divisor, + } => { + let sum: FieldElement = terms + .iter() + .map(|SumTerm(coeff, idx)| { + let val = witness[*idx].unwrap(); + coeff.map_or(val, |c| c * val) + }) + .fold(FieldElement::zero(), |acc, x| acc + x); + let d_limbs = fe_to_bigint(sum); + let m_limbs = divisor.into_bigint().0; + let (quotient, _) = divmod(&d_limbs, &m_limbs); + witness[*output] = Some(bigint_to_fe("ient)); + } WitnessBuilder::IndexedLogUpDenominator( witness_idx, sz_challenge, @@ -145,18 +287,9 @@ impl WitnessBuilderSolver for WitnessBuilder { rhs, output, ) => { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let output = match output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); + let output = resolve(witness, output); witness[*witness_idx] = Some( witness[*sz_challenge].unwrap() - (lhs @@ -175,22 +308,10 @@ impl WitnessBuilderSolver for WitnessBuilder { and_output, xor_output, ) => { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let and_out = match and_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let xor_out = match xor_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); + let and_out = resolve(witness, and_output); + let xor_out = resolve(witness, xor_output); // Encoding: sz - (lhs + rs*rhs + rs²*and_out + rs³*xor_out) witness[*witness_idx] = Some( witness[*sz_challenge].unwrap() @@ -203,18 +324,8 @@ impl WitnessBuilderSolver for WitnessBuilder { WitnessBuilder::MultiplicitiesForBinOp(witness_idx, atomic_bits, operands) => { let mut multiplicities = vec![0u32; 2usize.pow(2 * *atomic_bits)]; for (lhs, rhs) in operands { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => { - witness[*witness_idx].unwrap() - } - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => { - witness[*witness_idx].unwrap() - } - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); let index = (lhs.into_bigint().0[0] << *atomic_bits) + rhs.into_bigint().0[0]; multiplicities[index as usize] += 1; } @@ -223,14 +334,8 @@ impl WitnessBuilderSolver for WitnessBuilder { } } WitnessBuilder::U32Addition(result_witness_idx, carry_witness_idx, a, b) => { - let a_val = match a { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), - }; - let b_val = match b { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), - }; + let a_val = resolve(witness, a); + let b_val = resolve(witness, b); assert!( a_val.into_bigint().num_bits() <= 32, "a_val must be less than or equal to 32 bits, got {}", @@ -258,12 +363,7 @@ impl WitnessBuilderSolver for WitnessBuilder { // Sum all inputs as u64 to handle overflow. let mut sum: u64 = 0; for input in inputs { - let val = match input { - ConstantOrR1CSWitness::Constant(c) => c.into_bigint().0[0], - ConstantOrR1CSWitness::Witness(idx) => { - witness[*idx].unwrap().into_bigint().0[0] - } - }; + let val = resolve(witness, input).into_bigint().0[0]; assert!(val < (1u64 << 32), "input must be 32-bit"); sum += val; } @@ -274,14 +374,8 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*carry_witness_idx] = Some(FieldElement::from(quotient)); } WitnessBuilder::And(result_witness_idx, lh, rh) => { - let lh_val = match lh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rh_val = match rh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lh_val = resolve(witness, lh); + let rh_val = resolve(witness, rh); assert!( lh_val.into_bigint().num_bits() <= 32, "lh_val must be less than or equal to 32 bits, got {}", @@ -297,14 +391,8 @@ impl WitnessBuilderSolver for WitnessBuilder { )); } WitnessBuilder::Xor(result_witness_idx, lh, rh) => { - let lh_val = match lh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rh_val = match rh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lh_val = resolve(witness, lh); + let rh_val = resolve(witness, rh); assert!( lh_val.into_bigint().num_bits() <= 32, "lh_val must be less than or equal to 32 bits, got {}", @@ -319,9 +407,107 @@ impl WitnessBuilderSolver for WitnessBuilder { lh_val.into_bigint() ^ rh_val.into_bigint(), )); } + WitnessBuilder::MultiLimbMulModHint { + output_start, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => { + let n = *num_limbs as usize; + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); + + let product = widening_mul(&a_val, &b_val); + let (q_val, r_val) = divmod_wide(&product, modulus); + + let q_limbs_vals = decompose_to_u128_limbs(&q_val, n, w); + let r_limbs_vals = decompose_to_u128_limbs(&r_val, n, w); + + let carries = compute_mul_mod_carries( + &decompose_to_u128_limbs(&a_val, n, w), + &decompose_to_u128_limbs(&b_val, n, w), + &decompose_to_u128_limbs(modulus, n, w), + &q_limbs_vals, + &r_limbs_vals, + w, + ); + + write_limbs(witness, *output_start, &q_limbs_vals); + write_limbs(witness, *output_start + n, &r_limbs_vals); + write_limbs(witness, *output_start + 2 * n, &carries); + } + WitnessBuilder::MultiLimbModularInverse { + output_start, + a_limbs, + modulus, + limb_bits, + num_limbs, + } => { + let n = *num_limbs as usize; + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let exp = sub_u64(modulus, 2); + let inv = mod_pow(&a_val, &exp, modulus); + write_limbs(witness, *output_start, &decompose_to_u128_limbs(&inv, n, w)); + } + WitnessBuilder::MultiLimbAddQuotient { + output, + a_limbs, + b_limbs, + modulus, + limb_bits, + .. + } => { + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); + + let sum = add_4limb(&a_val, &b_val); + let q = if sum[4] > 0 { + 1u64 + } else { + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + if cmp_4limb(&sum4, modulus) != std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + } + }; + + witness[*output] = Some(FieldElement::from(q)); + } + WitnessBuilder::MultiLimbSubBorrow { + output, + a_limbs, + b_limbs, + limb_bits, + .. + } => { + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); + + let q = if cmp_4limb(&a_val, &b_val) == std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + }; + + witness[*output] = Some(FieldElement::from(q)); + } WitnessBuilder::BytePartition { lo, hi, x, k } => { let x_val = witness[*x].unwrap().into_bigint().0[0]; - debug_assert!(x_val < 256, "BytePartition input must be 8-bit"); + assert!( + x_val < 256, + "BytePartition input must be 8-bit, got {x_val}" + ); let mask = (1u64 << *k) - 1; let lo_val = x_val & mask; @@ -330,6 +516,411 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*lo] = Some(FieldElement::from(lo_val)); witness[*hi] = Some(FieldElement::from(hi_val)); } + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => { + let s_val = reconstruct_from_halves( + &fe_to_bigint(witness[*s_lo].unwrap()), + &fe_to_bigint(witness[*s_hi].unwrap()), + ); + + let (val1, val2, neg1, neg2) = half_gcd(&s_val, curve_order); + + witness[*output_start] = Some(bigint_to_fe(&val1)); + witness[*output_start + 1] = Some(bigint_to_fe(&val2)); + witness[*output_start + 2] = Some(FieldElement::from(neg1 as u64)); + witness[*output_start + 3] = Some(FieldElement::from(neg2 as u64)); + } + WitnessBuilder::EcDoubleHint { + output_start, + px, + py, + curve_a, + field_modulus_p, + } => { + let px_val = fe_to_bigint(witness[*px].unwrap()); + let py_val = fe_to_bigint(witness[*py].unwrap()); + + let (lambda, x3, y3) = + ec_point_double_with_lambda(&px_val, &py_val, curve_a, field_modulus_p); + + witness[*output_start] = Some(bigint_to_fe(&lambda)); + witness[*output_start + 1] = Some(bigint_to_fe(&x3)); + witness[*output_start + 2] = Some(bigint_to_fe(&y3)); + } + WitnessBuilder::EcAddHint { + output_start, + x1, + y1, + x2, + y2, + field_modulus_p, + } => { + let x1_val = fe_to_bigint(witness[*x1].unwrap()); + let y1_val = fe_to_bigint(witness[*y1].unwrap()); + let x2_val = fe_to_bigint(witness[*x2].unwrap()); + let y2_val = fe_to_bigint(witness[*y2].unwrap()); + + let (lambda, x3, y3) = + ec_point_add_with_lambda(&x1_val, &y1_val, &x2_val, &y2_val, field_modulus_p); + + witness[*output_start] = Some(bigint_to_fe(&lambda)); + witness[*output_start + 1] = Some(bigint_to_fe(&x3)); + witness[*output_start + 2] = Some(bigint_to_fe(&y3)); + } + WitnessBuilder::NonNativeEcHint { + output_start, + op, + inputs, + curve_a, + curve_b, + field_modulus_p, + limb_bits, + num_limbs, + } => { + let n = *num_limbs as usize; + let w = *limb_bits; + let os = *output_start; + + let p_l = decompose_to_u128_limbs(field_modulus_p, n, w); + + match op { + NonNativeEcOp::Double => { + let px_val = read_witness_limbs(witness, inputs[0].as_slice(), w); + let py_val = read_witness_limbs(witness, inputs[1].as_slice(), w); + let (lam, x3v, y3v) = + ec_point_double_with_lambda(&px_val, &py_val, curve_a, field_modulus_p); + let ll = decompose_to_u128_limbs(&lam, n, w); + let xl = decompose_to_u128_limbs(&x3v, n, w); + let yl = decompose_to_u128_limbs(&y3v, n, w); + let pl = decompose_to_u128_limbs(&px_val, n, w); + let pyl = decompose_to_u128_limbs(&py_val, n, w); + let a_l = decompose_to_u128_limbs(curve_a, n, w); + write_limbs(witness, os, &ll); + write_limbs(witness, os + n, &xl); + write_limbs(witness, os + 2 * n, &yl); + + // Per-equation max_coeff_sum must match compiler + // (see hints_non_native.rs:point_double_verified_non_native) + let mcs_eq1 = 6 + 2 * n as u64; // λy(2)+xx(3)+a(1)+pq(2n) + let mcs_eq2 = 4 + 2 * n as u64; // λλ(1)+x3(1)+px(2)+pq(2n) + let mcs_eq3 = 4 + 2 * n as u64; // λΔx(1)+y3(1)+py(1)+r(1)+pq(2n) → 4+2n + + // Layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 + + // Eq1: 2*λ*py - 3*px² - a = q1*p + solve_and_write_equation( + witness, + os, + 3 * n, + n, + w, + mcs_eq1, + field_modulus_p, + &p_l, + &[(&lam, &py_val, 2)], + &[(&px_val, &px_val, 3)], + &[], + &[(curve_a, 1)], + &[(&ll, &pyl, 2), (&pl, &pl, -3)], + &[(to_i128_limbs(&a_l), -1)], + ); + // Eq2: λ² - x3 - 2*px = q2*p + solve_and_write_equation( + witness, + os, + 7 * n - 2, + n, + w, + mcs_eq2, + field_modulus_p, + &p_l, + &[(&lam, &lam, 1)], + &[], + &[], + &[(&x3v, 1), (&px_val, 2)], + &[(&ll, &ll, 1)], + &[(to_i128_limbs(&xl), -1), (to_i128_limbs(&pl), -2)], + ); + // Eq3: λ*px - λ*x3 - y3 - py = q3*p + solve_and_write_equation( + witness, + os, + 11 * n - 4, + n, + w, + mcs_eq3, + field_modulus_p, + &p_l, + &[(&lam, &px_val, 1)], + &[(&lam, &x3v, 1)], + &[], + &[(&y3v, 1), (&py_val, 1)], + &[(&ll, &pl, 1), (&ll, &xl, -1)], + &[(to_i128_limbs(&yl), -1), (to_i128_limbs(&pyl), -1)], + ); + } + NonNativeEcOp::Add => { + let x1v = read_witness_limbs(witness, inputs[0].as_slice(), w); + let y1v = read_witness_limbs(witness, inputs[1].as_slice(), w); + let x2v = read_witness_limbs(witness, inputs[2].as_slice(), w); + let y2v = read_witness_limbs(witness, inputs[3].as_slice(), w); + let (lam, x3v, y3v) = + ec_point_add_with_lambda(&x1v, &y1v, &x2v, &y2v, field_modulus_p); + let ll = decompose_to_u128_limbs(&lam, n, w); + let xl = decompose_to_u128_limbs(&x3v, n, w); + let yl = decompose_to_u128_limbs(&y3v, n, w); + let x1l = decompose_to_u128_limbs(&x1v, n, w); + let y1l = decompose_to_u128_limbs(&y1v, n, w); + let x2l = decompose_to_u128_limbs(&x2v, n, w); + let y2l = decompose_to_u128_limbs(&y2v, n, w); + write_limbs(witness, os, &ll); + write_limbs(witness, os + n, &xl); + write_limbs(witness, os + 2 * n, &yl); + + // Must match compiler's max_coeff_sum + // (see hints_non_native.rs:point_add_verified_non_native) + let mcs = 4 + 2 * n as u64; // 1+1+1+1+2n for all 3 eqs + + // Layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 + + // Eq1: λ*x2 - λ*x1 + y1 - y2 = q1*p + solve_and_write_equation( + witness, + os, + 3 * n, + n, + w, + mcs, + field_modulus_p, + &p_l, + &[(&lam, &x2v, 1)], + &[(&lam, &x1v, 1)], + &[(&y1v, 1)], + &[(&y2v, 1)], + &[(&ll, &x2l, 1), (&ll, &x1l, -1)], + &[(to_i128_limbs(&y2l), -1), (to_i128_limbs(&y1l), 1)], + ); + // Eq2: λ² - x3 - x1 - x2 = q2*p + solve_and_write_equation( + witness, + os, + 7 * n - 2, + n, + w, + mcs, + field_modulus_p, + &p_l, + &[(&lam, &lam, 1)], + &[], + &[], + &[(&x3v, 1), (&x1v, 1), (&x2v, 1)], + &[(&ll, &ll, 1)], + &[ + (to_i128_limbs(&xl), -1), + (to_i128_limbs(&x1l), -1), + (to_i128_limbs(&x2l), -1), + ], + ); + // Eq3: λ*x1 - λ*x3 - y3 - y1 = q3*p + solve_and_write_equation( + witness, + os, + 11 * n - 4, + n, + w, + mcs, + field_modulus_p, + &p_l, + &[(&lam, &x1v, 1)], + &[(&lam, &x3v, 1)], + &[], + &[(&y3v, 1), (&y1v, 1)], + &[(&ll, &x1l, 1), (&ll, &xl, -1)], + &[(to_i128_limbs(&yl), -1), (to_i128_limbs(&y1l), -1)], + ); + } + NonNativeEcOp::OnCurve => { + let px_val = read_witness_limbs(witness, inputs[0].as_slice(), w); + let py_val = read_witness_limbs(witness, inputs[1].as_slice(), w); + let x_sq_val = mul_mod(&px_val, &px_val, field_modulus_p); + let xsl = decompose_to_u128_limbs(&x_sq_val, n, w); + let pl = decompose_to_u128_limbs(&px_val, n, w); + let pyl = decompose_to_u128_limbs(&py_val, n, w); + write_limbs(witness, os, &xsl); + + let a_is_zero = curve_a.iter().all(|&v| v == 0); + // Per-equation max_coeff_sum must match compiler + // (see hints_non_native.rs:verify_on_curve_non_native) + let mcs_eq1: u64 = 2 + 2 * n as u64; // px·px(1)+x_sq(1)+pq(2n) + let mcs_eq2: u64 = if a_is_zero { + 3 + 2 * n as u64 // x³(1)+y²(1)+b(1)+pq(2n) + } else { + 4 + 2 * n as u64 // x³(1)+y²(1)+ax(1)+b(1)+pq(2n) + }; + + // Layout: [x_sq(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2)] + // Total: 9N-4 + + // Eq1: px·px - x_sq = q1·p + solve_and_write_equation( + witness, + os, + n, + n, + w, + mcs_eq1, + field_modulus_p, + &p_l, + &[(&px_val, &px_val, 1)], + &[], + &[], + &[(&x_sq_val, 1)], + &[(&pl, &pl, 1)], + &[(to_i128_limbs(&xsl), -1)], + ); + + // Eq2: py·py - x_sq·px - a·px - b = q2·p + let a_l = decompose_to_u128_limbs(curve_a, n, w); + let b_l = decompose_to_u128_limbs(curve_b, n, w); + + let mut rhs_prods: Vec<(&[u64; 4], &[u64; 4], u64)> = + vec![(&x_sq_val, &px_val, 1)]; + if !a_is_zero { + rhs_prods.push((curve_a, &px_val, 1)); + } + let (q2_abs, q2_neg) = signed_quotient_wide( + &[(&py_val, &py_val, 1)], + &rhs_prods, + &[], + &[(curve_b, 1)], + field_modulus_p, + n, + w, + ); + + let (q2_pos, q2_neg) = split_quotient(q2_abs, q2_neg, n); + write_limbs(witness, os + 5 * n - 2, &q2_pos); + write_limbs(witness, os + 6 * n - 2, &q2_neg); + + let mut prod_sets: Vec<(&[u128], &[u128], i64)> = + vec![(&pyl, &pyl, 1), (&xsl, &pl, -1)]; + if !a_is_zero { + prod_sets.push((&a_l, &pl, -1)); + } + let c2 = compute_ec_verification_carries( + &prod_sets, + &[(to_i128_limbs(&b_l), -1)], + &p_l, + &q2_pos, + &q2_neg, + n, + w, + mcs_eq2, + ); + write_limbs(witness, os + 7 * n - 2, &c2); + } + } + } + WitnessBuilder::EcScalarMulHint { + output_start, + px_limbs, + py_limbs, + s_lo, + s_hi, + curve_a, + field_modulus_p, + num_limbs, + limb_bits, + } => { + let n = *num_limbs as usize; + let scalar = reconstruct_from_halves( + &fe_to_bigint(witness[*s_lo].unwrap()), + &fe_to_bigint(witness[*s_hi].unwrap()), + ); + + let px_val = if n == 1 { + fe_to_bigint(witness[px_limbs[0]].unwrap()) + } else { + read_witness_limbs(witness, px_limbs, *limb_bits) + }; + let py_val = if n == 1 { + fe_to_bigint(witness[py_limbs[0]].unwrap()) + } else { + read_witness_limbs(witness, py_limbs, *limb_bits) + }; + + let (rx, ry) = ec_scalar_mul(&px_val, &py_val, &scalar, curve_a, field_modulus_p); + + if n == 1 { + witness[*output_start] = Some(bigint_to_fe(&rx)); + witness[*output_start + 1] = Some(bigint_to_fe(&ry)); + } else { + let rx_limbs = decompose_to_u128_limbs(&rx, n, *limb_bits); + let ry_limbs = decompose_to_u128_limbs(&ry, n, *limb_bits); + write_limbs(witness, *output_start, &rx_limbs); + write_limbs(witness, *output_start + n, &ry_limbs); + } + } + WitnessBuilder::SelectWitness { + output, + flag, + on_false, + on_true, + } => { + let f = witness[*flag].unwrap(); + let a = witness[*on_false].unwrap(); + let b = witness[*on_true].unwrap(); + witness[*output] = Some(a + f * (b - a)); + } + WitnessBuilder::BooleanOr { output, a, b } => { + let a_val = witness[*a].unwrap(); + let b_val = witness[*b].unwrap(); + witness[*output] = Some(a_val + b_val - a_val * b_val); + } + WitnessBuilder::SignedBitHint { + output_start, + scalar, + num_bits, + } => { + assert!( + *num_bits <= 128, + "SignedBitHint: num_bits={} exceeds 128; scalar would be silently truncated", + num_bits + ); + let s_fe = witness[*scalar].unwrap(); + let s_big = s_fe.into_bigint().0; + let s_val: u128 = s_big[0] as u128 | ((s_big[1] as u128) << 64); + let n = *num_bits; + let skew: u128 = if s_val & 1 == 0 { 1 } else { 0 }; + let s_adj = s_val + skew; + // t = (s_adj + 2^n - 1) / 2 + // Both s_adj and 2^n-1 are odd, so sum is even. + // To avoid u128 overflow when n >= 128, rewrite as: + // t = (s_adj - 1) / 2 + (2^n - 1 + 1) / 2 = (s_adj - 1) / 2 + 2^(n-1) + let t = if n == 0 { + s_adj / 2 + } else { + (s_adj - 1) / 2 + (1u128 << (n - 1)) + }; + for i in 0..n { + witness[*output_start + i] = Some(FieldElement::from(((t >> i) & 1) as u64)); + } + witness[*output_start + n] = Some(FieldElement::from(skew as u64)); + } WitnessBuilder::CombinedTableEntryInverse(..) => { unreachable!( "CombinedTableEntryInverse should not be called - handled by batch inversion" @@ -393,12 +984,7 @@ impl WitnessBuilderSolver for WitnessBuilder { let table_size = 1usize << *num_bits; let mut multiplicities = vec![0u32; table_size]; for query in queries { - let val = match query { - ConstantOrR1CSWitness::Constant(c) => c.into_bigint().0[0], - ConstantOrR1CSWitness::Witness(w) => { - witness[*w].unwrap().into_bigint().0[0] - } - }; + let val = resolve(witness, query).into_bigint().0[0]; multiplicities[val as usize] += 1; } for (i, count) in multiplicities.iter().enumerate() { @@ -408,14 +994,8 @@ impl WitnessBuilderSolver for WitnessBuilder { WitnessBuilder::SpreadLookupDenominator(idx, sz, rs, input, spread_output) => { let sz_val = witness[*sz].unwrap(); let rs_val = witness[*rs].unwrap(); - let input_val = match input { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(w) => witness[*w].unwrap(), - }; - let spread_val = match spread_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(w) => witness[*w].unwrap(), - }; + let input_val = resolve(witness, input); + let spread_val = resolve(witness, spread_output); // sz - (input + rs * spread_output) witness[*idx] = Some(sz_val - (input_val + rs_val * spread_val)); } diff --git a/provekit/r1cs-compiler/src/constraint_helpers.rs b/provekit/r1cs-compiler/src/constraint_helpers.rs new file mode 100644 index 000000000..9561fe7f3 --- /dev/null +++ b/provekit/r1cs-compiler/src/constraint_helpers.rs @@ -0,0 +1,133 @@ +//! General-purpose R1CS constraint helpers. + +use { + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, +}; + +/// Constrains `flag` to be boolean: `flag * flag = flag`. +pub(crate) fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + ); +} + +/// Single-witness conditional select: `out = on_false + flag * (on_true - +/// on_false)`. +/// +/// Uses a single witness + single R1CS constraint: +/// flag * (on_true - on_false) = result - on_false +pub(crate) fn select_witness( + compiler: &mut NoirToR1CSCompiler, + flag: usize, + on_false: usize, + on_true: usize, +) -> usize { + // When both branches are the same witness, result is trivially that witness. + if on_false == on_true { + return on_false; + } + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SelectWitness { + output: result, + flag, + on_false, + on_true, + }); + // flag * (on_true - on_false) = result - on_false + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, on_true), (-FieldElement::ONE, on_false)], + &[(FieldElement::ONE, result), (-FieldElement::ONE, on_false)], + ); + result +} + +/// Packs bit witnesses into a digit: `d = Σ bits\[i\] * 2^i`. +pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize]) -> usize { + let terms: Vec = bits + .iter() + .enumerate() + .map(|(i, &bit)| SumTerm(Some(FieldElement::from(1u128 << i)), bit)) + .collect(); + compiler.add_sum(terms) +} + +/// Computes `a OR b` for two boolean witnesses: `1 - (1 - a)(1 - b)`. +/// Does NOT constrain a or b to be boolean — caller must ensure that. +/// +/// Uses a single witness + single R1CS constraint: +/// (1 - a) * (1 - b) = 1 - result +pub(crate) fn compute_boolean_or(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) -> usize { + let one = compiler.witness_one(); + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::BooleanOr { + output: result, + a, + b, + }); + // (1 - a) * (1 - b) = 1 - result + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, one), (-FieldElement::ONE, a)], + &[(FieldElement::ONE, one), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, one), (-FieldElement::ONE, result)], + ); + result +} + +/// Creates a constant witness with the given value, pinned by an R1CS +/// constraint so that a malicious prover cannot set it to an arbitrary value. +pub(crate) fn add_constant_witness( + compiler: &mut NoirToR1CSCompiler, + value: FieldElement, +) -> usize { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); + // Pin: 1 * w = value * 1 (embeds the constant into the constraint matrix) + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(value, compiler.witness_one())], + ); + w +} + +/// Constrains a witness to equal a known constant value. +/// Uses the constant as an R1CS coefficient — no witness needed for the +/// expected value. Use this for identity checks where the witness must equal +/// a compile-time-known value. +pub(crate) fn constrain_to_constant( + compiler: &mut NoirToR1CSCompiler, + witness: usize, + value: FieldElement, +) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, witness)], + &[(value, compiler.witness_one())], + ); +} + +/// Constrains two witnesses to be equal: `a - b = 0`. +pub(crate) fn constrain_equal(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} + +/// Constrains a witness to be zero: `w = 0`. +pub(crate) fn constrain_zero(compiler: &mut NoirToR1CSCompiler, w: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} diff --git a/provekit/r1cs-compiler/src/digits.rs b/provekit/r1cs-compiler/src/digits.rs index 91c4e4128..657f7bd78 100644 --- a/provekit/r1cs-compiler/src/digits.rs +++ b/provekit/r1cs-compiler/src/digits.rs @@ -1,5 +1,6 @@ use { crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::Field, ark_std::One, provekit_common::{ witness::{DigitalDecompositionWitnesses, WitnessBuilder}, @@ -66,7 +67,8 @@ pub(crate) fn add_digital_decomposition( // Add the constraints for the digital recomposition let mut digit_multipliers = vec![FieldElement::one()]; for log_base in log_bases[..log_bases.len() - 1].iter() { - let multiplier = *digit_multipliers.last().unwrap() * FieldElement::from(1u64 << *log_base); + let multiplier = + *digit_multipliers.last().unwrap() * FieldElement::from(2u64).pow([*log_base as u64]); digit_multipliers.push(multiplier); } dd_struct diff --git a/provekit/r1cs-compiler/src/lib.rs b/provekit/r1cs-compiler/src/lib.rs index 7de8f899b..0b9890a8a 100644 --- a/provekit/r1cs-compiler/src/lib.rs +++ b/provekit/r1cs-compiler/src/lib.rs @@ -1,10 +1,12 @@ mod binops; +mod constraint_helpers; mod digits; mod memory; +pub mod msm; mod noir_proof_scheme; -mod noir_to_r1cs; +pub mod noir_to_r1cs; mod poseidon2; -mod range_check; +pub mod range_check; mod sha256_compression; mod spread; mod uints; diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs new file mode 100644 index 000000000..463f09174 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -0,0 +1,525 @@ +//! Analytical cost model for MSM parameter optimization. +//! +//! Follows the SHA256 pattern (`spread.rs:get_optimal_spread_width`): +//! `calculate_msm_witness_cost` estimates total cost, `get_optimal_msm_params` +//! searches the parameter space for the minimum. + +use { + super::{ceil_log2, SCALAR_HALF_BITS}, + std::collections::BTreeMap, +}; + +/// Per-point overhead witnesses shared across all MSM paths. +const DETECT_SKIP_WIT: usize = 8; +const GLV_HINT_WIT: usize = 4; + +fn ceil_div(a: usize, b: usize) -> usize { + (a + b - 1) / b +} + +/// Table building ops: (doubles, adds) for a signed-digit table. +fn table_build_ops(half_table_size: usize) -> (usize, usize) { + if half_table_size >= 2 { + (1, half_table_size - 1) + } else { + (0, 0) + } +} + +// --------------------------------------------------------------------------- +// Hint-verified EC cost primitives (non-native, num_limbs >= 2) +// +// Every hint-verified EC op follows the same pattern: +// 1. Allocate hint: result limb-vectors + per-equation (q_pos, q_neg, +// carries) +// 2. Compute N×N schoolbook product grids +// 3. Pin constant limb-vectors (curve params) +// 4. Verify via schoolbook column equations +// 5. Range-check hint outputs (limb-bit) and carries (carry-bit) +// 6. less-than-p check on each result vector +// +// The helpers below decompose costs into these structural components. +// --------------------------------------------------------------------------- + +/// Witnesses per limb from a less-than-p borrow chain: borrow + d_i. +const LTP_WIT_PER_LIMB: usize = 2; +/// Range checks per limb from a less-than-p check: borrow + d_i. +const LTP_RC_PER_LIMB: usize = 2; +/// Witnesses per limb from a multi-limb negate (p-y borrow chain): borrow + r. +const NEGATE_WIT_PER_LIMB: usize = 2; +/// Range checks per limb from a multi-limb negate: r. +const NEGATE_RC_PER_LIMB: usize = 1; + +/// Hint output witnesses: result vectors + per-equation quotient/carry layout. +/// +/// Each equation allocates q_pos(N) + q_neg(N) + carries(2N-2) = 4N-2 +/// witnesses. +fn hint_output_witnesses(n: usize, result_vecs: usize, num_equations: usize) -> usize { + result_vecs * n + num_equations * (4 * n - 2) +} + +/// Witnesses from N×N schoolbook product grids. +fn schoolbook_product_witnesses(n: usize, product_pairs: usize) -> usize { + product_pairs * n * n +} + +/// Limb-bit range checks from hint outputs: result vecs + quotient pairs +/// (pos+neg). +fn hint_limb_range_checks(n: usize, result_vecs: usize, num_equations: usize) -> usize { + (result_vecs + 2 * num_equations) * n +} + +/// Carry range checks from schoolbook column equations: (2N-2) per equation. +fn hint_carry_range_checks(n: usize, num_equations: usize) -> usize { + num_equations * (2 * n - 2) +} + +/// Carry range check bit-width for schoolbook column equations. +pub(crate) fn hint_carry_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { + let extra_bits = ceil_log2(max_coeff_sum * n as u64) + 1; + limb_bits + extra_bits + 1 +} + +/// Maximum bit-width of a merged column equation value. +pub(crate) fn column_equation_max_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { + let extra_bits = ceil_log2(max_coeff_sum * n as u64) + 1; + 2 * limb_bits + extra_bits + 1 +} + +/// Worst-case max coefficient sum across all hint-verified EC equations. +/// +/// Point double Eq1 has the highest: 2(λy) + 3(x²) + 1(a) + N(q_pos·p) + +/// N(q_neg·p) = 6+2N. +fn worst_case_ec_max_coeff(n: usize) -> u64 { + 6 + 2 * n as u64 +} + +/// Witness and range check costs for a single hint-verified EC operation. +struct HintVerifiedEcCost { + witnesses: usize, + rc_limb: usize, + rc_carry: usize, + carry_bits: u32, +} + +impl HintVerifiedEcCost { + /// Point doubling: λ·2y = 3x²+a, λ² = x3+2x, λ(x-x3) = y3+y. + fn point_double(n: usize, limb_bits: u32) -> Self { + let num_ltp = 3; // λ, x3, y3 + Self { + witnesses: hint_output_witnesses(n, 3, 3) // λ,x3,y3 + 3 equations + + schoolbook_product_witnesses(n, 5) // λ×y, x×x, λ×λ, λ×x, λ×x3 + + n // pinned curve_a limbs + + num_ltp * LTP_WIT_PER_LIMB * n, + rc_limb: hint_limb_range_checks(n, 3, 3) + num_ltp * LTP_RC_PER_LIMB * n, + rc_carry: hint_carry_range_checks(n, 3), + carry_bits: hint_carry_bits(limb_bits, worst_case_ec_max_coeff(n), n), + } + } + + /// Point addition: λ(x2-x1) = y2-y1, λ² = x3+x1+x2, λ(x1-x3) = y3+y1. + fn point_add(n: usize, limb_bits: u32) -> Self { + let num_ltp = 3; // λ, x3, y3 + Self { + witnesses: hint_output_witnesses(n, 3, 3) // λ,x3,y3 + 3 equations + + schoolbook_product_witnesses(n, 4) // λ×x2, λ×x1, λ×λ, λ×x3 + + num_ltp * LTP_WIT_PER_LIMB * n, + rc_limb: hint_limb_range_checks(n, 3, 3) + num_ltp * LTP_RC_PER_LIMB * n, + rc_carry: hint_carry_range_checks(n, 3), + carry_bits: hint_carry_bits(limb_bits, 4 + 2 * n as u64, n), + } + } + + /// On-curve check: x² mod p, then y² = x³ + ax + b mod p. + fn on_curve(n: usize, limb_bits: u32) -> Self { + let num_ltp = 1; // x_sq + Self { + witnesses: hint_output_witnesses(n, 1, 2) // x_sq + 2 equations + + schoolbook_product_witnesses(n, 4) // x×x, y×y, xsq×x, a×x + + 2 * n // pinned curve_a + curve_b limbs + + num_ltp * LTP_WIT_PER_LIMB * n, + rc_limb: hint_limb_range_checks(n, 1, 2) + num_ltp * LTP_RC_PER_LIMB * n, + rc_carry: hint_carry_range_checks(n, 2), + carry_bits: hint_carry_bits(limb_bits, 5 + 2 * n as u64, n), + } + } + + /// Accumulate `count` of this op's range checks into `rc_map`. + fn add_range_checks(&self, count: usize, limb_bits: u32, rc_map: &mut BTreeMap) { + *rc_map.entry(limb_bits).or_default() += count * self.rc_limb; + *rc_map.entry(self.carry_bits).or_default() += count * self.rc_carry; + } +} + +// --------------------------------------------------------------------------- +// Scalar relation cost +// --------------------------------------------------------------------------- + +/// Witnesses and range checks for scalar relation verification. +fn scalar_relation_cost( + native_field_bits: u32, + scalar_bits: usize, +) -> (usize, BTreeMap) { + let limb_bits = scalar_relation_limb_bits(native_field_bits, scalar_bits); + let n = ceil_div(scalar_bits, limb_bits as usize); + let half_bits = (scalar_bits + 1) / 2; + let half_limbs = ceil_div(half_bits, limb_bits as usize); + let scalar_half_limbs = ceil_div(SCALAR_HALF_BITS, limb_bits as usize); + + // Field op witnesses for 1 add + 1 sub + 1 mul (no inv), always multi-limb + // (N≥2 enforced by scalar_relation_limb_bits) + let field_ops_wit = n * n + 14 * n; // 2 × add/sub(1+4N) + 1 × mul(N²+6N-2) + + let has_cross = n > 1 && SCALAR_HALF_BITS % limb_bits as usize != 0; + let witnesses = 2 * scalar_half_limbs // s1, s2 digit decomposition + + has_cross as usize // cross-limb carry + + 2 * n // sign-extended recomposition + + field_ops_wit // add + sub + mul + + 2 // neg1, neg2 flag constants + + n // constrain_to_constant limbs + + 3; // compute_is_zero(s2): inv + product + is_zero + + // Only n limbs worth of scalar DD digits get range checks; unused digits + // are zero-constrained instead (soundness fix for small curves). + let scalar_dd_rcs = n.min(2 * scalar_half_limbs); + let mut rc_map = BTreeMap::new(); + *rc_map.entry(limb_bits).or_default() += scalar_dd_rcs + 2 * half_limbs; + + // Field op range checks for 1 add + 1 sub + 1 mul (always multi-limb) + // add/sub: 2N each (×2 ops), mul: 3N + *rc_map.entry(limb_bits).or_default() += 7 * n; + let carry_bits = limb_bits + ceil_log2(n as u64) + 2; + *rc_map.entry(carry_bits).or_default() += 2 * n - 2; // mul carry chain + + (witnesses, rc_map) +} + +// --------------------------------------------------------------------------- +// MSM cost entry point +// --------------------------------------------------------------------------- + +/// Total estimated witness cost for an MSM. +pub fn calculate_msm_witness_cost( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, + window_size: usize, + limb_bits: u32, + is_native: bool, +) -> usize { + if is_native { + return calculate_msm_witness_cost_native(native_field_bits, n_points, scalar_bits); + } + + // --- Derived parameters --- + + let n = ceil_div(curve_modulus_bits as usize, limb_bits as usize); + assert!( + n >= 2, + "non-native MSM requires num_limbs >= 2, got {n} (limb_bits={limb_bits}, \ + curve_modulus_bits={curve_modulus_bits})" + ); + let half_bits = (scalar_bits + 1) / 2; + let w = window_size; + let half_table_size = 1usize << (w - 1); + let num_windows = ceil_div(half_bits, w); + + // --- Atomic op costs --- + + let ec_double = HintVerifiedEcCost::point_double(n, limb_bits); + let ec_add = HintVerifiedEcCost::point_add(n, limb_bits); + let ec_oncurve = HintVerifiedEcCost::on_curve(n, limb_bits); + let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); + let (tbl_d, tbl_a) = table_build_ops(half_table_size); + + let negate_wit = NEGATE_WIT_PER_LIMB * n; + // negate N y-limbs + N select_witness to pick y vs -y + let negate_and_select = negate_wit + n; + + // --- Per-point witnesses (grouped by pipeline phase) --- + + // Phase 1: preprocessing — sanitize, decompose, on-curve, y preparation + let preprocess = 2 * (half_bits + 1) // scalar bit decomposition + + DETECT_SKIP_WIT // degenerate-case detection + + (2 * n + 2) // sanitize selects (px, py, s_lo, s_hi) + + 4 * n // ec hint (2N outputs + 2N selects) + + GLV_HINT_WIT // FakeGLV hint + + 4 * n // point decomposition (always N≥2 here) + + 2 * ec_oncurve.witnesses // on-curve checks for P and R + + 2 * negate_and_select; // y pre-negate per half-scalar + + // Phase 2: table building — construct [P, 3P, 5P, ...] + mux selects + let table = 2 * (tbl_d * ec_double.witnesses + tbl_a * ec_add.witnesses) + + num_windows * 2 * half_table_size.saturating_sub(1) * 2 * n; + + // Phase 3: EC loop — doublings shared across windows, per-window adds + let doublings = num_windows * w * ec_double.witnesses; + let loop_body = + num_windows * 2 * (ec_add.witnesses + negate_and_select + 2 * w.saturating_sub(1)); + // per window × 2 half-scalars × (add + negate+select + XOR bits) + + // Phase 4: skew correction — per half-scalar: add + negate + point_select + let skew = 2 * (ec_add.witnesses + negate_wit + 2 * n); + + let per_point = preprocess + table + doublings + loop_body + skew + sr_witnesses; + + // --- Accumulation witnesses --- + + let shared_constants = 3 + 2 * n; // gen_x, gen_y, zero + offset(x,y) limbs + + // Per-point: add to accumulator + point_select(2N) for skip handling + let accum_per_point = ec_add.witnesses + 2 * n; + // Boolean product chain tracking all_skipped + let accum_skip_chain = n_points.saturating_sub(1); + // Offset subtraction: add + gen constants(3N, offset_x reused) + + // mask selects(2N) + init(2) + flags(2) + let accum_offset = ec_add.witnesses + 3 * n + 2 * n + 2 + 2; + let accum = n_points * accum_per_point + accum_skip_chain + accum_offset; + + // --- Range checks --- + + let mut rc_map: BTreeMap = BTreeMap::new(); + + // EC op range checks (per-point doubles + table doubles, adds, on-curve) + let doubles_count = num_windows * w + 2 * tbl_d; + let adds_count = 2 * tbl_a + num_windows * 2 + 2; + ec_double.add_range_checks(n_points * doubles_count, limb_bits, &mut rc_map); + ec_add.add_range_checks(n_points * adds_count, limb_bits, &mut rc_map); + ec_oncurve.add_range_checks(n_points * 2, limb_bits, &mut rc_map); + + // Accumulation adds + ec_add.add_range_checks(n_points + 1, limb_bits, &mut rc_map); + + // Negates per point: 2 y_eff + 2/window signed_lookup + 2 skew + let negate_count_pp = 2 + num_windows * 2 + 2; + *rc_map.entry(limb_bits).or_default() += n_points * negate_count_pp * NEGATE_RC_PER_LIMB * n; + // Point decomp limb RCs (2N) + scalar mul hint output limb RCs (2N) + *rc_map.entry(limb_bits).or_default() += n_points * 4 * n; + + // Scalar relation range checks + for (&bits, &count) in &sr_range_checks { + *rc_map.entry(bits).or_default() += n_points * count; + } + + let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); + shared_constants + n_points * per_point + accum + range_check_cost +} + +// --------------------------------------------------------------------------- +// Native-field cost +// --------------------------------------------------------------------------- + +// Native-field hint-verified EC witness costs. +const NATIVE_DOUBLE: usize = 4; // hint(λ,x3,y3) + x_sq product +const NATIVE_ADD: usize = 3; // hint(λ,x3,y3) +const NATIVE_ON_CURVE: usize = 2; // x_sq + x_cu products +const NATIVE_NEGATE: usize = 1; // linear combination +const NATIVE_SELECT: usize = 1; // select_witness +const NATIVE_POINT_SELECT: usize = 2; // x + y select_witness + +/// Native-field MSM cost. +fn calculate_msm_witness_cost_native( + native_field_bits: u32, + n_points: usize, + scalar_bits: usize, +) -> usize { + let half_bits = (scalar_bits + 1) / 2; + let (sr_wit, sr_rc) = scalar_relation_cost(native_field_bits, scalar_bits); + + // Per-bit: double + 2 × (negate y_eff + signed select + add) + let per_bit = NATIVE_DOUBLE + 2 * (NATIVE_NEGATE + NATIVE_SELECT + NATIVE_ADD); + // Per half-scalar skew: negate + add + point_select + let per_skew = NATIVE_NEGATE + NATIVE_ADD + NATIVE_POINT_SELECT; + + let per_point = 2 * NATIVE_ON_CURVE // on-curve checks (2 points) + + 2 * (NATIVE_NEGATE + 2 * NATIVE_SELECT) // y pre-negate per half-scalar + + 2 * (half_bits + 1) // scalar bit decomposition + + DETECT_SKIP_WIT + 4 + 4 + GLV_HINT_WIT // sanitize + ec_hint + glv + + sr_wit + + half_bits * per_bit + + 2 * per_skew; + + let shared_constants = 3 + 2; // gen_x, gen_y, zero + offset(x,y) + + let accum_per_point = NATIVE_ADD + NATIVE_POINT_SELECT; + let accum = n_points * accum_per_point + + n_points.saturating_sub(1) // all_skipped products + + NATIVE_ADD + 2 + 2 + 2; // offset sub: add + 2 const + 2 sel + 2 mask + + // Range checks (only from scalar relation for native) + let mut rc_map: BTreeMap = BTreeMap::new(); + for (&bits, &count) in &sr_rc { + *rc_map.entry(bits).or_default() += n_points * count; + } + let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); + + n_points * per_point + shared_constants + accum + range_check_cost +} + +// --------------------------------------------------------------------------- +// Parameter search +// --------------------------------------------------------------------------- + +/// Picks the widest limb size for scalar-relation arithmetic that fits the +/// native field. +pub(super) fn scalar_relation_limb_bits(native_field_bits: u32, order_bits: usize) -> u32 { + for n in 2..=super::MAX_LIMBS { + let lb = ((order_bits + n - 1) / n) as u32; + if column_equation_fits_native_field(native_field_bits, lb, n) { + return lb; + } + } + + panic!("native field too small for scalar relation verification"); +} + +/// Check whether schoolbook column equation values fit in the native field. +pub fn column_equation_fits_native_field( + native_field_bits: u32, + limb_bits: u32, + num_limbs: usize, +) -> bool { + if num_limbs <= 1 { + return true; + } + let max_coeff_sum = worst_case_ec_max_coeff(num_limbs); + column_equation_max_bits(limb_bits, max_coeff_sum, num_limbs) < native_field_bits +} + +/// Search for optimal `(limb_bits, window_size, num_limbs)` minimizing +/// witness cost. +pub fn get_optimal_msm_params( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, + is_native: bool, +) -> (u32, usize, usize) { + if is_native { + return (native_field_bits, 1, 1); + } + + let max_limb_bits = (native_field_bits.saturating_sub(4)) / 2; + let mut best_cost = usize::MAX; + let mut best_limb_bits = max_limb_bits.min(86); + let mut best_window = 4; + let mut best_num_limbs = ceil_div(curve_modulus_bits as usize, best_limb_bits as usize); + + for lb in 8..=max_limb_bits { + let num_limbs = ceil_div(curve_modulus_bits as usize, lb as usize); + // Non-native path requires num_limbs >= 2 (hint-verified EC ops) + if num_limbs < 2 { + continue; + } + if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { + continue; + } + for ws in 2..=8usize { + let cost = calculate_msm_witness_cost( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + ws, + lb, + false, + ); + if cost < best_cost { + best_cost = cost; + best_limb_bits = lb; + best_window = ws; + best_num_limbs = num_limbs; + } + } + } + + (best_limb_bits, best_window, best_num_limbs) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_optimal_params_bn254_native() { + let (limb_bits, window_size, num_limbs) = get_optimal_msm_params(254, 254, 1, 256, true); + assert_eq!(limb_bits, 254); + assert_eq!(window_size, 1, "native path uses signed-bit wNAF (w=1)"); + assert_eq!(num_limbs, 1, "native path uses 1 limb"); + } + + #[test] + fn test_optimal_params_secp256r1() { + let (limb_bits, window_size, num_limbs) = get_optimal_msm_params(254, 256, 1, 256, false); + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_optimal_params_goldilocks() { + let (limb_bits, window_size, num_limbs) = get_optimal_msm_params(254, 64, 1, 64, false); + assert!( + num_limbs >= 2, + "non-native path requires num_limbs >= 2, got {num_limbs}" + ); + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_column_equation_soundness_boundary() { + // With N=3, max_coeff_sum = 6+2*3 = 12, extra_bits = ceil(log2(36))+1 = 7 + // Need: 2*W + 7 + 1 < 254 → W < 123 + assert!(column_equation_fits_native_field(254, 122, 3)); + assert!(!column_equation_fits_native_field(254, 123, 3)); + assert!(!column_equation_fits_native_field(254, 124, 3)); + } + + #[test] + fn test_secp256r1_limb_bits_not_unsound() { + let (limb_bits, _, num_limbs) = get_optimal_msm_params(254, 256, 1, 256, false); + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "secp256r1 limb_bits={limb_bits} (N={num_limbs}) doesn't fit native field" + ); + } + + #[test] + fn test_scalar_relation_cost_grumpkin() { + let (sr, rc) = scalar_relation_cost(254, 256); + assert_eq!(sr, 70, "grumpkin scalar_relation witnesses changed: {sr}"); + let total_rc: usize = rc.values().sum(); + assert_eq!( + total_rc, 32, + "grumpkin scalar_relation range checks changed: {total_rc}" + ); + } + + #[test] + fn test_scalar_relation_cost_small_curve() { + let (sr, _) = scalar_relation_cost(254, 64); + assert_eq!( + sr, 51, + "64-bit curve scalar_relation witnesses changed: {sr}" + ); + } + + #[test] + fn test_estimate_range_check_cost_basic() { + use crate::range_check::estimate_range_check_cost; + + assert_eq!(estimate_range_check_cost(&BTreeMap::new()), 0); + + let mut checks = BTreeMap::new(); + checks.insert(8u32, 100usize); + let cost = estimate_range_check_cost(&checks); + assert!(cost > 0, "expected nonzero cost for 100 8-bit checks"); + } +} diff --git a/provekit/r1cs-compiler/src/msm/curve/grumpkin.rs b/provekit/r1cs-compiler/src/msm/curve/grumpkin.rs new file mode 100644 index 000000000..2a7f10222 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/curve/grumpkin.rs @@ -0,0 +1,67 @@ +use super::Curve; + +/// Grumpkin: BN254 cycle-companion curve. +/// Base field = BN254 scalar field, order = BN254 base field order. +/// Equation: y² = x³ − 17 +pub struct Grumpkin; + +impl Curve for Grumpkin { + fn field_modulus_p(&self) -> [u64; 4] { + [ + 0x43e1f593f0000001, + 0x2833e84879b97091, + 0xb85045b68181585d, + 0x30644e72e131a029, + ] + } + fn curve_order_n(&self) -> [u64; 4] { + [ + 0x3c208c16d87cfd47, + 0x97816a916871ca8d, + 0xb85045b68181585d, + 0x30644e72e131a029, + ] + } + fn curve_a(&self) -> [u64; 4] { + [0; 4] + } + fn curve_b(&self) -> [u64; 4] { + [ + 0x43e1f593effffff0, + 0x2833e84879b97091, + 0xb85045b68181585d, + 0x30644e72e131a029, + ] + } + fn generator(&self) -> ([u64; 4], [u64; 4]) { + ([1, 0, 0, 0], [ + 0x833fc48d823f272c, + 0x2d270d45f1181294, + 0xcf135e7506a45d63, + 0x0000000000000002, + ]) + } + /// Offset point for accumulation blinding. + /// + /// NUMS (nothing-up-my-sleeve) construction: + /// `x = SHA256("provekit-grumpkin-offset")` interpreted as big-endian + /// integer mod p, incremented until y² = x³ + b has a square root. + /// Canonical (smaller) y is chosen. Reproducible via + /// `scripts/verify_offset_points.py`. + fn offset_point(&self) -> ([u64; 4], [u64; 4]) { + ( + [ + 0x0c7f59b08d3ed494, + 0xc9c7cc25211e2d7a, + 0x39c65342a2e5e9f2, + 0x121b63f644122c3d, + ], + [ + 0xdbecdeb7a68f782d, + 0x10f1f9045c0bc912, + 0x1cd40a11a67012e1, + 0x00767fcc149fc6b3, + ], + ) + } +} diff --git a/provekit/r1cs-compiler/src/msm/curve/mod.rs b/provekit/r1cs-compiler/src/msm/curve/mod.rs new file mode 100644 index 000000000..e2d258012 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/curve/mod.rs @@ -0,0 +1,323 @@ +use { + ark_ff::{AdditiveGroup, PrimeField}, + provekit_common::{ + u256_arith::{mod_add, mod_inv, mod_mul, mod_sub}, + FieldElement, + }, +}; + +mod grumpkin; +mod secp256r1; + +pub use {grumpkin::Grumpkin, provekit_common::u256_arith::U256, secp256r1::Secp256r1}; + +// --------------------------------------------------------------------------- +// Curve trait — the only thing a new curve needs to implement +// --------------------------------------------------------------------------- + +/// Elliptic curve definition for MSM circuit compilation. +/// +/// Each supported curve is a zero-sized struct implementing this trait. +/// Only the 6 required methods (curve constants) must be provided; +/// all derived properties and decomposition helpers have default +/// implementations. +pub trait Curve { + // ===== Required: curve constants ===== + + /// Base field modulus p as 4 × u64 limbs (256-bit, little-endian). + fn field_modulus_p(&self) -> U256; + /// Scalar field order n as 4 × u64 limbs. + fn curve_order_n(&self) -> U256; + /// Weierstrass curve parameter a. + fn curve_a(&self) -> U256; + /// Weierstrass curve parameter b. + fn curve_b(&self) -> U256; + /// Generator point (x, y). + fn generator(&self) -> (U256, U256); + /// Offset point for accumulation (x, y). + fn offset_point(&self) -> (U256, U256); + + // ===== Provided: derived properties ===== + + /// Number of bits in the field modulus. + fn modulus_bits(&self) -> u32 { + bit_length_u256(&self.field_modulus_p()) + } + + /// Returns true if the curve's base field equals the native field + /// (currently BN254 scalar field, but determined dynamically from + /// `FieldElement::MODULUS`). + fn is_native_field(&self) -> bool { + self.field_modulus_p() == FieldElement::MODULUS.0 + } + + /// Number of bits in the curve order n. + fn curve_order_bits(&self) -> u32 { + bit_length_u256(&self.curve_order_n()) + } + + /// Number of bits for the GLV half-scalar: `ceil(order_bits / 2)`. + fn glv_half_bits(&self) -> u32 { + (self.curve_order_bits() + 1) / 2 + } + + /// Convert modulus to a native field element. + fn p_native_fe(&self) -> FieldElement { + curve_native_point_fe(&self.field_modulus_p()) + } + + // ===== Provided: limb decomposition helpers ===== + + fn p_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.field_modulus_p(), limb_bits, num_limbs) + } + fn p_minus_1_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs( + &sub_one_u64_4(&self.field_modulus_p()), + limb_bits, + num_limbs, + ) + } + fn curve_a_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_a(), limb_bits, num_limbs) + } + fn curve_b_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_b(), limb_bits, num_limbs) + } + fn curve_order_n_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_order_n(), limb_bits, num_limbs) + } + fn curve_order_n_minus_1_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&sub_one_u64_4(&self.curve_order_n()), limb_bits, num_limbs) + } + fn generator_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.generator().0, limb_bits, num_limbs) + } + fn offset_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point().0, limb_bits, num_limbs) + } + fn offset_y_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point().1, limb_bits, num_limbs) + } + + /// Compute `[2^n_doublings] * offset_point` on the curve (compile-time + /// only). + fn accumulated_offset(&self, n_doublings: usize) -> (U256, U256) { + let p = self.field_modulus_p(); + let a = self.curve_a(); + let mut x = self.offset_point().0; + let mut y = self.offset_point().1; + for _ in 0..n_doublings { + let (x3, y3) = ec_point_double(&x, &y, &a, &p); + x = x3; + y = y3; + } + (x, y) + } +} + +/// Compute bit length of a 256-bit value. +fn bit_length_u256(val: &U256) -> u32 { + for i in (0..4).rev() { + if val[i] != 0 { + return (i as u32) * 64 + (64 - val[i].leading_zeros()); + } + } + 0 +} + +// --------------------------------------------------------------------------- +// Free functions +// --------------------------------------------------------------------------- + +/// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width each, +/// returned as FieldElements. +pub fn decompose_to_limbs(val: &U256, limb_bits: u32, num_limbs: usize) -> Vec { + assert!(limb_bits > 0, "limb_bits must be positive"); + // Special case: when a single limb needs > 128 bits, FieldElement::from(u128) + // would truncate. Use from_sign_and_limbs to preserve the full value. + if num_limbs == 1 && limb_bits > 128 { + return vec![curve_native_point_fe(val)]; + } + + let mask: u128 = if limb_bits >= 128 { + u128::MAX + } else { + (1u128 << limb_bits) - 1 + }; + let mut result = vec![FieldElement::ZERO; num_limbs]; + let mut remaining = *val; + for item in result.iter_mut() { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + *item = FieldElement::from(lo & mask); + // Shift remaining right by limb_bits + if limb_bits >= 256 { + remaining = [0; 4]; + } else { + let mut shifted = [0u64; 4]; + let word_shift = (limb_bits / 64) as usize; + let bit_shift = limb_bits % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = remaining[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= remaining[i + word_shift + 1] << (64 - bit_shift); + } + } + } + remaining = shifted; + } + } + result +} + +/// Subtract 1 from a U256 value. +fn sub_one_u64_4(val: &U256) -> U256 { + let mut result = *val; + for limb in result.iter_mut() { + if *limb > 0 { + *limb -= 1; + return result; + } + *limb = u64::MAX; // borrow + } + result +} + +/// Converts a 256-bit value ([u64; 4]) into a single native field element. +pub fn curve_native_point_fe(val: &U256) -> FieldElement { + FieldElement::from_sign_and_limbs(true, val) +} + +/// Negate a field element: compute `-val mod p` (i.e., `p - val`). +/// Returns `[0; 4]` when `val` is zero. +pub fn negate_field_element(val: &U256, modulus: &U256) -> U256 { + if *val == [0u64; 4] { + return [0u64; 4]; + } + // val is in [1, p-1], so p - val is in [1, p-1] — no borrow. + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = modulus[i].overflowing_sub(val[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + debug_assert!(!borrow, "negate_field_element: val >= modulus"); + result +} + +/// EC point doubling on y² = x³ + ax + b (compile-time precomputation only). +fn ec_point_double(x: &U256, y: &U256, a: &U256, p: &U256) -> (U256, U256) { + // lambda = (3*x^2 + a) / (2*y) + let x_sq = mod_mul(x, x, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let num = mod_add(&three_x_sq, a, p); + let two_y = mod_add(y, y, p); + let denom_inv = mod_inv(&two_y, p); + let lambda = mod_mul(&num, &denom_inv, p); + + // x3 = lambda^2 - 2*x + let lambda_sq = mod_mul(&lambda, &lambda, p); + let two_x = mod_add(x, x, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + // y3 = lambda * (x - x3) - y + let x_minus_x3 = mod_sub(x, &x3, p); + let lambda_dx = mod_mul(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, y, p); + + (x3, y3) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_offset_point_on_curve_grumpkin() { + let c = Grumpkin; + let x = curve_native_point_fe(&c.offset_point().0); + let y = curve_native_point_fe(&c.offset_point().1); + let b = curve_native_point_fe(&c.curve_b()); + // Grumpkin: y^2 = x^3 + b (a=0) + assert_eq!(y * y, x * x * x + b, "offset point not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_single_double_grumpkin() { + let c = Grumpkin; + let (x4, y4) = c.accumulated_offset(1); + let x = curve_native_point_fe(&x4); + let y = curve_native_point_fe(&y4); + let b = curve_native_point_fe(&c.curve_b()); + // Should still be on curve + assert_eq!(y * y, x * x * x + b, "[2]*offset not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_256_on_curve() { + let c = Grumpkin; + let (x, y) = c.accumulated_offset(256); + let xfe = curve_native_point_fe(&x); + let yfe = curve_native_point_fe(&y); + let b = curve_native_point_fe(&c.curve_b()); + assert_eq!(yfe * yfe, xfe * xfe * xfe + b, "[2^257]G not on Grumpkin"); + } + + #[test] + fn test_offset_point_on_curve_secp256r1() { + let c = Secp256r1; + let p = &c.field_modulus_p(); + let x = &c.offset_point().0; + let y = &c.offset_point().1; + let a = &c.curve_a(); + let b = &c.curve_b(); + // y^2 = x^3 + a*x + b (mod p) + let y_sq = mod_mul(y, y, p); + let x_sq = mod_mul(x, x, p); + let x_cubed = mod_mul(&x_sq, x, p); + let ax = mod_mul(a, x, p); + let x3_plus_ax = mod_add(&x_cubed, &ax, p); + let rhs = mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "offset point not on secp256r1"); + } + + #[test] + fn test_accumulated_offset_secp256r1() { + let c = Secp256r1; + let p = &c.field_modulus_p(); + let a = &c.curve_a(); + let b = &c.curve_b(); + let (x, y) = c.accumulated_offset(256); + // Verify the accumulated offset is on the curve + let y_sq = mod_mul(&y, &y, p); + let x_sq = mod_mul(&x, &x, p); + let x_cubed = mod_mul(&x_sq, &x, p); + let ax = mod_mul(a, &x, p); + let x3_plus_ax = mod_add(&x_cubed, &ax, p); + let rhs = mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "accumulated offset not on secp256r1"); + } + + #[test] + fn test_fe_roundtrip() { + // Verify from_sign_and_limbs / into_bigint roundtrip + let val: [u64; 4] = [42, 0, 0, 0]; + let fe = curve_native_point_fe(&val); + let back = fe.into_bigint().0; + assert_eq!(val, back, "roundtrip failed for small value"); + + let val2: [u64; 4] = [ + 0x6d8bc688cdbffffe, + 0x19a74caa311e13d4, + 0xddeb49cdaa36306d, + 0x06ce1b0827aafa85, + ]; + let fe2 = curve_native_point_fe(&val2); + let back2 = fe2.into_bigint().0; + assert_eq!(val2, back2, "roundtrip failed for offset x"); + } +} diff --git a/provekit/r1cs-compiler/src/msm/curve/secp256r1.rs b/provekit/r1cs-compiler/src/msm/curve/secp256r1.rs new file mode 100644 index 000000000..62bd565a8 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/curve/secp256r1.rs @@ -0,0 +1,74 @@ +use super::Curve; + +/// SECP256R1 (NIST P-256). +/// Equation: y² = x³ + ax + b +pub struct Secp256r1; + +impl Curve for Secp256r1 { + fn field_modulus_p(&self) -> [u64; 4] { + [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001] + } + fn curve_order_n(&self) -> [u64; 4] { + [ + 0xf3b9cac2fc632551, + 0xbce6faada7179e84, + 0xffffffffffffffff, + 0xffffffff00000000, + ] + } + fn curve_a(&self) -> [u64; 4] { + [ + 0xfffffffffffffffc, + 0x00000000ffffffff, + 0x0000000000000000, + 0xffffffff00000001, + ] + } + fn curve_b(&self) -> [u64; 4] { + [ + 0x3bce3c3e27d2604b, + 0x651d06b0cc53b0f6, + 0xb3ebbd55769886bc, + 0x5ac635d8aa3a93e7, + ] + } + fn generator(&self) -> ([u64; 4], [u64; 4]) { + ( + [ + 0xf4a13945d898c296, + 0x77037d812deb33a0, + 0xf8bce6e563a440f2, + 0x6b17d1f2e12c4247, + ], + [ + 0xcbb6406837bf51f5, + 0x2bce33576b315ece, + 0x8ee7eb4a7c0f9e16, + 0x4fe342e2fe1a7f9b, + ], + ) + } + /// Offset point for accumulation blinding. + /// + /// NUMS (nothing-up-my-sleeve) construction: + /// `x = SHA256("provekit-secp256r1-offset")` interpreted as big-endian + /// integer mod p, incremented until y² = x³ + ax + b has a square root. + /// Canonical (smaller) y is chosen. Reproducible via + /// `scripts/verify_offset_points.py`. + fn offset_point(&self) -> ([u64; 4], [u64; 4]) { + ( + [ + 0x3b8d6e63154ac0b8, + 0x9d50c8f4c290feb5, + 0x27080c391ced0ac0, + 0x24d812942f1c942a, + ], + [ + 0x1d028e001bc65cb8, + 0xc4cb905df8bd1f90, + 0x9f519d447e4a2d9d, + 0x7c9e0b6ce248a7a0, + ], + ) + } +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs b/provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs new file mode 100644 index 000000000..9162f6395 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs @@ -0,0 +1,138 @@ +//! Native-field hint-verified EC operations. + +use { + crate::{msm::multi_limb_ops::EcFieldParams, noir_to_r1cs::NoirToR1CSCompiler}, + ark_ff::{Field, PrimeField}, + provekit_common::{witness::WitnessBuilder, FieldElement}, +}; + +/// Hint-verified point doubling for native field. +#[must_use] +pub fn point_double_verified_native( + compiler: &mut NoirToR1CSCompiler, + px: usize, + py: usize, + params: &EcFieldParams, +) -> (usize, usize) { + // Allocate hint witnesses + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcDoubleHint { + output_start: hint_start, + px, + py, + curve_a: params.ec.curve_a_raw, + field_modulus_p: params.modulus_raw, + }); + let lambda = hint_start; + let x3 = hint_start + 1; + let y3 = hint_start + 2; + + // x_sq = px * px (1W + 1C) + let x_sq = compiler.add_product(px, px); + + // Constraint: lambda * (2 * py) = 3 * x_sq + a + // A = [lambda], B = [2*py], C = [3*x_sq + a_const] + let a_fe = FieldElement::from_bigint(ark_ff::BigInt(params.ec.curve_a_raw)) + .expect("curve_a must fit in native field"); + let three = FieldElement::from(3u64); + let two = FieldElement::from(2u64); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, lambda)], &[(two, py)], &[ + (three, x_sq), + (a_fe, compiler.witness_one()), + ]); + + // Constraint: lambda^2 = x3 + 2*px + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x3), (two, px)], + ); + + // Constraint: lambda * (px - x3) = y3 + py + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, px), (-FieldElement::ONE, x3)], + &[(FieldElement::ONE, y3), (FieldElement::ONE, py)], + ); + + (x3, y3) +} + +/// Hint-verified point addition for native field. +#[must_use] +pub fn point_add_verified_native( + compiler: &mut NoirToR1CSCompiler, + x1: usize, + y1: usize, + x2: usize, + y2: usize, + params: &EcFieldParams, +) -> (usize, usize) { + // Allocate hint witnesses + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcAddHint { + output_start: hint_start, + x1, + y1, + x2, + y2, + field_modulus_p: params.modulus_raw, + }); + let lambda = hint_start; + let x3 = hint_start + 1; + let y3 = hint_start + 2; + + // Constraint: lambda * (x2 - x1) = y2 - y1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x2), (-FieldElement::ONE, x1)], + &[(FieldElement::ONE, y2), (-FieldElement::ONE, y1)], + ); + + // Constraint: lambda^2 = x3 + x1 + x2 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, lambda)], + &[ + (FieldElement::ONE, x3), + (FieldElement::ONE, x1), + (FieldElement::ONE, x2), + ], + ); + + // Constraint: lambda * (x1 - x3) = y3 + y1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x1), (-FieldElement::ONE, x3)], + &[(FieldElement::ONE, y3), (FieldElement::ONE, y1)], + ); + + (x3, y3) +} + +/// On-curve check for native field: y² = x³ + a·x + b. +pub fn verify_on_curve_native( + compiler: &mut NoirToR1CSCompiler, + x: usize, + y: usize, + params: &EcFieldParams, +) { + let x_sq = compiler.add_product(x, x); + let x_cu = compiler.add_product(x_sq, x); + + let a_fe = FieldElement::from_bigint(ark_ff::BigInt(params.ec.curve_a_raw)) + .expect("curve_a must fit in native field"); + let b_fe = FieldElement::from_bigint(ark_ff::BigInt(params.ec.curve_b_raw)) + .expect("curve_b must fit in native field"); + + // y * y = x_cu + a*x + b + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, y)], &[(FieldElement::ONE, y)], &[ + (FieldElement::ONE, x_cu), + (a_fe, x), + (b_fe, compiler.witness_one()), + ]); +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs b/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs new file mode 100644 index 000000000..0e3ea31f4 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs @@ -0,0 +1,462 @@ +//! Non-native hint-verified EC operations (multi-limb schoolbook). +//! +//! Each EC op allocates a prover hint and verifies it via schoolbook +//! column equations, avoiding step-by-step field inversions. + +use { + crate::{ + msm::{ + cost_model::{column_equation_max_bits, hint_carry_bits}, + multi_limb_arith::{emit_schoolbook_column_equations, less_than_p_check_multi}, + multi_limb_ops::{allocate_pinned_constant_limbs, EcFieldParams}, + Limbs, + }, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{Field, PrimeField}, + provekit_common::{ + witness::{NonNativeEcOp, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Collect witness indices from `start..start+len`. +fn witness_range(start: usize, len: usize) -> Vec { + (start..start + len).collect() +} + +/// Allocate N×N product witnesses for `a\[i\]*b\[j\]`. +fn make_products(compiler: &mut NoirToR1CSCompiler, a: &[usize], b: &[usize]) -> Vec> { + let n = a.len(); + debug_assert_eq!(n, b.len()); + let mut prods = vec![vec![0usize; n]; n]; + for i in 0..n { + for j in 0..n { + prods[i][j] = compiler.add_product(a[i], b[j]); + } + } + prods +} + +/// Range-check limb witnesses at `limb_bits` and carry witnesses at +/// `carry_range_bits`. +fn range_check_limbs_and_carries( + range_checks: &mut BTreeMap>, + limb_vecs: &[&[usize]], + carry_vecs: &[&[usize]], + limb_bits: u32, + carry_range_bits: u32, +) { + for limbs in limb_vecs { + for &w in *limbs { + range_checks.entry(limb_bits).or_default().push(w); + } + } + for carries in carry_vecs { + for &c in *carries { + range_checks.entry(carry_range_bits).or_default().push(c); + } + } +} + +/// Soundness check: verify that merged column equations fit the native field. +fn check_column_equation_fits(limb_bits: u32, max_coeff_sum: u64, n: usize, op_name: &str) { + let max_bits = column_equation_max_bits(limb_bits, max_coeff_sum, n); + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "{op_name} column equation overflow: limb_bits={limb_bits}, n={n}, needs {max_bits} bits", + ); +} + +// --------------------------------------------------------------------------- +// Parsed hint layout for 3-equation EC ops (point double and add) +// --------------------------------------------------------------------------- + +/// Parsed hint layout: [lambda(N), x3(N), y3(N), +/// q_pos\[0\](N), q_neg\[0\](N), carries\[0\](2N-2), +/// q_pos\[1\](N), q_neg\[1\](N), carries\[1\](2N-2), +/// q_pos\[2\](N), q_neg\[2\](N), carries\[2\](2N-2)] +/// Total: 15N-6 witnesses. +struct EcHint3Eq { + lambda: Vec, + x3: Vec, + y3: Vec, + q_pos: [Vec; 3], + q_neg: [Vec; 3], + carries: [Vec; 3], +} + +impl EcHint3Eq { + fn parse(os: usize, n: usize) -> Self { + Self { + lambda: witness_range(os, n), + x3: witness_range(os + n, n), + y3: witness_range(os + 2 * n, n), + q_pos: [ + witness_range(os + 3 * n, n), + witness_range(os + 7 * n - 2, n), + witness_range(os + 11 * n - 4, n), + ], + q_neg: [ + witness_range(os + 4 * n, n), + witness_range(os + 8 * n - 2, n), + witness_range(os + 12 * n - 4, n), + ], + carries: [ + witness_range(os + 5 * n, 2 * n - 2), + witness_range(os + 9 * n - 2, 2 * n - 2), + witness_range(os + 13 * n - 4, 2 * n - 2), + ], + } + } + + /// Range-check all hint outputs and verify lambda/x3/y3 < p. + fn range_check_and_verify( + &self, + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + max_coeff_sum: u64, + params: &EcFieldParams, + ) -> (Limbs, Limbs) { + let n = params.num_limbs; + let crb = hint_carry_bits(params.limb_bits, max_coeff_sum, n); + range_check_limbs_and_carries( + range_checks, + &[ + &self.lambda, + &self.x3, + &self.y3, + &self.q_pos[0], + &self.q_neg[0], + &self.q_pos[1], + &self.q_neg[1], + &self.q_pos[2], + &self.q_neg[2], + ], + &[&self.carries[0], &self.carries[1], &self.carries[2]], + params.limb_bits, + crb, + ); + for v in [&self.lambda, &self.x3, &self.y3] { + less_than_p_check_multi(compiler, range_checks, Limbs::from(v.as_slice()), params); + } + ( + Limbs::from(self.x3.as_slice()), + Limbs::from(self.y3.as_slice()), + ) + } +} + +/// Hint-verified on-curve check: y² = x³ + ax + b (mod p). +pub fn verify_on_curve_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + px: Limbs, + py: Limbs, + params: &EcFieldParams, +) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified on-curve check requires n >= 2"); + + let a_is_zero = params.ec.curve_a_raw.iter().all(|&v| v == 0); + + let max_coeff_sum: u64 = if a_is_zero { + 4 + 2 * n as u64 + } else { + 5 + 2 * n as u64 + }; + check_column_equation_fits(params.limb_bits, max_coeff_sum, n, "On-curve"); + + // Allocate hint + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::OnCurve, + inputs: vec![px, py], + curve_a: params.ec.curve_a_raw, + curve_b: params.ec.curve_b_raw, + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + // Parse hint layout: [x_sq(N), q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2)] + // Total: 9N-4 + let x_sq = witness_range(os, n); + let q1_pos = witness_range(os + n, n); + let q1_neg = witness_range(os + 2 * n, n); + let c1 = witness_range(os + 3 * n, 2 * n - 2); + let q2_pos = witness_range(os + 5 * n - 2, n); + let q2_neg = witness_range(os + 6 * n - 2, n); + let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); + + // Eq1: px·px - x_sq = q1·p + let prod_px_px = make_products(compiler, px.as_slice(), px.as_slice()); + + let max_coeff_eq1: u64 = 1 + 1 + 2 * n as u64; + emit_schoolbook_column_equations( + compiler, + &[(&prod_px_px, FieldElement::ONE)], + &[(&x_sq, -FieldElement::ONE)], + &q1_pos, + Some(&q1_neg), + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq1, + ); + + // Eq2: py·py - x_sq·px - a·px - b = q2·p + let prod_py_py = make_products(compiler, py.as_slice(), py.as_slice()); + let prod_xsq_px = make_products(compiler, &x_sq, px.as_slice()); + let b_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.ec.curve_b_limbs[..n]); + + if a_is_zero { + let max_coeff_eq2: u64 = 1 + 1 + 1 + 2 * n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_py_py, FieldElement::ONE), + (&prod_xsq_px, -FieldElement::ONE), + ], + &[(&b_limbs, -FieldElement::ONE)], + &q2_pos, + Some(&q2_neg), + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + } else { + let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.ec.curve_a_limbs[..n]); + let prod_a_px = make_products(compiler, &a_limbs, px.as_slice()); + + let max_coeff_eq2: u64 = 1 + 1 + 1 + 1 + 2 * n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_py_py, FieldElement::ONE), + (&prod_xsq_px, -FieldElement::ONE), + (&prod_a_px, -FieldElement::ONE), + ], + &[(&b_limbs, -FieldElement::ONE)], + &q2_pos, + Some(&q2_neg), + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + } + + // Range checks on hint outputs + let crb = hint_carry_bits(params.limb_bits, max_coeff_sum, n); + range_check_limbs_and_carries( + range_checks, + &[&x_sq, &q1_pos, &q1_neg, &q2_pos, &q2_neg], + &[&c1, &c2], + params.limb_bits, + crb, + ); + + // Less-than-p check for x_sq + less_than_p_check_multi(compiler, range_checks, Limbs::from(x_sq.as_slice()), params); +} + +/// Hint-verified point doubling for non-native field (multi-limb). +pub fn point_double_verified_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + px: Limbs, + py: Limbs, + params: &EcFieldParams, +) -> (Limbs, Limbs) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified non-native requires n >= 2"); + + let max_coeff_sum: u64 = 2 + 3 + 1 + 2 * n as u64; // λy(2) + xx(3) + a(1) + pq_pos(N) + pq_neg(N) + check_column_equation_fits(params.limb_bits, max_coeff_sum, n, "Merged EC double"); + + // Allocate hint + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::Double, + inputs: vec![px, py], + curve_a: params.ec.curve_a_raw, + curve_b: [0; 4], // unused for double + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + let h = EcHint3Eq::parse(os, n); + let px_s = px.as_slice(); + let py_s = py.as_slice(); + + // Eq1: 2*lambda*py - 3*px*px - a = q1*p + let prod_lam_py = make_products(compiler, &h.lambda, py_s); + let prod_px_px = make_products(compiler, px_s, px_s); + let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.ec.curve_a_limbs[..n]); + + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_py, FieldElement::from(2u64)), + (&prod_px_px, -FieldElement::from(3u64)), + ], + &[(&a_limbs, -FieldElement::ONE)], + &h.q_pos[0], + Some(&h.q_neg[0]), + &h.carries[0], + ¶ms.p_limbs, + n, + params.limb_bits, + 2 + 3 + 1 + 2 * n as u64, + ); + + // Eq2: lambda² - x3 - 2*px = q2*p + let prod_lam_lam = make_products(compiler, &h.lambda, &h.lambda); + + emit_schoolbook_column_equations( + compiler, + &[(&prod_lam_lam, FieldElement::ONE)], + &[ + (&h.x3, -FieldElement::ONE), + (px_s, -FieldElement::from(2u64)), + ], + &h.q_pos[1], + Some(&h.q_neg[1]), + &h.carries[1], + ¶ms.p_limbs, + n, + params.limb_bits, + 1 + 1 + 2 + 2 * n as u64, + ); + + // Eq3: lambda*px - lambda*x3 - y3 - py = q3*p + let prod_lam_px = make_products(compiler, &h.lambda, px_s); + let prod_lam_x3 = make_products(compiler, &h.lambda, &h.x3); + + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_px, FieldElement::ONE), + (&prod_lam_x3, -FieldElement::ONE), + ], + &[(&h.y3, -FieldElement::ONE), (py_s, -FieldElement::ONE)], + &h.q_pos[2], + Some(&h.q_neg[2]), + &h.carries[2], + ¶ms.p_limbs, + n, + params.limb_bits, + 1 + 1 + 1 + 1 + 2 * n as u64, + ); + + // Worst-case max_coeff across eqs: Eq1 = 6+2N + h.range_check_and_verify(compiler, range_checks, 6 + 2 * n as u64, params) +} + +/// Hint-verified point addition for non-native field (multi-limb). +pub fn point_add_verified_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, + params: &EcFieldParams, +) -> (Limbs, Limbs) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified non-native requires n >= 2"); + + let max_coeff: u64 = 1 + 1 + 1 + 1 + 2 * n as u64; // all 3 eqs: 1+1+1+1+2N + check_column_equation_fits(params.limb_bits, max_coeff, n, "EC add"); + + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::Add, + inputs: vec![x1, y1, x2, y2], + curve_a: [0; 4], // unused for add + curve_b: [0; 4], // unused for add + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + let h = EcHint3Eq::parse(os, n); + let x1_s = x1.as_slice(); + let y1_s = y1.as_slice(); + let x2_s = x2.as_slice(); + let y2_s = y2.as_slice(); + + // Eq1: lambda*x2 - lambda*x1 - y2 + y1 = q1*p + let prod_lam_x2 = make_products(compiler, &h.lambda, x2_s); + let prod_lam_x1 = make_products(compiler, &h.lambda, x1_s); + + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_x2, FieldElement::ONE), + (&prod_lam_x1, -FieldElement::ONE), + ], + &[(y2_s, -FieldElement::ONE), (y1_s, FieldElement::ONE)], + &h.q_pos[0], + Some(&h.q_neg[0]), + &h.carries[0], + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Eq2: lambda² - x3 - x1 - x2 = q2*p + let prod_lam_lam = make_products(compiler, &h.lambda, &h.lambda); + + emit_schoolbook_column_equations( + compiler, + &[(&prod_lam_lam, FieldElement::ONE)], + &[ + (&h.x3, -FieldElement::ONE), + (x1_s, -FieldElement::ONE), + (x2_s, -FieldElement::ONE), + ], + &h.q_pos[1], + Some(&h.q_neg[1]), + &h.carries[1], + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Eq3: lambda*x1 - lambda*x3 - y3 - y1 = q3*p + // Reuse prod_lam_x1 from Eq1 + let prod_lam_x3 = make_products(compiler, &h.lambda, &h.x3); + + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_x1, FieldElement::ONE), + (&prod_lam_x3, -FieldElement::ONE), + ], + &[(&h.y3, -FieldElement::ONE), (y1_s, -FieldElement::ONE)], + &h.q_pos[2], + Some(&h.q_neg[2]), + &h.carries[2], + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // max_coeff across all 3 eqs = 4+2N + h.range_check_and_verify(compiler, range_checks, 4 + 2 * n as u64, params) +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/mod.rs b/provekit/r1cs-compiler/src/msm/ec_points/mod.rs new file mode 100644 index 000000000..0b0d92649 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/mod.rs @@ -0,0 +1,158 @@ +//! Elliptic curve point operations for MSM, dispatched via [`EcOps`]. + +mod hints_native; +mod hints_non_native; +mod tables; + +pub(super) use tables::{scalar_mul_merged_glv, MergedGlvPoint}; +use { + super::{ + multi_limb_ops::{EcFieldParams, FieldArith, MultiLimbField, NativeSingleField}, + EcPoint, Limbs, + }, + crate::noir_to_r1cs::NoirToR1CSCompiler, + std::collections::BTreeMap, +}; + +// --------------------------------------------------------------------------- +// EcOps trait — strategy interface for EC point arithmetic +// --------------------------------------------------------------------------- + +/// Strategy for constraining elliptic curve operations in the circuit. +/// +/// Each impl specifies its associated `Field: FieldArith` type, so +/// `MultiLimbOps` gets both field and EC ops +/// without EC types needing to re-implement field arithmetic. +pub trait EcOps { + /// The field arithmetic strategy paired with this EC strategy. + type Field: FieldArith; + + /// Point doubling: computes 2P. + fn point_double( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &EcFieldParams, + p: EcPoint, + ) -> EcPoint; + + /// Point addition: computes P1 + P2 (requires P1 ≠ ±P2). + fn point_add( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &EcFieldParams, + p1: EcPoint, + p2: EcPoint, + ) -> EcPoint; + + /// On-curve verification: constrains y² = x³ + ax + b. + fn verify_on_curve( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &EcFieldParams, + p: EcPoint, + ); +} + +// --------------------------------------------------------------------------- +// NativeEcOps — hint-verified via raw R1CS (num_limbs=1) +// --------------------------------------------------------------------------- + +/// Native-field EC operations via hint-verified R1CS constraints. +pub(crate) struct NativeEcOps; + +impl EcOps for NativeEcOps { + type Field = NativeSingleField; + + fn point_double( + compiler: &mut NoirToR1CSCompiler, + _range_checks: &mut BTreeMap>, + params: &EcFieldParams, + p: EcPoint, + ) -> EcPoint { + let (x3, y3) = hints_native::point_double_verified_native(compiler, p.x[0], p.y[0], params); + EcPoint { + x: Limbs::single(x3), + y: Limbs::single(y3), + } + } + + fn point_add( + compiler: &mut NoirToR1CSCompiler, + _range_checks: &mut BTreeMap>, + params: &EcFieldParams, + p1: EcPoint, + p2: EcPoint, + ) -> EcPoint { + let (x3, y3) = hints_native::point_add_verified_native( + compiler, p1.x[0], p1.y[0], p2.x[0], p2.y[0], params, + ); + EcPoint { + x: Limbs::single(x3), + y: Limbs::single(y3), + } + } + + fn verify_on_curve( + compiler: &mut NoirToR1CSCompiler, + _range_checks: &mut BTreeMap>, + params: &EcFieldParams, + p: EcPoint, + ) { + hints_native::verify_on_curve_native(compiler, p.x[0], p.y[0], params); + } +} + +// --------------------------------------------------------------------------- +// NonNativeEcOps — hint-verified via schoolbook column equations (num_limbs≥2) +// --------------------------------------------------------------------------- + +/// Non-native EC operations via hint-verified schoolbook column equations. +pub(crate) struct NonNativeEcOps; + +impl EcOps for NonNativeEcOps { + type Field = MultiLimbField; + + fn point_double( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &EcFieldParams, + p: EcPoint, + ) -> EcPoint { + let (x3, y3) = hints_non_native::point_double_verified_non_native( + compiler, + range_checks, + p.x, + p.y, + params, + ); + EcPoint { x: x3, y: y3 } + } + + fn point_add( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &EcFieldParams, + p1: EcPoint, + p2: EcPoint, + ) -> EcPoint { + let (x3, y3) = hints_non_native::point_add_verified_non_native( + compiler, + range_checks, + p1.x, + p1.y, + p2.x, + p2.y, + params, + ); + EcPoint { x: x3, y: y3 } + } + + fn verify_on_curve( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &EcFieldParams, + p: EcPoint, + ) { + hints_non_native::verify_on_curve_non_native(compiler, range_checks, p.x, p.y, params); + } +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/tables.rs b/provekit/r1cs-compiler/src/msm/ec_points/tables.rs new file mode 100644 index 000000000..897d6356f --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/tables.rs @@ -0,0 +1,212 @@ +//! Point table construction and lookup for windowed scalar multiplication. +//! +//! Builds tables of point multiples and performs lookups using bit witnesses +//! for both unsigned and signed-digit windowed approaches. + +use { + super::EcOps, + crate::msm::{ + multi_limb_ops::{EcFieldParams, MultiLimbOps}, + EcPoint, + }, + ark_ff::Field, + provekit_common::{witness::SumTerm, FieldElement}, +}; + +/// Builds a signed point table of odd multiples: +/// T\[0\] = P, T\[1\] = 3P, T\[2\] = 5P, ..., T\[k-1\] = (2k-1)P +/// where k = `half_table_size`. +fn build_signed_point_table( + ops: &mut MultiLimbOps<'_, '_, E::Field, E, EcFieldParams>, + p: EcPoint, + half_table_size: usize, +) -> Vec { + assert!(half_table_size >= 1); + let mut table = Vec::with_capacity(half_table_size); + table.push(p); // T[0] = 1*P + if half_table_size >= 2 { + let two_p = ops.point_double(p); // 2P + for i in 1..half_table_size { + table.push(ops.point_add(table[i - 1], two_p)); + } + } + table +} + +/// Selects T\[d\] from a point table using bit witnesses, where `d = Σ +/// bits\[i\] * 2^i`, via binary tree of point selects. +/// +/// When `constrain_bits` is true, each bit is boolean-constrained. When +/// false, bits are assumed already constrained. +fn table_lookup( + ops: &mut MultiLimbOps<'_, '_, E::Field, E, EcFieldParams>, + table: &[EcPoint], + bits: &[usize], + constrain_bits: bool, +) -> EcPoint { + assert_eq!(table.len(), 1 << bits.len()); + let mut current: Vec = table.to_vec(); + // Process bits from MSB to LSB + for &bit in bits.iter().rev() { + if constrain_bits { + ops.constrain_flag(bit); + } + let half = current.len() / 2; + let mut next = Vec::with_capacity(half); + for i in 0..half { + next.push(ops.point_select_unchecked(bit, current[i], current[i + half])); + } + current = next; + } + current[0] +} + +/// Signed-digit table lookup: selects from a table of odd multiples, +/// conditionally negating y based on the sign bit. +/// +/// `sign_bit` must be boolean-constrained by the caller. +fn signed_table_lookup( + ops: &mut MultiLimbOps<'_, '_, E::Field, E, EcFieldParams>, + table: &[EcPoint], + index_bits: &[usize], + sign_bit: usize, +) -> EcPoint { + let pt = if index_bits.is_empty() { + // w=1: single entry, no lookup needed + assert_eq!(table.len(), 1); + table[0] + } else { + // Compute XOR'd index bits: idx_i = 1 - b_i - MSB + 2*b_i*MSB + let one_w = ops.witness_one(); + let two = FieldElement::from(2u64); + let xor_bits: Vec = index_bits + .iter() + .map(|&bit| { + let prod = ops.product(bit, sign_bit); + ops.sum(vec![ + SumTerm(Some(FieldElement::ONE), one_w), + SumTerm(Some(-FieldElement::ONE), bit), + SumTerm(Some(-FieldElement::ONE), sign_bit), + SumTerm(Some(two), prod), + ]) + }) + .collect(); + + // XOR'd bits are boolean by construction, skip redundant constraints + table_lookup(ops, table, &xor_bits, false) + }; + + let neg_y = ops.negate(pt.y); + let eff_y = ops.select_unchecked(sign_bit, neg_y, pt.y); + + EcPoint { x: pt.x, y: eff_y } +} + +/// Per-point data for merged multi-point GLV scalar multiplication. +pub(in crate::msm) struct MergedGlvPoint { + /// Point P (effective, post-negation) + pub p: EcPoint, + /// Signed-bit decomposition of |s1| (half-scalar for P), LSB first + pub s1_bits: Vec, + /// Skew correction witness for s1 branch (boolean) + pub s1_skew: usize, + /// Point R (effective, post-negation) + pub r: EcPoint, + /// Signed-bit decomposition of |s2| (half-scalar for R), LSB first + pub s2_bits: Vec, + /// Skew correction witness for s2 branch (boolean) + pub s2_skew: usize, +} + +/// Merged multi-point GLV scalar multiplication with shared doublings +/// and signed-digit windows. +pub(in crate::msm) fn scalar_mul_merged_glv( + ops: &mut MultiLimbOps<'_, '_, E::Field, E, EcFieldParams>, + points: &[MergedGlvPoint], + window_size: usize, + offset: EcPoint, +) -> EcPoint { + assert!(!points.is_empty()); + let n = points[0].s1_bits.len(); + let w = window_size; + let half_table_size = 1usize << (w - 1); + + // Build signed point tables (odd multiples) for all points upfront + let tables: Vec<(Vec, Vec)> = points + .iter() + .map(|pt| { + let tp = build_signed_point_table(ops, pt.p, half_table_size); + let tr = build_signed_point_table(ops, pt.r, half_table_size); + (tp, tr) + }) + .collect(); + + let num_windows = (n + w - 1) / w; + let mut acc = offset; + + // Process all windows from MSB down to LSB + for i in (0..num_windows).rev() { + let bit_start = i * w; + let bit_end = std::cmp::min(bit_start + w, n); + let actual_w = bit_end - bit_start; + + // w shared doublings on the accumulator (shared across ALL points) + let mut doubled_acc = acc; + for _ in 0..w { + doubled_acc = ops.point_double(doubled_acc); + } + + let mut cur = doubled_acc; + + // For each point: P branch + R branch (signed-digit lookup) + for (pt, (table_p, table_r)) in points.iter().zip(tables.iter()) { + // --- P branch (s1 window) --- + let s1_window_bits = &pt.s1_bits[bit_start..bit_end]; + let sign_bit_p = s1_window_bits[actual_w - 1]; // MSB + let index_bits_p = &s1_window_bits[..actual_w - 1]; // lower bits + let actual_table_p = if actual_w < w { + &table_p[..1 << (actual_w - 1)] + } else { + &table_p[..] + }; + let looked_up_p = signed_table_lookup(ops, actual_table_p, index_bits_p, sign_bit_p); + // All signed digits are non-zero — no is_zero check needed + cur = ops.point_add(cur, looked_up_p); + + // --- R branch (s2 window) --- + let s2_window_bits = &pt.s2_bits[bit_start..bit_end]; + let sign_bit_r = s2_window_bits[actual_w - 1]; // MSB + let index_bits_r = &s2_window_bits[..actual_w - 1]; // lower bits + let actual_table_r = if actual_w < w { + &table_r[..1 << (actual_w - 1)] + } else { + &table_r[..] + }; + let looked_up_r = signed_table_lookup(ops, actual_table_r, index_bits_r, sign_bit_r); + cur = ops.point_add(cur, looked_up_r); + } + + acc = cur; + } + + // Skew corrections + for pt in points { + // P branch skew + let neg_py = ops.negate(pt.p.y); + let sub_p = ops.point_add(acc, EcPoint { + x: pt.p.x, + y: neg_py, + }); + acc = ops.point_select_unchecked(pt.s1_skew, acc, sub_p); + + // R branch skew + let neg_ry = ops.negate(pt.r.y); + let sub_r = ops.point_add(acc, EcPoint { + x: pt.r.x, + y: neg_ry, + }); + acc = ops.point_select_unchecked(pt.s2_skew, acc, sub_r); + } + + acc +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs new file mode 100644 index 000000000..f3e21eb1a --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -0,0 +1,208 @@ +pub mod cost_model; +pub mod curve; +pub(crate) mod ec_points; +pub(crate) mod multi_limb_arith; +pub(crate) mod multi_limb_ops; +mod pipeline; +mod sanitize; +mod scalar_relation; + +pub use provekit_common::witness::{Limbs, MAX_LIMBS}; +use { + crate::{ + constraint_helpers::{add_constant_witness, constrain_boolean}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field, PrimeField}, + curve::Curve, + ec_points::{NativeEcOps, NonNativeEcOps}, + provekit_common::{ + witness::{ConstantOrR1CSWitness, ConstantTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// An elliptic curve point as named `(x, y)` coordinate limbs. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct EcPoint { + pub x: Limbs, + pub y: Limbs, +} + +/// Scalar inputs are split into two 128-bit halves (s_lo, s_hi). +pub(crate) const SCALAR_HALF_BITS: usize = 128; + +pub(crate) use provekit_common::u256_arith::ceil_log2; + +// --------------------------------------------------------------------------- +// MSM entry point +// --------------------------------------------------------------------------- + +/// MSM outputs in multi-limb form. +pub struct MsmLimbedOutputs { + pub out_x_limbs: Vec, + pub out_y_limbs: Vec, + pub out_inf: usize, +} + +/// MSM circuit configuration parameters. +pub(crate) struct MsmConfig { + pub num_limbs: usize, + pub limb_bits: u32, + pub window_size: usize, +} + +/// Compiles MSM operations for any curve implementing the `Curve` trait. +pub fn add_msm_with_curve( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + MsmLimbedOutputs, + )>, + range_checks: &mut BTreeMap>, + curve: &C, +) { + if msm_ops.is_empty() { + return; + } + + let native_bits = provekit_common::FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let is_native = curve.is_native_field(); + let scalar_bits = curve.curve_order_bits() as usize; + + // Use first op's output limbs to estimate n_points for cost model + let first_num_limbs = msm_ops[0].2.out_x_limbs.len(); + let stride = 2 * first_num_limbs + 1; + let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / stride).sum(); + + let (limb_bits, window_size, num_limbs) = cost_model::get_optimal_msm_params( + native_bits, + curve_bits, + n_points, + scalar_bits, + is_native, + ); + + assert_eq!( + first_num_limbs, num_limbs, + "output limb count ({first_num_limbs}) doesn't match cost model num_limbs ({num_limbs})" + ); + + let config = MsmConfig { + num_limbs, + limb_bits, + window_size, + }; + + // Dispatch once — the entire pipeline is monomorphized for the chosen strategy. + if is_native { + add_msm_inner::(compiler, msm_ops, range_checks, curve, &config); + } else { + add_msm_inner::(compiler, msm_ops, range_checks, curve, &config); + } +} + +/// Inner MSM loop, monomorphized for a specific EC strategy `E`. +fn add_msm_inner( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + MsmLimbedOutputs, + )>, + range_checks: &mut BTreeMap>, + curve: &C, + config: &MsmConfig, +) { + let stride = 2 * config.num_limbs + 1; + + for (points, scalars, outputs) in msm_ops { + assert!( + points.len() % stride == 0, + "points length must be a multiple of {stride} (2*{}+1)", + config.num_limbs + ); + let n = points.len() / stride; + assert_eq!(scalars.len(), 2 * n, "scalars length must be 2x n_points"); + assert_eq!(outputs.out_x_limbs.len(), config.num_limbs); + assert_eq!(outputs.out_y_limbs.len(), config.num_limbs); + + let point_wits: Vec = points.iter().map(|p| resolve_input(compiler, p)).collect(); + let scalar_wits: Vec = scalars.iter().map(|s| resolve_input(compiler, s)).collect(); + + pipeline::process_multi_point::( + compiler, + &point_wits, + &scalar_wits, + &outputs, + n, + config, + range_checks, + curve, + ); + } +} + +/// Resolves a `ConstantOrR1CSWitness` to a witness index. +#[must_use] +fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitness) -> usize { + match input { + ConstantOrR1CSWitness::Witness(idx) => *idx, + ConstantOrR1CSWitness::Constant(value) => { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, *value))); + w + } + } +} + +// --------------------------------------------------------------------------- +// Signed-bit decomposition (shared by native and non-native paths) +// --------------------------------------------------------------------------- + +/// Signed-bit decomposition for wNAF scalar multiplication. +pub(crate) fn decompose_signed_bits( + compiler: &mut NoirToR1CSCompiler, + scalar: usize, + num_bits: usize, +) -> (Vec, usize) { + let start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SignedBitHint { + output_start: start, + scalar, + num_bits, + }); + let bits: Vec = (start..start + num_bits).collect(); + let skew = start + num_bits; + + // Boolean-constrain each bit and skew + for &b in &bits { + constrain_boolean(compiler, b); + } + constrain_boolean(compiler, skew); + + // Reconstruction: scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} + // Rearranged as: scalar + skew + (2^n - 1) - Σ b_i * 2^{i+1} = 0 + let one = compiler.witness_one(); + let two = FieldElement::from(2u64); + let constant = two.pow([num_bits as u64]) - FieldElement::ONE; + let mut b_terms: Vec<(FieldElement, usize)> = bits + .iter() + .enumerate() + .map(|(i, &b)| (-two.pow([(i + 1) as u64]), b)) + .collect(); + b_terms.push((FieldElement::ONE, scalar)); + b_terms.push((FieldElement::ONE, skew)); + b_terms.push((constant, one)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, one)], &b_terms, &[( + FieldElement::ZERO, + one, + )]); + + (bits, skew) +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs new file mode 100644 index 000000000..36e32fef6 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -0,0 +1,485 @@ +//! N-limb modular arithmetic for EC field operations. +//! +//! Provides add/sub/mul/negate mod p for both single-limb (N=1) and multi-limb +//! (N≥2) representations, plus `compute_is_zero` and `less_than_p` checks. + +use { + super::{ceil_log2, multi_limb_ops::ModulusParams, Limbs}, + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, Field, PrimeField}, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Distinguishes modular addition from subtraction in the shared core. +enum ModularOp { + Add, + Sub, +} + +/// Checks if value is zero or not (used by all N values). +/// Returns a boolean witness: 1 if zero, 0 if non-zero. +/// +/// Uses SafeInverse (not Inverse) because the input value may be zero. +/// SafeInverse outputs 0 when the input is 0, and is solved in the Other +/// layer (not batch-inverted), so zero inputs don't poison the batch. +#[must_use] +pub fn compute_is_zero(compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { + let value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SafeInverse(value_inv, value)); + + let value_mul_value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product( + value_mul_value_inv, + value, + value_inv, + )); + + let is_zero = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(is_zero, vec![ + SumTerm(Some(FieldElement::ONE), compiler.witness_one()), + SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), + ])); + + // v × v^(-1) = 1 - is_zero + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, value_inv)], + &[ + (FieldElement::ONE, compiler.witness_one()), + (-FieldElement::ONE, is_zero), + ], + ); + // v × is_zero = 0 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, is_zero)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + + is_zero +} + +// --------------------------------------------------------------------------- +// N≥2 multi-limb path (generalization of wide_ops.rs) +// --------------------------------------------------------------------------- + +/// Shared core for `add_mod_p_multi` and `sub_mod_p_multi`. +fn add_sub_mod_p_core( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + op: ModularOp, + a: Limbs, + b: Limbs, + params: &ModulusParams, +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "add/sub_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + // Witness: q ∈ {0, 1} + let q = compiler.num_witnesses(); + match op { + ModularOp::Add => { + compiler.add_witness_builder(WitnessBuilder::MultiLimbAddQuotient { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + } + ModularOp::Sub => { + compiler.add_witness_builder(WitnessBuilder::MultiLimbSubBorrow { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + } + } + // q is boolean + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); + + let mut r = Limbs::new(); + let mut carry_prev: Option = None; + + for i in 0..n { + // Combine w1 terms to avoid duplicate column indices. + // The offset 2^W is folded into w1_coeff (with -1 for carry_prev + // which also uses w1 implicitly via SumTerm(None, ...)). + let w1_coeff = if carry_prev.is_some() { + params.two_pow_w - FieldElement::ONE + } else { + params.two_pow_w + }; + let mut terms = vec![SumTerm(None, a[i])]; + match op { + ModularOp::Add => { + terms.push(SumTerm(None, b[i])); + terms.push(SumTerm(Some(w1_coeff), w1)); + terms.push(SumTerm(Some(-params.p_limbs[i]), q)); + } + ModularOp::Sub => { + terms.push(SumTerm(Some(-FieldElement::ONE), b[i])); + terms.push(SumTerm(Some(params.p_limbs[i]), q)); + terms.push(SumTerm(Some(w1_coeff), w1)); + } + } + if let Some(carry) = carry_prev { + terms.push(SumTerm(None, carry)); + } + + // carry = floor(sum(terms) / 2^W) + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SumQuotient { + output: carry, + terms: terms.clone(), + divisor: params.two_pow_w, + }); + + // Merged constraint: r[i] = sum(terms) - carry * 2^W + terms.push(SumTerm(Some(-params.two_pow_w), carry)); + r.push(compiler.add_sum(terms)); + carry_prev = Some(carry); + } + + less_than_p_check_multi(compiler, range_checks, r, params); + + r +} + +/// (a + b) mod p for multi-limb values. +#[must_use] +pub fn add_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + params: &ModulusParams, +) -> Limbs { + add_sub_mod_p_core(compiler, range_checks, ModularOp::Add, a, b, params) +} + +/// Negate a multi-limb value: computes `p - y` directly via borrow chain. +#[must_use] +pub fn negate_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + y: Limbs, + params: &ModulusParams, +) -> Limbs { + let n = y.len(); + assert!(n >= 2, "negate_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + let mut r = Limbs::new(); + let mut borrow_prev: Option = None; + + for i in 0..n { + // Combine w1 terms to avoid duplicate column indices. + let w1_coeff = if borrow_prev.is_some() { + params.p_limbs[i] + params.two_pow_w - FieldElement::ONE + } else { + params.p_limbs[i] + params.two_pow_w + }; + let mut terms = vec![ + SumTerm(Some(w1_coeff), w1), + SumTerm(Some(-FieldElement::ONE), y[i]), + ]; + if let Some(borrow) = borrow_prev { + terms.push(SumTerm(None, borrow)); + } + + // borrow = floor(sum(terms) / 2^W) + let borrow = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SumQuotient { + output: borrow, + terms: terms.clone(), + divisor: params.two_pow_w, + }); + + // Merged constraint: r[i] = sum(terms) - borrow * 2^W + terms.push(SumTerm(Some(-params.two_pow_w), borrow)); + let ri = compiler.add_sum(terms); + r.push(ri); + + // Range check r[i] — ensures borrow is uniquely determined + range_checks.entry(params.limb_bits).or_default().push(ri); + + borrow_prev = Some(borrow); + } + + r +} + +/// (a - b) mod p for multi-limb values. +#[must_use] +pub fn sub_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + params: &ModulusParams, +) -> Limbs { + add_sub_mod_p_core(compiler, range_checks, ModularOp::Sub, a, b, params) +} + +/// (a * b) mod p for multi-limb values using schoolbook multiplication. +#[must_use] +pub fn mul_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + params: &ModulusParams, +) -> Limbs { + let n = a.len(); + let limb_bits = params.limb_bits; + assert!(n >= 2, "mul_mod_p_multi requires n >= 2, got n={n}"); + + // Soundness: column values must fit the native field. + { + let ceil_log2_n = ceil_log2(n as u64); + let max_bits = 2 * limb_bits + ceil_log2_n + 3; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "Schoolbook column equation overflow: limb_bits={limb_bits}, n={n} limbs requires \ + {max_bits} bits, but native field is only {} bits. Use smaller limb_bits.", + FieldElement::MODULUS_BIT_SIZE, + ); + } + + let num_carries = 2 * n - 2; + // Carry offset uses max_coeff_sum=1 (products only). The soundness check + // above already verified the full column value fits the native field. + let max_coeff_sum: u64 = 1; + let extra_bits = ceil_log2(max_coeff_sum * n as u64) + 1; + let carry_offset_bits = limb_bits + extra_bits; + + // Step 1: Allocate hint witnesses (q limbs, r limbs, carries) + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbMulModHint { + output_start: os, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: params.modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + + // q[0..n), r[n..2n), carries[2n..4n-2) + let q: Vec = (0..n).map(|i| os + i).collect(); + let r_indices: Vec = (0..n).map(|i| os + n + i).collect(); + let cu: Vec = (0..num_carries).map(|i| os + 2 * n + i).collect(); + + // Step 2: Product witnesses for a[i]*b[j] (n² R1CS constraints) + let mut ab_products = vec![vec![0usize; n]; n]; + for i in 0..n { + for j in 0..n { + ab_products[i][j] = compiler.add_product(a[i], b[j]); + } + } + + // Step 3: Column equations (2n-1 R1CS constraints) + // Equation: a·b - r = p·q (r on LHS with negative coeff, unsigned quotient) + emit_schoolbook_column_equations( + compiler, + &[(&ab_products, FieldElement::ONE)], + &[(&r_indices, -FieldElement::ONE)], + &q, + None, // unsigned quotient — no q_neg + &cu, + ¶ms.p_limbs, + n, + limb_bits, + max_coeff_sum, + ); + + // Step 4: less-than-p check and range checks on r + let mut r_limbs = Limbs::new(); + for &ri in &r_indices { + r_limbs.push(ri); + } + less_than_p_check_multi(compiler, range_checks, r_limbs, params); + + // Step 5: Range checks for q limbs and carries + for i in 0..n { + range_checks.entry(limb_bits).or_default().push(q[i]); + } + // Carry range: limb_bits + extra_bits + 1 (carry_offset_bits + 1) + let carry_range_bits = carry_offset_bits + 1; + for &c in &cu { + range_checks.entry(carry_range_bits).or_default().push(c); + } + + r_limbs +} + +/// Proves r < p by decomposing (p-1) - r via borrow propagation. +pub fn less_than_p_check_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + r: Limbs, + params: &ModulusParams, +) { + let n = r.len(); + let w1 = compiler.witness_one(); + let mut borrow_prev: Option = None; + for i in 0..n { + // Combine w1 terms to avoid duplicate column indices. + let w1_coeff = if borrow_prev.is_some() { + params.p_minus_1_limbs[i] + params.two_pow_w - FieldElement::ONE + } else { + params.p_minus_1_limbs[i] + params.two_pow_w + }; + let mut terms = vec![ + SumTerm(Some(w1_coeff), w1), + SumTerm(Some(-FieldElement::ONE), r[i]), + ]; + if let Some(borrow) = borrow_prev { + terms.push(SumTerm(None, borrow)); + } + + // borrow = floor(sum(terms) / 2^W) + let borrow = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SumQuotient { + output: borrow, + terms: terms.clone(), + divisor: params.two_pow_w, + }); + + // Merged constraint: d[i] = sum(terms) - borrow * 2^W + terms.push(SumTerm(Some(-params.two_pow_w), borrow)); + let d_i = compiler.add_sum(terms); + + // Range check r[i] and d[i] + range_checks.entry(params.limb_bits).or_default().push(r[i]); + range_checks.entry(params.limb_bits).or_default().push(d_i); + + borrow_prev = Some(borrow); + } + + // Constrain final carry = 1 (valid r < p). + if let Some(final_borrow) = borrow_prev { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, final_borrow)], + &[(FieldElement::ONE, compiler.witness_one())], + ); + } +} + +// --------------------------------------------------------------------------- +// Schoolbook column equations (shared by mul_mod_p_multi and non-native EC +// hints) +// --------------------------------------------------------------------------- + +/// Merge terms with the same witness index by summing their coefficients. +fn merge_terms(terms: &[(FieldElement, usize)]) -> Vec<(FieldElement, usize)> { + let mut map: BTreeMap = BTreeMap::new(); + for &(coeff, idx) in terms { + *map.entry(idx).or_insert(FieldElement::ZERO) += coeff; + } + map.into_iter().map(|(idx, c)| (c, idx)).collect() +} + +/// Emit `2N-1` R1CS constraints verifying a schoolbook column equation +/// with unsigned-offset carry chain. +pub(in crate::msm) fn emit_schoolbook_column_equations( + compiler: &mut NoirToR1CSCompiler, + product_sets: &[(&[Vec], FieldElement)], + linear_limbs: &[(&[usize], FieldElement)], + q_pos_witnesses: &[usize], + q_neg_witnesses: Option<&[usize]>, + carry_witnesses: &[usize], + p_limbs: &[FieldElement], + n: usize, + limb_bits: u32, + max_coeff_sum: u64, +) { + let w1 = compiler.witness_one(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + + // Carry offset scaled for the merged equation's coefficients + let extra_bits = ceil_log2(max_coeff_sum * n as u64) + 1; + let carry_offset_bits = limb_bits + extra_bits; + let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); + let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); + let offset_w_minus_carry = offset_w - carry_offset_fe; + + let num_columns = 2 * n - 1; + + for k in 0..num_columns { + // LHS: Σ coeff * products[i][j] for i+j=k + Σ p[i]*q_neg[j] + carry_in + offset + let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + + for &(products, coeff) in product_sets { + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((coeff, products[i][j_val as usize])); + } + } + } + + // Add linear terms (for k < limbs.len() only) + for &(limbs, coeff) in linear_limbs { + if k < limbs.len() { + lhs_terms.push((coeff, limbs[k])); + } + } + + // Add p*q_neg on the LHS (when using split quotients) + if let Some(q_neg) = q_neg_witnesses { + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((p_limbs[i], q_neg[j_val as usize])); + } + } + } + + // Add carry_in and offset + if k > 0 { + lhs_terms.push((FieldElement::ONE, carry_witnesses[k - 1])); + lhs_terms.push((offset_w_minus_carry, w1)); + } else { + lhs_terms.push((offset_w, w1)); + } + + // RHS: Σ p[i]*q_pos[j] for i+j=k + carry_out * W (or offset at last column) + let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + rhs_terms.push((p_limbs[i], q_pos_witnesses[j_val as usize])); + } + } + + if k < num_columns - 1 { + rhs_terms.push((two_pow_w, carry_witnesses[k])); + } else { + // Last column: balance with offset_w (no outgoing carry) + rhs_terms.push((offset_w, w1)); + } + + // Merge terms with the same witness index (products may share cached witnesses) + let lhs_merged = merge_terms(&lhs_terms); + let rhs_merged = merge_terms(&rhs_terms); + compiler + .r1cs + .add_constraint(&lhs_merged, &[(FieldElement::ONE, w1)], &rhs_merged); + } +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs new file mode 100644 index 000000000..bd61e6815 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -0,0 +1,473 @@ +//! `MultiLimbOps` — field and EC arithmetic parameterized by runtime limb +//! count. + +use { + super::{ec_points::EcOps, multi_limb_arith, EcPoint, Limbs}, + crate::{ + constraint_helpers::{constrain_boolean, select_witness}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::{collections::BTreeMap, marker::PhantomData}, +}; + +/// Shared modulus parameters for multi-limb field arithmetic. +/// +/// Used by both EC operations (mod field_modulus_p) and scalar relations +/// (mod curve_order_n). +pub struct ModulusParams { + pub num_limbs: usize, + pub limb_bits: u32, + pub p_limbs: Vec, + pub p_minus_1_limbs: Vec, + pub two_pow_w: FieldElement, + pub modulus_raw: [u64; 4], +} + +/// EC-specific curve constants — only needed by EC point operations. +pub struct CurveEcParams { + pub curve_a_limbs: Vec, + pub curve_a_raw: [u64; 4], + pub curve_b_limbs: Vec, + pub curve_b_raw: [u64; 4], +} + +/// Full parameters for EC field operations (modulus + curve constants). +/// +/// Derefs to `ModulusParams` so modulus fields are accessible directly. +pub struct EcFieldParams { + pub modulus: ModulusParams, + pub ec: CurveEcParams, +} + +impl std::ops::Deref for EcFieldParams { + type Target = ModulusParams; + fn deref(&self) -> &ModulusParams { + &self.modulus + } +} + +impl AsRef for ModulusParams { + fn as_ref(&self) -> &ModulusParams { + self + } +} + +impl AsRef for EcFieldParams { + fn as_ref(&self) -> &ModulusParams { + &self.modulus + } +} + +// --------------------------------------------------------------------------- +// FieldArith trait — strategy interface for field arithmetic dispatch +// --------------------------------------------------------------------------- + +/// Strategy for constraining field arithmetic operations in the circuit. +pub trait FieldArith { + fn field_add( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &ModulusParams, + a: Limbs, + b: Limbs, + ) -> Limbs; + + fn field_sub( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &ModulusParams, + a: Limbs, + b: Limbs, + ) -> Limbs; + + fn field_mul( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &ModulusParams, + a: Limbs, + b: Limbs, + ) -> Limbs; + + fn field_negate( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &ModulusParams, + value: Limbs, + ) -> Limbs; +} + +// --------------------------------------------------------------------------- +// NativeSingleField — native R1CS arithmetic (num_limbs=1) +// --------------------------------------------------------------------------- + +/// Native-field arithmetic: single-limb R1CS add/sub/mul/negate. +pub struct NativeSingleField; + +impl FieldArith for NativeSingleField { + fn field_add( + compiler: &mut NoirToR1CSCompiler, + _range_checks: &mut BTreeMap>, + _params: &ModulusParams, + a: Limbs, + b: Limbs, + ) -> Limbs { + let r = if a[0] == b[0] { + compiler.add_sum(vec![SumTerm(Some(FieldElement::from(2u64)), a[0])]) + } else { + compiler.add_sum(vec![SumTerm(None, a[0]), SumTerm(None, b[0])]) + }; + Limbs::single(r) + } + + fn field_sub( + compiler: &mut NoirToR1CSCompiler, + _range_checks: &mut BTreeMap>, + _params: &ModulusParams, + a: Limbs, + b: Limbs, + ) -> Limbs { + let r = if a[0] == b[0] { + compiler.add_sum(vec![SumTerm(Some(FieldElement::ZERO), a[0])]) + } else { + compiler.add_sum(vec![ + SumTerm(None, a[0]), + SumTerm(Some(-FieldElement::ONE), b[0]), + ]) + }; + Limbs::single(r) + } + + fn field_mul( + compiler: &mut NoirToR1CSCompiler, + _range_checks: &mut BTreeMap>, + _params: &ModulusParams, + a: Limbs, + b: Limbs, + ) -> Limbs { + let r = compiler.add_product(a[0], b[0]); + Limbs::single(r) + } + + fn field_negate( + compiler: &mut NoirToR1CSCompiler, + _range_checks: &mut BTreeMap>, + _params: &ModulusParams, + value: Limbs, + ) -> Limbs { + let r = compiler.add_sum(vec![SumTerm(Some(-FieldElement::ONE), value[0])]); + Limbs::single(r) + } +} + +// --------------------------------------------------------------------------- +// MultiLimbField — schoolbook multi-limb arithmetic (num_limbs≥2) +// --------------------------------------------------------------------------- + +/// Multi-limb field arithmetic: schoolbook add/sub/mul/negate with carry +/// chains. +pub struct MultiLimbField; + +impl FieldArith for MultiLimbField { + fn field_add( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &ModulusParams, + a: Limbs, + b: Limbs, + ) -> Limbs { + multi_limb_arith::add_mod_p_multi(compiler, range_checks, a, b, params) + } + + fn field_sub( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &ModulusParams, + a: Limbs, + b: Limbs, + ) -> Limbs { + multi_limb_arith::sub_mod_p_multi(compiler, range_checks, a, b, params) + } + + fn field_mul( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &ModulusParams, + a: Limbs, + b: Limbs, + ) -> Limbs { + multi_limb_arith::mul_mod_p_multi(compiler, range_checks, a, b, params) + } + + fn field_negate( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &ModulusParams, + value: Limbs, + ) -> Limbs { + multi_limb_arith::negate_mod_p_multi(compiler, range_checks, value, params) + } +} + +impl ModulusParams { + /// Build params for scalar relation verification (mod curve_order_n). + pub fn for_curve_order( + num_limbs: usize, + limb_bits: u32, + curve: &C, + ) -> Self { + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + Self { + num_limbs, + limb_bits, + p_limbs: curve.curve_order_n_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.curve_order_n_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.curve_order_n(), + } + } +} + +impl EcFieldParams { + /// Build params for EC field operations (mod field_modulus_p). + pub fn for_field_modulus( + num_limbs: usize, + limb_bits: u32, + curve: &C, + ) -> Self { + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + Self { + modulus: ModulusParams { + num_limbs, + limb_bits, + p_limbs: curve.p_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.p_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.field_modulus_p(), + }, + ec: CurveEcParams { + curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), + curve_a_raw: curve.curve_a(), + curve_b_limbs: curve.curve_b_limbs(limb_bits, num_limbs), + curve_b_raw: curve.curve_b(), + }, + } + } +} + +/// Allocate a pinned constant witness embedded in the constraint matrix. +#[must_use] +pub fn allocate_pinned_constant(compiler: &mut NoirToR1CSCompiler, value: FieldElement) -> usize { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(value, compiler.witness_one())], + ); + w +} + +/// Allocate pinned constant witnesses from pre-decomposed `FieldElement` limbs. +pub fn allocate_pinned_constant_limbs( + compiler: &mut NoirToR1CSCompiler, + limb_values: &[FieldElement], +) -> Vec { + limb_values + .iter() + .map(|&val| allocate_pinned_constant(compiler, val)) + .collect() +} + +/// Unified field + EC operations struct, parameterized by: +/// - `F`: field arithmetic strategy (`FieldArith`) +/// - `E`: EC point arithmetic strategy (`EcOps`) — `()` when EC ops not needed +/// - `P`: params type — `ModulusParams` for field-only, `EcFieldParams` for EC +pub struct MultiLimbOps<'a, 'p, F, E = (), P: AsRef = ModulusParams> { + pub(in crate::msm) compiler: &'a mut NoirToR1CSCompiler, + pub(in crate::msm) range_checks: &'a mut BTreeMap>, + pub params: &'p P, + _field: PhantomData<(F, E)>, +} + +// ----------------------------------------------------------------- +// Helper methods — available for any F (no trait bound required) +// ----------------------------------------------------------------- + +impl<'a, 'p, F, E, P: AsRef> MultiLimbOps<'a, 'p, F, E, P> { + /// Construct a new `MultiLimbOps`. + pub fn new( + compiler: &'a mut NoirToR1CSCompiler, + range_checks: &'a mut BTreeMap>, + params: &'p P, + ) -> Self { + Self { + compiler, + range_checks, + params, + _field: PhantomData, + } + } + + /// Access the modulus parameters. + fn m(&self) -> &ModulusParams { + self.params.as_ref() + } + + fn n(&self) -> usize { + self.m().num_limbs + } + + /// Returns the witness index for the constant-one wire. + #[must_use] + pub fn witness_one(&self) -> usize { + self.compiler.witness_one() + } + + /// Returns the current number of allocated witnesses. + #[must_use] + pub fn num_witnesses(&self) -> usize { + self.compiler.num_witnesses() + } + + /// Allocates a product witness: `out = a * b`. + #[must_use] + pub fn product(&mut self, a: usize, b: usize) -> usize { + self.compiler.add_product(a, b) + } + + /// Allocates a linear combination witness: `out = Σ terms`. + #[must_use] + pub fn sum(&mut self, terms: Vec) -> usize { + self.compiler.add_sum(terms) + } + + /// Registers a witness builder for deferred witness solving. + pub fn add_witness_builder(&mut self, builder: WitnessBuilder) { + self.compiler.add_witness_builder(builder); + } + + /// Registers a range check: `witness` must fit in `bits` bits. + pub fn register_range_check(&mut self, bits: u32, witness: usize) { + self.range_checks.entry(bits).or_default().push(witness); + } + + /// Constrains `flag` to be boolean (`flag * flag = flag`). + pub fn constrain_flag(&mut self, flag: usize) { + constrain_boolean(self.compiler, flag); + } + + /// Conditional select without boolean constraint on `flag`. + /// Caller must ensure `flag` is already constrained boolean. + #[must_use] + pub fn select_unchecked(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { + let n = self.n(); + let mut out = Limbs::new(); + for i in 0..n { + out.push(select_witness(self.compiler, flag, on_false[i], on_true[i])); + } + out + } + + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if + /// `flag` is 0. Constrains `flag` to be boolean. + #[must_use] + pub fn select(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { + self.constrain_flag(flag); + self.select_unchecked(flag, on_false, on_true) + } + + /// Conditional point select without boolean constraint on `flag`. + /// Returns `on_true` if `flag=1`, `on_false` if `flag=0`. + /// Caller must ensure `flag` is already constrained boolean. + #[must_use] + pub fn point_select_unchecked( + &mut self, + flag: usize, + on_false: EcPoint, + on_true: EcPoint, + ) -> EcPoint { + EcPoint { + x: self.select_unchecked(flag, on_false.x, on_true.x), + y: self.select_unchecked(flag, on_false.y, on_true.y), + } + } + + /// Returns a constant field element from its limb decomposition. + #[must_use] + pub fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Limbs { + let n = self.n(); + assert_eq!( + limbs.len(), + n, + "constant_limbs: expected {n} limbs, got {}", + limbs.len() + ); + Limbs::from(allocate_pinned_constant_limbs(self.compiler, limbs).as_slice()) + } +} + +// ----------------------------------------------------------------- +// Field arithmetic — available when F: FieldArith +// ----------------------------------------------------------------- + +impl> MultiLimbOps<'_, '_, F, E, P> { + /// Negate a multi-limb value: computes `p - value (mod p)`. + #[must_use] + pub fn negate(&mut self, value: Limbs) -> Limbs { + F::field_negate( + self.compiler, + self.range_checks, + self.params.as_ref(), + value, + ) + } + + #[must_use] + pub fn add(&mut self, a: Limbs, b: Limbs) -> Limbs { + assert_eq!(a.len(), self.n(), "add: a.len() != num_limbs"); + assert_eq!(b.len(), self.n(), "add: b.len() != num_limbs"); + F::field_add(self.compiler, self.range_checks, self.params.as_ref(), a, b) + } + + #[must_use] + pub fn sub(&mut self, a: Limbs, b: Limbs) -> Limbs { + assert_eq!(a.len(), self.n(), "sub: a.len() != num_limbs"); + assert_eq!(b.len(), self.n(), "sub: b.len() != num_limbs"); + F::field_sub(self.compiler, self.range_checks, self.params.as_ref(), a, b) + } + + #[must_use] + pub fn mul(&mut self, a: Limbs, b: Limbs) -> Limbs { + assert_eq!(a.len(), self.n(), "mul: a.len() != num_limbs"); + assert_eq!(b.len(), self.n(), "mul: b.len() != num_limbs"); + F::field_mul(self.compiler, self.range_checks, self.params.as_ref(), a, b) + } +} + +// ----------------------------------------------------------------- +// EC point operations — available when E: EcOps, P = EcFieldParams +// ----------------------------------------------------------------- + +impl MultiLimbOps<'_, '_, F, E, EcFieldParams> { + /// Point doubling: computes 2P. + #[must_use] + pub fn point_double(&mut self, p: EcPoint) -> EcPoint { + E::point_double(self.compiler, self.range_checks, self.params, p) + } + + /// Point addition: computes P1 + P2 (requires P1 ≠ ±P2). + #[must_use] + pub fn point_add(&mut self, p1: EcPoint, p2: EcPoint) -> EcPoint { + E::point_add(self.compiler, self.range_checks, self.params, p1, p2) + } + + /// On-curve verification: constrains y² = x³ + ax + b. + pub fn verify_on_curve(&mut self, p: EcPoint) { + E::verify_on_curve(self.compiler, self.range_checks, self.params, p); + } +} diff --git a/provekit/r1cs-compiler/src/msm/pipeline.rs b/provekit/r1cs-compiler/src/msm/pipeline.rs new file mode 100644 index 000000000..13cea8e46 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/pipeline.rs @@ -0,0 +1,415 @@ +//! Unified MSM pipeline for all curves, generic over `E: EcOps`. +//! +//! Orchestrates 4 phases: preprocessing, scalar mul verification, +//! scalar relations, and accumulation. + +use { + super::{ + curve::{self, Curve}, + ec_points::{self, EcOps}, + multi_limb_ops::{EcFieldParams, MultiLimbOps}, + sanitize::{ + emit_ec_scalar_mul_hint_and_sanitize_multi_limb, emit_fakeglv_hint, + sanitize_point_scalar_multi_limb, + }, + scalar_relation, EcPoint, Limbs, MsmConfig, MsmLimbedOutputs, + }, + crate::{ + constraint_helpers::{ + add_constant_witness, constrain_equal, constrain_to_constant, select_witness, + }, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::AdditiveGroup, + curve::decompose_to_limbs as decompose_to_limbs_pub, + provekit_common::FieldElement, + std::collections::BTreeMap, +}; + +/// EC-aware `MultiLimbOps` with both field (`E::Field`) and EC (`E`) ops +/// available. +type EcOpsCtx<'a, 'p, E> = MultiLimbOps<'a, 'p, ::Field, E, EcFieldParams>; + +// --------------------------------------------------------------------------- +// Phase 1 output +// --------------------------------------------------------------------------- + +/// Per-point scalar relation witness indices. +struct ScalarRelationInputs { + s_lo: usize, + s_hi: usize, + s1: usize, + s2: usize, + neg1: usize, + neg2: usize, +} + +/// Per-point data collected during Phase 1 preprocessing. +struct PreprocessedData { + all_skipped: usize, + merged_points: Vec, + scalar_rel_inputs: Vec, + accum_inputs: Vec<(EcPoint, usize)>, +} + +// --------------------------------------------------------------------------- +// Public entry point +// --------------------------------------------------------------------------- + +/// Unified multi-point MSM with per-limb I/O. +pub(super) fn process_multi_point( + compiler: &mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: &MsmLimbedOutputs, + n_points: usize, + config: &MsmConfig, + range_checks: &mut BTreeMap>, + curve: &impl Curve, +) { + let num_limbs = config.num_limbs; + let limb_bits = config.limb_bits; + let one = compiler.witness_one(); + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + + // Generator as limbs + let gen_x_fe_limbs = decompose_to_limbs_pub(&curve.generator().0, limb_bits, num_limbs); + let gen_y_fe_limbs = decompose_to_limbs_pub(&curve.generator().1, limb_bits, num_limbs); + let gen_x_limb_wits: Vec = gen_x_fe_limbs + .iter() + .map(|&v| add_constant_witness(compiler, v)) + .collect(); + let gen_y_limb_wits: Vec = gen_y_fe_limbs + .iter() + .map(|&v| add_constant_witness(compiler, v)) + .collect(); + + // Build params once for all operations + let params = EcFieldParams::for_field_modulus(num_limbs, limb_bits, curve); + // Offset point as limbs for accumulation + let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); + let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); + + // Phase 1: Per-point preprocessing + let data = preprocess_points::( + compiler, + range_checks, + ¶ms, + point_wits, + scalar_wits, + &gen_x_limb_wits, + &gen_y_limb_wits, + zero_witness, + one, + n_points, + config, + curve, + ); + + // Phase 2: Per-point scalar mul verification + verify_scalar_muls::( + compiler, + range_checks, + ¶ms, + &data.merged_points, + &offset_x_values, + &offset_y_values, + config, + curve, + ); + + // Phase 3: Per-point scalar relations + for sr in &data.scalar_rel_inputs { + scalar_relation::verify_scalar_relation( + compiler, + range_checks, + sr.s_lo, + sr.s_hi, + sr.s1, + sr.s2, + sr.neg1, + sr.neg2, + curve, + ); + } + + // Phase 4: Accumulation + output constraining + accumulate_and_constrain_outputs::( + compiler, + range_checks, + ¶ms, + &data.accum_inputs, + outputs, + data.all_skipped, + &offset_x_values, + &offset_y_values, + zero_witness, + config, + curve, + ); +} + +// --------------------------------------------------------------------------- +// Phase 1: Per-point preprocessing +// --------------------------------------------------------------------------- + +/// Per-point preprocessing: sanitize, on-curve check, scalar decomposition, +/// y-negation. +fn preprocess_points( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &EcFieldParams, + point_wits: &[usize], + scalar_wits: &[usize], + gen_x_limb_wits: &[usize], + gen_y_limb_wits: &[usize], + zero_witness: usize, + one: usize, + n_points: usize, + config: &MsmConfig, + curve: &impl Curve, +) -> PreprocessedData { + let num_limbs = config.num_limbs; + let limb_bits = config.limb_bits; + let stride = 2 * num_limbs + 1; + let mut all_skipped: Option = None; + let mut merged_points: Vec = Vec::new(); + let mut scalar_rel_inputs: Vec = Vec::new(); + let mut accum_inputs: Vec<(EcPoint, usize)> = Vec::new(); + + for i in 0..n_points { + // Extract point limbs and inf flag from limbed layout + let base = i * stride; + let mut px_limbs = Limbs::new(); + let mut py_limbs = Limbs::new(); + for j in 0..num_limbs { + let px_j = point_wits[base + j]; + let py_j = point_wits[base + num_limbs + j]; + px_limbs.push(px_j); + py_limbs.push(py_j); + // Non-native: range-check each limb + if num_limbs > 1 { + range_checks.entry(limb_bits).or_default().push(px_j); + range_checks.entry(limb_bits).or_default().push(py_j); + } + } + let inf_flag = point_wits[base + 2 * num_limbs]; + + // Sanitize point-scalar pair + let san = sanitize_point_scalar_multi_limb( + compiler, + px_limbs, + py_limbs, + scalar_wits[2 * i], + scalar_wits[2 * i + 1], + inf_flag, + gen_x_limb_wits, + gen_y_limb_wits, + zero_witness, + one, + ); + + // Track all_skipped + all_skipped = Some(match all_skipped { + None => san.is_skip, + Some(prev) => compiler.add_product(prev, san.is_skip), + }); + + // EcScalarMulHint with multi-limb inputs/outputs + let (rx, ry) = emit_ec_scalar_mul_hint_and_sanitize_multi_limb( + compiler, + &san, + gen_x_limb_wits, + gen_y_limb_wits, + limb_bits, + range_checks, + curve, + ); + + let p = EcPoint { + x: san.px_limbs, + y: san.py_limbs, + }; + let r = EcPoint { x: rx, y: ry }; + + // On-curve checks + { + let mut ops = EcOpsCtx::::new(&mut *compiler, &mut *range_checks, params); + ops.verify_on_curve(p); + ops.verify_on_curve(r); + } + + // FakeGLV decomposition and y-negation + let (s1_witness, s2_witness, neg1_witness, neg2_witness); + let (py_effective, ry_effective, s1_bits, s2_bits, s1_skew, s2_skew); + { + let mut ops = EcOpsCtx::::new(&mut *compiler, &mut *range_checks, params); + + // FakeGLVHint → |s1|, |s2|, neg1, neg2 + (s1_witness, s2_witness, neg1_witness, neg2_witness) = + emit_fakeglv_hint(ops.compiler, san.s_lo, san.s_hi, curve); + + // Signed-bit decomposition of |s1|, |s2| + let half_bits = curve.glv_half_bits() as usize; + (s1_bits, s1_skew) = super::decompose_signed_bits(ops.compiler, s1_witness, half_bits); + (s2_bits, s2_skew) = super::decompose_signed_bits(ops.compiler, s2_witness, half_bits); + + // Conditionally negate y-coordinates + let neg_py = ops.negate(p.y); + let neg_ry = ops.negate(r.y); + py_effective = ops.select(neg1_witness, p.y, neg_py); + ry_effective = ops.select(neg2_witness, r.y, neg_ry); + } + + merged_points.push(ec_points::MergedGlvPoint { + p: EcPoint { + x: p.x, + y: py_effective, + }, + s1_bits, + s1_skew, + r: EcPoint { + x: r.x, + y: ry_effective, + }, + s2_bits, + s2_skew, + }); + + scalar_rel_inputs.push(ScalarRelationInputs { + s_lo: san.s_lo, + s_hi: san.s_hi, + s1: s1_witness, + s2: s2_witness, + neg1: neg1_witness, + neg2: neg2_witness, + }); + accum_inputs.push((r, san.is_skip)); + } + + PreprocessedData { + all_skipped: all_skipped.expect("MSM must have at least one point"), + merged_points, + scalar_rel_inputs, + accum_inputs, + } +} + +// --------------------------------------------------------------------------- +// Phase 2: Per-point scalar mul verification +// --------------------------------------------------------------------------- + +/// Per-point scalar mul verification with independent identity checks. +fn verify_scalar_muls( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &EcFieldParams, + merged_points: &[ec_points::MergedGlvPoint], + offset_x_values: &[FieldElement], + offset_y_values: &[FieldElement], + config: &MsmConfig, + curve: &impl Curve, +) { + let num_limbs = config.num_limbs; + let limb_bits = config.limb_bits; + let window_size = config.window_size; + let half_bits = curve.glv_half_bits() as usize; + let mut ops = EcOpsCtx::::new(&mut *compiler, &mut *range_checks, params); + + // Expected accumulated offset + let glv_num_windows = (half_bits + window_size - 1) / window_size; + let glv_n_doublings = glv_num_windows * window_size; + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); + let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); + let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); + + // Allocate offset once + let offset = EcPoint { + x: ops.constant_limbs(offset_x_values), + y: ops.constant_limbs(offset_y_values), + }; + + for pt in merged_points { + let glv_acc = ec_points::scalar_mul_merged_glv( + &mut ops, + std::slice::from_ref(pt), + window_size, + offset, + ); + + // Per-point identity check + for j in 0..num_limbs { + constrain_to_constant(ops.compiler, glv_acc.x[j], acc_off_x_values[j]); + constrain_to_constant(ops.compiler, glv_acc.y[j], acc_off_y_values[j]); + } + } +} + +// --------------------------------------------------------------------------- +// Phase 4: Accumulation + output constraining +// --------------------------------------------------------------------------- + +/// Accumulates per-point scalar-mul results, subtracts the offset, and +/// constrains the final coordinates to the output witnesses. +fn accumulate_and_constrain_outputs( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + params: &EcFieldParams, + accum_inputs: &[(EcPoint, usize)], + outputs: &MsmLimbedOutputs, + all_skipped: usize, + offset_x_values: &[FieldElement], + offset_y_values: &[FieldElement], + zero_witness: usize, + config: &MsmConfig, + curve: &impl Curve, +) { + let num_limbs = config.num_limbs; + let limb_bits = config.limb_bits; + let mut ops = EcOpsCtx::::new(&mut *compiler, &mut *range_checks, params); + // Allocate offset limbs once + let offset_x = ops.constant_limbs(offset_x_values); + let offset_y = ops.constant_limbs(offset_y_values); + let mut acc = EcPoint { + x: offset_x, + y: offset_y, + }; + + for &(r, is_skip) in accum_inputs { + let cand = ops.point_add(acc, r); + acc = ops.point_select_unchecked(is_skip, cand, acc); + } + + // Offset subtraction + let neg_offset_y_raw = + curve::negate_field_element(&curve.offset_point().1, &curve.field_modulus_p()); + let neg_offset_y_values = curve::decompose_to_limbs(&neg_offset_y_raw, limb_bits, num_limbs); + + let gen_x_limb_values = curve.generator_x_limbs(limb_bits, num_limbs); + let neg_gen_y_raw = curve::negate_field_element(&curve.generator().1, &curve.field_modulus_p()); + let neg_gen_y_values = curve::decompose_to_limbs(&neg_gen_y_raw, limb_bits, num_limbs); + + let sub_pt = EcPoint { + x: { + let g_x = ops.constant_limbs(&gen_x_limb_values); + ops.select(all_skipped, offset_x, g_x) + }, + y: { + let neg_off_y = ops.constant_limbs(&neg_offset_y_values); + let neg_g_y = ops.constant_limbs(&neg_gen_y_values); + ops.select(all_skipped, neg_off_y, neg_g_y) + }, + }; + + let result = ops.point_add(acc, sub_pt); + let compiler = ops.compiler; + + // Output constraining + for j in 0..num_limbs { + let masked_x = select_witness(compiler, all_skipped, result.x[j], zero_witness); + let masked_y = select_witness(compiler, all_skipped, result.y[j], zero_witness); + constrain_equal(compiler, outputs.out_x_limbs[j], masked_x); + constrain_equal(compiler, outputs.out_y_limbs[j], masked_y); + } + constrain_equal(compiler, outputs.out_inf, all_skipped); +} diff --git a/provekit/r1cs-compiler/src/msm/sanitize.rs b/provekit/r1cs-compiler/src/msm/sanitize.rs new file mode 100644 index 000000000..7f3364c6c --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/sanitize.rs @@ -0,0 +1,148 @@ +//! Degenerate-case detection and sanitization helpers for MSM point-scalar +//! pairs. + +use { + super::{curve::Curve, Limbs}, + crate::{ + constraint_helpers::{compute_boolean_or, constrain_boolean, select_witness}, + msm::multi_limb_arith::compute_is_zero, + noir_to_r1cs::NoirToR1CSCompiler, + }, + provekit_common::witness::WitnessBuilder, +}; + +/// Detects whether a point-scalar pair is degenerate (scalar=0 or point at +/// infinity). +fn detect_skip( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, + inf_flag: usize, +) -> usize { + constrain_boolean(compiler, inf_flag); + let is_zero_s_lo = compute_is_zero(compiler, s_lo); + let is_zero_s_hi = compute_is_zero(compiler, s_hi); + let s_is_zero = compiler.add_product(is_zero_s_lo, is_zero_s_hi); + compute_boolean_or(compiler, s_is_zero, inf_flag) +} + +/// Allocates a FakeGLV hint and returns `(s1, s2, neg1, neg2)` witness indices. +pub(super) fn emit_fakeglv_hint( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, + curve: &C, +) -> (usize, usize, usize, usize) { + let glv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::FakeGLVHint { + output_start: glv_start, + s_lo, + s_hi, + curve_order: curve.curve_order_n(), + }); + (glv_start, glv_start + 1, glv_start + 2, glv_start + 3) +} + +/// Sanitized point-scalar inputs (limbed representation). +pub(super) struct SanitizedInputsMultiLimb { + pub px_limbs: Limbs, + pub py_limbs: Limbs, + pub s_lo: usize, + pub s_hi: usize, + pub is_skip: usize, +} + +/// Sanitize a point-scalar pair, replacing degenerate cases with the +/// generator. +pub(super) fn sanitize_point_scalar_multi_limb( + compiler: &mut NoirToR1CSCompiler, + px_limbs: Limbs, + py_limbs: Limbs, + s_lo: usize, + s_hi: usize, + inf_flag: usize, + gen_x_limb_wits: &[usize], + gen_y_limb_wits: &[usize], + zero: usize, + one: usize, +) -> SanitizedInputsMultiLimb { + let n = px_limbs.len(); + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + + let mut san_px = Limbs::new(); + let mut san_py = Limbs::new(); + for i in 0..n { + san_px.push(select_witness( + compiler, + is_skip, + px_limbs[i], + gen_x_limb_wits[i], + )); + san_py.push(select_witness( + compiler, + is_skip, + py_limbs[i], + gen_y_limb_wits[i], + )); + } + + SanitizedInputsMultiLimb { + px_limbs: san_px, + py_limbs: san_py, + s_lo: select_witness(compiler, is_skip, s_lo, one), + s_hi: select_witness(compiler, is_skip, s_hi, zero), + is_skip, + } +} + +/// Emit an `EcScalarMulHint` and sanitize the output limbs. +pub(super) fn emit_ec_scalar_mul_hint_and_sanitize_multi_limb( + compiler: &mut NoirToR1CSCompiler, + san: &SanitizedInputsMultiLimb, + gen_x_limb_wits: &[usize], + gen_y_limb_wits: &[usize], + limb_bits: u32, + range_checks: &mut std::collections::BTreeMap>, + curve: &C, +) -> (Limbs, Limbs) { + let num_limbs = san.px_limbs.len(); + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { + output_start: hint_start, + px_limbs: san.px_limbs.as_slice().to_vec(), + py_limbs: san.py_limbs.as_slice().to_vec(), + s_lo: san.s_lo, + s_hi: san.s_hi, + curve_a: curve.curve_a(), + field_modulus_p: curve.field_modulus_p(), + num_limbs: num_limbs as u32, + limb_bits, + }); + + let mut rx = Limbs::new(); + let mut ry = Limbs::new(); + for i in 0..num_limbs { + let rx_hint = hint_start + i; + let ry_hint = hint_start + num_limbs + i; + // Range-check hint output limbs (native field elements don't need it) + if num_limbs > 1 { + range_checks.entry(limb_bits).or_default().push(rx_hint); + range_checks.entry(limb_bits).or_default().push(ry_hint); + } + // Sanitize: select between hint output and generator + rx.push(select_witness( + compiler, + san.is_skip, + rx_hint, + gen_x_limb_wits[i], + )); + ry.push(select_witness( + compiler, + san.is_skip, + ry_hint, + gen_y_limb_wits[i], + )); + } + + (rx, ry) +} diff --git a/provekit/r1cs-compiler/src/msm/scalar_relation.rs b/provekit/r1cs-compiler/src/msm/scalar_relation.rs new file mode 100644 index 000000000..93e29c277 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/scalar_relation.rs @@ -0,0 +1,234 @@ +//! Scalar relation verification: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 +//! (mod n). +//! +//! Shared by both the native and non-native MSM paths. + +use { + super::{ + cost_model, + curve::Curve, + multi_limb_arith::compute_is_zero, + multi_limb_ops::{ModulusParams, MultiLimbField, MultiLimbOps}, + Limbs, SCALAR_HALF_BITS, + }, + crate::{ + constraint_helpers::constrain_zero, + digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field, PrimeField}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Compute digit widths for decomposing `total_bits` into chunks of at most +/// `max_width` bits. The last chunk may be smaller. +fn limb_widths(total_bits: usize, max_width: u32) -> Vec { + let n = (total_bits + max_width as usize - 1) / max_width as usize; + (0..n) + .map(|i| { + let remaining = total_bits - i * max_width as usize; + remaining.min(max_width as usize) + }) + .collect() +} + +/// Verifies the scalar relation: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 +/// (mod n), and enforces s2 ≠ 0. +pub(super) fn verify_scalar_relation( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + s_lo: usize, + s_hi: usize, + s1_witness: usize, + s2_witness: usize, + neg1_witness: usize, + neg2_witness: usize, + curve: &C, +) { + let order_bits = curve.curve_order_bits() as usize; + let limb_bits = + cost_model::scalar_relation_limb_bits(FieldElement::MODULUS_BIT_SIZE, order_bits); + let num_limbs = (order_bits + limb_bits as usize - 1) / limb_bits as usize; + let half_bits = curve.glv_half_bits() as usize; + + let params = ModulusParams::for_curve_order(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps::::new(compiler, range_checks, ¶ms); + + let s_limbs = decompose_scalar_from_halves(&mut ops, s_lo, s_hi, num_limbs, limb_bits); + let s1_limbs = decompose_half_scalar(&mut ops, s1_witness, num_limbs, half_bits, limb_bits); + let s2_limbs = decompose_half_scalar(&mut ops, s2_witness, num_limbs, half_bits, limb_bits); + + let product = ops.mul(s2_limbs, s_limbs); + + // Sign handling: when signs match check s1+product=0, otherwise s1-product=0. + // XOR = neg1 + neg2 - 2*neg1*neg2 gives 0 for same signs, 1 for different. + let sum = ops.add(s1_limbs, product); + let diff = ops.sub(s1_limbs, product); + + let xor_prod = ops.product(neg1_witness, neg2_witness); + let xor = ops.sum(vec![ + SumTerm(None, neg1_witness), + SumTerm(None, neg2_witness), + SumTerm(Some(-FieldElement::from(2u64)), xor_prod), + ]); + + let effective = ops.select_unchecked(xor, sum, diff); + for i in 0..num_limbs { + constrain_zero(ops.compiler, effective[i]); + } + + // Soundness: s2 must be non-zero. If s2=0 the relation degenerates to + // s1≡0 (mod n) which is trivially satisfiable with s1=0, leaving the + // hint-supplied result point R unconstrained. + let s2_is_zero = compute_is_zero(ops.compiler, s2_witness); + constrain_zero(ops.compiler, s2_is_zero); +} + +/// Decompose a 256-bit scalar from two 128-bit halves into `num_limbs` limbs. +/// +/// When `limb_bits` divides 128 (e.g. 64), limb boundaries align with the +/// s_lo/s_hi split. Otherwise (e.g. 85-bit limbs), one limb straddles bit 128 +/// and is assembled from a partial s_lo digit and a partial s_hi digit. +/// +/// For small curves where `num_limbs * limb_bits < 256`, the digits beyond the +/// used limbs are constrained to zero. This ensures the scalar fits in the +/// representation and prevents truncation attacks. +fn decompose_scalar_from_halves( + ops: &mut MultiLimbOps<'_, '_, MultiLimbField>, + s_lo: usize, + s_hi: usize, + num_limbs: usize, + limb_bits: u32, +) -> Limbs { + let lo_tail = SCALAR_HALF_BITS % limb_bits as usize; + + if lo_tail == 0 { + let widths = limb_widths(SCALAR_HALF_BITS, limb_bits); + let dd_lo = add_digital_decomposition(ops.compiler, widths.clone(), vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, widths.clone(), vec![s_hi]); + let mut limbs = Limbs::new(); + let from_lo = widths.len().min(num_limbs); + for (_, &w) in widths.iter().enumerate().take(from_lo) { + let idx = dd_lo.get_digit_witness_index(limbs.len(), 0); + limbs.push(idx); + ops.register_range_check(w as u32, idx); + } + let from_hi = (num_limbs - from_lo).min(widths.len()); + for (i, &w) in widths.iter().enumerate().take(from_hi) { + let idx = dd_hi.get_digit_witness_index(i, 0); + limbs.push(idx); + ops.register_range_check(w as u32, idx); + } + + // Constrain unused dd_lo digits to zero (small curves where num_limbs + // covers fewer than 128 bits of s_lo). + for i in from_lo..widths.len() { + constrain_zero(ops.compiler, dd_lo.get_digit_witness_index(i, 0)); + } + // Constrain unused dd_hi digits to zero (small curves where the upper + // half is partially or entirely unused). + for i in from_hi..widths.len() { + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(i, 0)); + } + + limbs + } else { + // Example: 85-bit limbs, 254-bit order → + // s_lo DD [85, 43], s_hi DD [42, 86] + // L0 = s_lo[0..85), L1 = s_lo[85..128) | s_hi[0..42), L2 = s_hi[42..128) + let hi_head = limb_bits as usize - lo_tail; + let hi_rest = SCALAR_HALF_BITS - hi_head; + let lo_full = SCALAR_HALF_BITS / limb_bits as usize; + + let lo_widths = limb_widths(SCALAR_HALF_BITS, limb_bits); + let hi_widths = vec![hi_head, hi_rest]; + + let dd_lo = add_digital_decomposition(ops.compiler, lo_widths.clone(), vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, hi_widths, vec![s_hi]); + let mut limbs = Limbs::new(); + + let lo_used = lo_full.min(num_limbs); + for i in 0..lo_used { + let idx = dd_lo.get_digit_witness_index(i, 0); + limbs.push(idx); + ops.register_range_check(limb_bits, idx); + } + + // Cross-boundary limb and hi_rest, only if num_limbs needs them. + let needs_cross = num_limbs > lo_full; + let needs_hi_rest = num_limbs > lo_full + 1 && hi_rest > 0; + + if needs_cross { + let shift = FieldElement::from(2u64).pow([lo_tail as u64]); + let lo_digit = dd_lo.get_digit_witness_index(lo_full, 0); + let hi_digit = dd_hi.get_digit_witness_index(0, 0); + let cross_val = ops.sum(vec![ + SumTerm(None, lo_digit), + SumTerm(Some(shift), hi_digit), + ]); + limbs.push(cross_val); + ops.register_range_check(lo_tail as u32, lo_digit); + ops.register_range_check(hi_head as u32, hi_digit); + } + + if needs_hi_rest { + let idx = dd_hi.get_digit_witness_index(1, 0); + limbs.push(idx); + ops.register_range_check(hi_rest as u32, idx); + } + + // Constrain unused digits to zero for small curves. + // dd_lo: digits beyond lo_used that aren't part of the cross-boundary. + if !needs_cross { + // The tail digit of dd_lo and all dd_hi digits are unused. + for i in lo_used..lo_widths.len() { + constrain_zero(ops.compiler, dd_lo.get_digit_witness_index(i, 0)); + } + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(0, 0)); + if hi_rest > 0 { + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(1, 0)); + } + } else if !needs_hi_rest && hi_rest > 0 { + // Cross-boundary is used but hi_rest digit is unused. + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(1, 0)); + } + + limbs + } +} + +/// Decompose a half-scalar witness into `num_limbs` limbs, zero-padding the +/// upper limbs beyond `half_bits`. +fn decompose_half_scalar( + ops: &mut MultiLimbOps<'_, '_, MultiLimbField>, + witness: usize, + num_limbs: usize, + half_bits: usize, + limb_bits: u32, +) -> Limbs { + let widths = limb_widths(half_bits, limb_bits); + let dd = add_digital_decomposition(ops.compiler, widths.clone(), vec![witness]); + let mut limbs = Limbs::new(); + + for (i, &w) in widths.iter().enumerate() { + let idx = dd.get_digit_witness_index(i, 0); + limbs.push(idx); + ops.register_range_check(w as u32, idx); + } + + for _ in widths.len()..num_limbs { + let w = ops.num_witnesses(); + ops.add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + FieldElement::ZERO, + ))); + limbs.push(w); + constrain_zero(ops.compiler, w); + } + + limbs +} diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 189eb4693..f9a6247a1 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -3,6 +3,7 @@ use { binops::add_combined_binop_constraints, digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, memory::{add_ram_checking, add_rom_checking, MemoryBlock, MemoryOperation}, + msm::{add_msm_with_curve, MsmLimbedOutputs}, poseidon2::add_poseidon2_permutation, range_check::add_range_checks, sha256_compression::add_sha256_compression, @@ -88,12 +89,17 @@ pub struct R1CSBreakdown { pub poseidon2_constraints: usize, /// Witnesses from Poseidon2 permutation pub poseidon2_witnesses: usize, + + /// Constraints from multi-scalar multiplication + pub msm_constraints: usize, + /// Witnesses from multi-scalar multiplication + pub msm_witnesses: usize, } /// Compiles an ACIR circuit into an [R1CS] instance, comprising of the A, B, /// and C R1CS matrices, along with the witness vector. -pub(crate) struct NoirToR1CSCompiler { - pub(crate) r1cs: R1CS, +pub struct NoirToR1CSCompiler { + pub r1cs: R1CS, /// Indicates how to solve for each R1CS witness pub witness_builders: Vec, @@ -136,7 +142,7 @@ pub fn noir_to_r1cs_with_breakdown( } impl NoirToR1CSCompiler { - pub(crate) fn new() -> Self { + pub fn new() -> Self { let mut r1cs = R1CS::new(); // Grow the matrices to account for the constant one witness. r1cs.add_witnesses(1); @@ -457,6 +463,7 @@ impl NoirToR1CSCompiler { let mut xor_ops = vec![]; let mut sha256_compression_ops = vec![]; let mut poseidon2_ops = vec![]; + let mut msm_ops = vec![]; let mut breakdown = R1CSBreakdown::default(); @@ -627,6 +634,24 @@ impl NoirToR1CSCompiler { output_witnesses, )); } + BlackBoxFuncCall::MultiScalarMul { + points, + scalars, + outputs, + } => { + let point_wits: Vec = points + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let scalar_wits: Vec = scalars + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let out_x = self.fetch_r1cs_witness_index(outputs.0); + let out_y = self.fetch_r1cs_witness_index(outputs.1); + let out_inf = self.fetch_r1cs_witness_index(outputs.2); + msm_ops.push((point_wits, scalar_wits, (out_x, out_y, out_inf))); + } _ => { unimplemented!("Other black box function: {:?}", black_box_func_call); } @@ -718,6 +743,27 @@ impl NoirToR1CSCompiler { breakdown.poseidon2_constraints = self.r1cs.num_constraints() - constraints_before_poseidon; breakdown.poseidon2_witnesses = self.num_witnesses() - witnesses_before_poseidon; + let constraints_before_msm = self.r1cs.num_constraints(); + let witnesses_before_msm = self.num_witnesses(); + let limbed_msm_ops = msm_ops + .into_iter() + .map(|(points, scalars, (out_x, out_y, out_inf))| { + (points, scalars, MsmLimbedOutputs { + out_x_limbs: vec![out_x], + out_y_limbs: vec![out_y], + out_inf, + }) + }) + .collect(); + add_msm_with_curve( + self, + limbed_msm_ops, + &mut range_checks, + &crate::msm::curve::Grumpkin, + ); + breakdown.msm_constraints = self.r1cs.num_constraints() - constraints_before_msm; + breakdown.msm_witnesses = self.num_witnesses() - witnesses_before_msm; + breakdown.range_ops_total = range_checks.values().map(|v| v.len()).sum(); let constraints_before_range = self.r1cs.num_constraints(); let witnesses_before_range = self.num_witnesses(); diff --git a/provekit/r1cs-compiler/src/range_check.rs b/provekit/r1cs-compiler/src/range_check.rs index f76fe94c3..763576ac4 100644 --- a/provekit/r1cs-compiler/src/range_check.rs +++ b/provekit/r1cs-compiler/src/range_check.rs @@ -139,13 +139,46 @@ fn get_optimal_base_width(collected: &[RangeCheckRequest]) -> u32 { optimal_width } +/// Estimates total witness cost for resolving range checks without +/// constructing actual R1CS constraints. +/// +/// Takes a map of `bit_width → count` (number of witnesses needing that +/// range check). Uses the same optimal-base-width search and +/// LogUp-vs-naive cost model as [`add_range_checks`], but operates on +/// aggregate counts rather than concrete witness indices. +pub(crate) fn estimate_range_check_cost(checks: &BTreeMap) -> usize { + if checks.is_empty() { + return 0; + } + + // Create synthetic RangeCheckRequests with unique dummy indices. + let mut collected: Vec = Vec::new(); + let mut dummy_idx = 0usize; + for (&bits, &count) in checks { + for _ in 0..count { + collected.push(RangeCheckRequest { + witness_idx: dummy_idx, + bits, + }); + dummy_idx += 1; + } + } + + if collected.is_empty() { + return 0; + } + + let base_width = get_optimal_base_width(&collected); + calculate_witness_cost(base_width, &collected) +} + /// Add witnesses and constraints that ensure that the values of the witness /// belong to a range 0..2^k (for some k). /// /// Uses dynamic base width optimization: all range check requests are /// collected, and the optimal decomposition base width is determined by /// minimizing the total witness count (memory cost). The search evaluates -/// every base width from [MIN_BASE_WIDTH] to [MAX_BASE_WIDTH]. For each +/// every base width from \[MIN_BASE_WIDTH\] to \[MAX_BASE_WIDTH\]. For each /// candidate, the cost model picks the cheaper of LogUp and naive for /// every atomic bucket. /// @@ -156,7 +189,7 @@ fn get_optimal_base_width(collected: &[RangeCheckRequest]) -> u32 { /// /// `range_checks` is a map from the number of bits k to the vector of /// witness indices that are to be constrained within the range [0..2^k]. -pub(crate) fn add_range_checks( +pub fn add_range_checks( r1cs: &mut NoirToR1CSCompiler, range_checks: BTreeMap>, ) -> Option { diff --git a/scripts/verify_offset_points.py b/scripts/verify_offset_points.py new file mode 100644 index 000000000..ffe5a3943 --- /dev/null +++ b/scripts/verify_offset_points.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +Verify MSM offset points are on-curve and reproducible via SHA256 +try-and-increment (NUMS construction). + +Each offset point is generated as: + 1. x = SHA256(seed) interpreted as big-endian integer mod p + 2. Increment x until y² = x³ + ax + b (mod p) has a square root + 3. Pick the canonical (smaller) y + +Usage: python3 scripts/verify_offset_points.py +""" + +from __future__ import annotations + +import hashlib + + +def to_int(limbs): + """Convert [u64; 4] little-endian limbs to a Python int.""" + result = 0 + for i, limb in enumerate(limbs): + result |= limb << (64 * i) + return result + + +def mod_sqrt(a, p): + """Tonelli-Shanks modular square root. Returns sqrt or None.""" + if a % p == 0: + return 0 + if pow(a, (p - 1) // 2, p) != 1: + return None + + # p = 3 (mod 4) shortcut + if p % 4 == 3: + return pow(a, (p + 1) // 4, p) + + # Factor out powers of 2: p - 1 = Q * 2^S + Q, S = p - 1, 0 + while Q % 2 == 0: + Q //= 2 + S += 1 + + # Find a quadratic non-residue + z = 2 + while pow(z, (p - 1) // 2, p) != p - 1: + z += 1 + + M = S + c = pow(z, Q, p) + t = pow(a, Q, p) + R = pow(a, (Q + 1) // 2, p) + + while True: + if t == 1: + return R + i = 1 + tmp = (t * t) % p + while tmp != 1: + tmp = (tmp * tmp) % p + i += 1 + b = pow(c, 1 << (M - i - 1), p) + M = i + c = (b * b) % p + t = (t * c) % p + R = (R * b) % p + + +def try_and_increment(seed, p, a, b, max_attempts=1000): + """ + NUMS point generation via try-and-increment. + SHA256(seed) -> x candidate, increment until y^2 = x^3 + ax + b is a QR. + """ + h = hashlib.sha256(seed.encode()).digest() + x = int.from_bytes(h, "big") % p + + for attempt in range(max_attempts): + rhs = (pow(x, 3, p) + a * x + b) % p + y = mod_sqrt(rhs, p) + if y is not None: + # Pick the smaller y (canonical) + if y > p - y: + y = p - y + return x, y, attempt + x = (x + 1) % p + + return None, None, max_attempts + + +# ========================================================================= +# Curve definitions (must match Rust constants in curve/grumpkin.rs and +# curve/secp256r1.rs) +# ========================================================================= + +CURVES = { + "grumpkin": { + "p": to_int( + [ + 0x43E1F593F0000001, + 0x2833E84879B97091, + 0xB85045B68181585D, + 0x30644E72E131A029, + ] + ), + "a": 0, + "b": to_int( + [ + 0x43E1F593EFFFFFF0, + 0x2833E84879B97091, + 0xB85045B68181585D, + 0x30644E72E131A029, + ] + ), + "offset_x": to_int( + [ + 0x0C7F59B08D3ED494, + 0xC9C7CC25211E2D7A, + 0x39C65342A2E5E9F2, + 0x121B63F644122C3D, + ] + ), + "offset_y": to_int( + [ + 0xDBECDEB7A68F782D, + 0x10F1F9045C0BC912, + 0x1CD40A11A67012E1, + 0x00767FCC149FC6B3, + ] + ), + "seed": "provekit-grumpkin-offset", + }, + "secp256r1": { + "p": to_int([0xFFFFFFFFFFFFFFFF, 0xFFFFFFFF, 0x0, 0xFFFFFFFF00000001]), + "a": to_int( + [ + 0xFFFFFFFFFFFFFFFC, + 0x00000000FFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFF00000001, + ] + ), + "b": to_int( + [ + 0x3BCE3C3E27D2604B, + 0x651D06B0CC53B0F6, + 0xB3EBBD55769886BC, + 0x5AC635D8AA3A93E7, + ] + ), + "offset_x": to_int( + [ + 0x3B8D6E63154AC0B8, + 0x9D50C8F4C290FEB5, + 0x27080C391CED0AC0, + 0x24D812942F1C942A, + ] + ), + "offset_y": to_int( + [ + 0x1D028E001BC65CB8, + 0xC4CB905DF8BD1F90, + 0x9F519D447E4A2D9D, + 0x7C9E0B6CE248A7A0, + ] + ), + "seed": "provekit-secp256r1-offset", + }, +} + + +def verify_on_curve(name, curve): + p, a, b = curve["p"], curve["a"], curve["b"] + x, y = curve["offset_x"], curve["offset_y"] + lhs = pow(y, 2, p) + rhs = (pow(x, 3, p) + a * x % p + b) % p + ok = lhs == rhs + print(" on-curve: %s" % ("PASS" if ok else "FAIL")) + if not ok: + print(" y^2 mod p = %d" % lhs) + print(" x^3+ax+b mod p = %d" % rhs) + return ok + + +def verify_reproduction(name, curve): + p, a, b = curve["p"], curve["a"], curve["b"] + seed = curve["seed"] + expected_x = curve["offset_x"] + expected_y = curve["offset_y"] + + x, y, attempts = try_and_increment(seed, p, a, b) + if x is None: + print(" reproduce: FAIL (no point found in 1000 attempts)") + return False + + # Check both y and p-y (either sign is valid) + match_x = x == expected_x + match_y = y == expected_y or (p - y) == expected_y + + if match_x and match_y: + print(' reproduce: PASS (SHA256("%s") + %d increments)' % (seed, attempts)) + return True + else: + print(" reproduce: MISMATCH") + print(" expected x: 0x%064x" % expected_x) + print(" got x: 0x%064x" % x) + print(" expected y: 0x%064x" % expected_y) + print(" got y: 0x%064x" % y) + print(' (after %d increments from SHA256("%s"))' % (attempts, seed)) + return False + + +def main(): + all_ok = True + for name, curve in CURVES.items(): + print("\n%s:" % name) + on_curve = verify_on_curve(name, curve) + reproduced = verify_reproduction(name, curve) + if not on_curve or not reproduced: + all_ok = False + + print() + if all_ok: + print("All offset points verified: on-curve and reproducible from seed.") + else: + print("SOME CHECKS FAILED.") + return 0 if all_ok else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tooling/provekit-bench/Cargo.toml b/tooling/provekit-bench/Cargo.toml index 52b52b57d..12c100371 100644 --- a/tooling/provekit-bench/Cargo.toml +++ b/tooling/provekit-bench/Cargo.toml @@ -16,6 +16,7 @@ provekit-r1cs-compiler.workspace = true provekit-verifier.workspace = true # Noir language +acir.workspace = true nargo.workspace = true nargo_cli.workspace = true nargo_toml.workspace = true @@ -23,10 +24,12 @@ noirc_driver.workspace = true # 3rd party anyhow.workspace = true +ark-ff.workspace = true divan.workspace = true serde.workspace = true test-case.workspace = true toml.workspace = true +whir.workspace = true [lints] workspace = true diff --git a/tooling/provekit-bench/tests/compiler.rs b/tooling/provekit-bench/tests/compiler.rs index 828b84b93..cfb12b4cf 100644 --- a/tooling/provekit-bench/tests/compiler.rs +++ b/tooling/provekit-bench/tests/compiler.rs @@ -23,7 +23,7 @@ struct NargoTomlPackage { name: String, } -fn test_noir_compiler(test_case_path: impl AsRef) { +fn test_noir_compiler(test_case_path: impl AsRef, witness_file: &str) { let test_case_path = test_case_path.as_ref(); compile_workspace(test_case_path).expect("Compiling workspace"); @@ -36,7 +36,7 @@ fn test_noir_compiler(test_case_path: impl AsRef) { let package_name = nargo_toml.package.name; let circuit_path = test_case_path.join(format!("target/{package_name}.json")); - let witness_file_path = test_case_path.join("Prover.toml"); + let witness_file_path = test_case_path.join(witness_file); let schema = NoirCompiler::from_file(&circuit_path, provekit_common::HashConfig::default()) .expect("Reading proof scheme"); @@ -69,19 +69,57 @@ pub fn compile_workspace(workspace_path: impl AsRef) -> Result Ok(workspace) } -#[test_case("../../noir-examples/noir-r1cs-test-programs/acir_assert_zero")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/simplest-read-only-memory")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/read-only-memory")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/range-check-u8")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/range-check-u16")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/range-check-mixed-bases")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/read-write-memory")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/conditional-write")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/bin-opcode")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/small-sha")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/bounded-vec")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/brillig-unconstrained")] -#[test_case("../../noir-examples/noir-passport-monolithic/complete_age_check"; "complete_age_check")] -fn case_noir(path: &str) { - test_noir_compiler(path); +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/acir_assert_zero", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/simplest-read-only-memory", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/read-only-memory", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/range-check-u8", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/range-check-u16", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/range-check-mixed-bases", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/read-write-memory", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/conditional-write", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/bin-opcode", + "Prover.toml" +)] +#[test_case("../../noir-examples/noir-r1cs-test-programs/small-sha", "Prover.toml")] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/bounded-vec", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/brillig-unconstrained", + "Prover.toml" +)] +#[test_case("../../noir-examples/noir-passport-monolithic/complete_age_check", "Prover.toml"; "complete_age_check")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover.toml"; "embedded_curve_msm")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_zero_scalars.toml"; "msm_zero_scalars")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_single_nonzero.toml"; "msm_single_nonzero")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_near_order.toml"; "msm_near_order")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_near_identity.toml"; "msm_near_identity")] +fn case_noir(path: &str, witness_file: &str) { + test_noir_compiler(path, witness_file); } diff --git a/tooling/provekit-bench/tests/msm_witness_solving.rs b/tooling/provekit-bench/tests/msm_witness_solving.rs new file mode 100644 index 000000000..db9b60afa --- /dev/null +++ b/tooling/provekit-bench/tests/msm_witness_solving.rs @@ -0,0 +1,527 @@ +//! End-to-end MSM witness solving tests for non-native curves (secp256r1). +//! +//! These tests verify that the full pipeline works correctly: +//! 1. Compile MSM circuit (R1CS + witness builders) +//! 2. Set initial witness values (point coordinates as limbs + scalar) +//! 3. Solve all derived witnesses via the witness builder layer scheduler +//! 4. Check R1CS satisfaction: A·w ⊙ B·w = C·w for all constraints +//! +//! All tests use the **limbed API** (`add_msm_with_curve`) where +//! point coordinates are multi-limb witnesses, supporting arbitrary +//! secp256r1 coordinates (including those exceeding BN254 Fr). + +use { + acir::native_types::WitnessMap, + ark_ff::{PrimeField, Zero}, + provekit_common::{ + witness::{ConstantOrR1CSWitness, LayerScheduler, WitnessBuilder}, + FieldElement, NoirElement, TranscriptSponge, + }, + provekit_prover::{bigint_mod::ec_scalar_mul, r1cs::solve_witness_vec}, + provekit_r1cs_compiler::{ + msm::{ + add_msm_with_curve, + cost_model::get_optimal_msm_params, + curve::{decompose_to_limbs, Curve, Secp256r1}, + MsmLimbedOutputs, + }, + noir_to_r1cs::NoirToR1CSCompiler, + range_check::add_range_checks, + }, + std::collections::BTreeMap, + whir::transcript::{codecs::Empty, DomainSeparator, ProverState}, +}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Convert a [u64; 4] to a FieldElement. Panics if value exceeds BN254 Fr. +/// Only used for scalars (128-bit halves that always fit). +fn u256_to_fe(v: &[u64; 4]) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt(*v)) + .unwrap_or_else(|| panic!("Value exceeds BN254 Fr: {v:?}")) +} + +/// Split a 256-bit scalar into (lo_128, hi_128) as [u64; 4] values. +fn split_scalar(s: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + let lo = [s[0], s[1], 0, 0]; + let hi = [s[2], s[3], 0, 0]; + (lo, hi) +} + +/// Verify R1CS satisfaction: for each constraint row, A·w * B·w == C·w. +fn check_r1cs_satisfaction( + r1cs: &provekit_common::R1CS, + witness: &[FieldElement], +) -> anyhow::Result<()> { + use anyhow::ensure; + + ensure!( + witness.len() == r1cs.num_witnesses(), + "witness size {} != expected {}", + witness.len(), + r1cs.num_witnesses() + ); + + let a = r1cs.a() * witness; + let b = r1cs.b() * witness; + let c = r1cs.c() * witness; + for (row, ((a_val, b_val), c_val)) in a.into_iter().zip(b).zip(c).enumerate() { + ensure!( + a_val * b_val == c_val, + "Constraint {row} failed: a={a_val:?}, b={b_val:?}, a*b={:?}, c={c_val:?}", + a_val * b_val + ); + } + Ok(()) +} + +/// Create a dummy transcript for witness solving (no challenges needed). +fn dummy_transcript() -> ProverState { + let ds = DomainSeparator::protocol(&()).instance(&Empty); + ProverState::new(&ds, TranscriptSponge::default()) +} + +/// Solve all witness builders given initial witness values. +fn solve_witnesses( + builders: &[WitnessBuilder], + num_witnesses: usize, + initial_values: &[(usize, FieldElement)], +) -> Vec { + let layers = LayerScheduler::new(builders).build_layers(); + let mut witness: Vec> = vec![None; num_witnesses]; + + for &(idx, val) in initial_values { + witness[idx] = Some(val); + } + + let acir_map = WitnessMap::::new(); + let mut transcript = dummy_transcript(); + solve_witness_vec(&mut witness, layers, &acir_map, &mut transcript); + + witness + .into_iter() + .enumerate() + .map(|(i, w)| w.unwrap_or_else(|| panic!("Witness {i} was not solved"))) + .collect() +} + +/// Compute the (num_limbs, limb_bits) that the compiler will use for this +/// curve, so the test can decompose coordinates the same way. +fn msm_params_for_curve(curve: &impl Curve, n_points: usize) -> (usize, u32) { + let native_bits = FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let is_native = curve.is_native_field(); + let scalar_bits = curve.curve_order_bits() as usize; + let (limb_bits, _window_size, num_limbs) = + get_optimal_msm_params(native_bits, curve_bits, n_points, scalar_bits, is_native); + (num_limbs, limb_bits) +} + +/// Decompose a [u64; 4] value into field-element limbs. +fn u256_to_limb_fes(v: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(v, limb_bits, num_limbs) +} + +// --------------------------------------------------------------------------- +// Single-point limbed MSM test runner +// --------------------------------------------------------------------------- + +/// Compile and solve a single-point MSM circuit using the limbed API. +/// +/// When `expected_inf` is true, the expected output is point at infinity +/// (all output limbs zero, out_inf = 1). +fn run_single_point_msm_test_limbed( + px: &[u64; 4], + py: &[u64; 4], + inf: bool, + scalar: &[u64; 4], + expected_x: &[u64; 4], + expected_y: &[u64; 4], + expected_inf: bool, +) { + let curve = Secp256r1; + let (num_limbs, limb_bits) = msm_params_for_curve(&curve, 1); + let (s_lo, s_hi) = split_scalar(scalar); + let stride = 2 * num_limbs + 1; + + let px_fes = u256_to_limb_fes(px, limb_bits, num_limbs); + let py_fes = u256_to_limb_fes(py, limb_bits, num_limbs); + let ex_fes = u256_to_limb_fes(expected_x, limb_bits, num_limbs); + let ey_fes = u256_to_limb_fes(expected_y, limb_bits, num_limbs); + + let mut compiler = NoirToR1CSCompiler::new(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + let base = compiler.num_witnesses(); + let total_input_wits = stride + 2 + stride; + compiler.r1cs.add_witnesses(total_input_wits); + + let points: Vec = (0..stride) + .map(|j| ConstantOrR1CSWitness::Witness(base + j)) + .collect(); + let slo_w = base + stride; + let shi_w = base + stride + 1; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(slo_w), + ConstantOrR1CSWitness::Witness(shi_w), + ]; + + let out_base = base + stride + 2; + let out_x_limbs: Vec = (0..num_limbs).map(|j| out_base + j).collect(); + let out_y_limbs: Vec = (0..num_limbs).map(|j| out_base + num_limbs + j).collect(); + let out_inf = out_base + 2 * num_limbs; + + let outputs = MsmLimbedOutputs { + out_x_limbs: out_x_limbs.clone(), + out_y_limbs: out_y_limbs.clone(), + out_inf, + }; + let msm_ops = vec![(points, scalars, outputs)]; + add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); + add_range_checks(&mut compiler, range_checks); + + let num_witnesses = compiler.num_witnesses(); + + // Set initial witness values + let mut initial_values = vec![(0, FieldElement::from(1u64))]; + for (j, fe) in px_fes.iter().enumerate() { + initial_values.push((base + j, *fe)); + } + for (j, fe) in py_fes.iter().enumerate() { + initial_values.push((base + num_limbs + j, *fe)); + } + let inf_fe = if inf { + FieldElement::from(1u64) + } else { + FieldElement::zero() + }; + initial_values.push((base + 2 * num_limbs, inf_fe)); + initial_values.push((slo_w, u256_to_fe(&s_lo))); + initial_values.push((shi_w, u256_to_fe(&s_hi))); + for (j, fe) in ex_fes.iter().enumerate() { + initial_values.push((out_x_limbs[j], *fe)); + } + for (j, fe) in ey_fes.iter().enumerate() { + initial_values.push((out_y_limbs[j], *fe)); + } + let out_inf_fe = if expected_inf { + FieldElement::from(1u64) + } else { + FieldElement::zero() + }; + initial_values.push((out_inf, out_inf_fe)); + + let witness = solve_witnesses(&compiler.witness_builders, num_witnesses, &initial_values); + + check_r1cs_satisfaction(&compiler.r1cs, &witness) + .expect("R1CS satisfaction check failed (limbed)"); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Single-point MSM using the secp256r1 generator directly. +/// The generator's x-coordinate exceeds BN254 Fr. +#[test] +fn test_single_point_generator() { + let curve = Secp256r1; + let gx = curve.generator().0; + let gy = curve.generator().1; + let scalar: [u64; 4] = [7, 0, 0, 0]; + let (ex, ey) = ec_scalar_mul( + &gx, + &gy, + &scalar, + &curve.curve_a(), + &curve.field_modulus_p(), + ); + + run_single_point_msm_test_limbed(&gx, &gy, false, &scalar, &ex, &ey, false); +} + +/// Scalar = 1: result should equal the input point. +#[test] +fn test_scalar_one() { + let curve = Secp256r1; + let gx = curve.generator().0; + let gy = curve.generator().1; + let scalar: [u64; 4] = [1, 0, 0, 0]; + + // 1·G = G + run_single_point_msm_test_limbed(&gx, &gy, false, &scalar, &gx, &gy, false); +} + +/// Large scalar spanning both lo and hi halves of the 256-bit representation. +#[test] +fn test_large_scalar() { + let curve = Secp256r1; + let gx = curve.generator().0; + let gy = curve.generator().1; + let scalar: [u64; 4] = [0xcafebabe, 0x12345678, 0x42, 0]; + let (ex, ey) = ec_scalar_mul( + &gx, + &gy, + &scalar, + &curve.curve_a(), + &curve.field_modulus_p(), + ); + + run_single_point_msm_test_limbed(&gx, &gy, false, &scalar, &ex, &ey, false); +} + +/// Zero scalar: result should be point at infinity. +#[test] +fn test_zero_scalar() { + let curve = Secp256r1; + let gx = curve.generator().0; + let gy = curve.generator().1; + let zero_scalar: [u64; 4] = [0, 0, 0, 0]; + let zero_point: [u64; 4] = [0, 0, 0, 0]; + + run_single_point_msm_test_limbed( + &gx, + &gy, + false, + &zero_scalar, + &zero_point, + &zero_point, + true, + ); +} + +/// Point at infinity as input: result should be point at infinity regardless +/// of scalar. +#[test] +fn test_point_at_infinity_input() { + let curve = Secp256r1; + // Use generator coords as placeholder (they're ignored due to inf=1 select) + let gx = curve.generator().0; + let gy = curve.generator().1; + let scalar: [u64; 4] = [42, 0, 0, 0]; + let zero_point: [u64; 4] = [0, 0, 0, 0]; + + run_single_point_msm_test_limbed(&gx, &gy, true, &scalar, &zero_point, &zero_point, true); +} + +/// Non-trivial point (2·G) with a moderate scalar, verifying the full +/// wNAF + FakeGLV pipeline. +#[test] +fn test_arbitrary_point_and_scalar() { + let curve = Secp256r1; + let gx = curve.generator().0; + let gy = curve.generator().1; + let a = &curve.curve_a(); + let p = &curve.field_modulus_p(); + + // P = 2·G + let (px, py) = ec_scalar_mul(&gx, &gy, &[2, 0, 0, 0], a, p); + let scalar: [u64; 4] = [17, 0, 0, 0]; + // Expected: 17·(2G) = 34G + let (ex, ey) = ec_scalar_mul(&gx, &gy, &[34, 0, 0, 0], a, p); + + run_single_point_msm_test_limbed(&px, &py, false, &scalar, &ex, &ey, false); +} + +/// Two-point MSM: s1·P1 + s2·P2 with arbitrary coordinates. +#[test] +fn test_two_point_msm() { + let curve = Secp256r1; + let gx = curve.generator().0; + let gy = curve.generator().1; + let a = &curve.curve_a(); + let p = &curve.field_modulus_p(); + let (num_limbs, limb_bits) = msm_params_for_curve(&curve, 2); + let stride = 2 * num_limbs + 1; + + // P1 = 3·G, P2 = 5·G + let (p1x, p1y) = ec_scalar_mul(&gx, &gy, &[3, 0, 0, 0], a, p); + let (p2x, p2y) = ec_scalar_mul(&gx, &gy, &[5, 0, 0, 0], a, p); + let s1: [u64; 4] = [2, 0, 0, 0]; + let s2: [u64; 4] = [3, 0, 0, 0]; + // Expected: 2·(3G) + 3·(5G) = 6G + 15G = 21G + let (ex, ey) = ec_scalar_mul(&gx, &gy, &[21, 0, 0, 0], a, p); + + let (s1_lo, s1_hi) = split_scalar(&s1); + let (s2_lo, s2_hi) = split_scalar(&s2); + + let p1x_fes = u256_to_limb_fes(&p1x, limb_bits, num_limbs); + let p1y_fes = u256_to_limb_fes(&p1y, limb_bits, num_limbs); + let p2x_fes = u256_to_limb_fes(&p2x, limb_bits, num_limbs); + let p2y_fes = u256_to_limb_fes(&p2y, limb_bits, num_limbs); + let ex_fes = u256_to_limb_fes(&ex, limb_bits, num_limbs); + let ey_fes = u256_to_limb_fes(&ey, limb_bits, num_limbs); + + let mut compiler = NoirToR1CSCompiler::new(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + let base = compiler.num_witnesses(); + let total = 2 * stride + 4 + stride; + compiler.r1cs.add_witnesses(total); + + let points: Vec = (0..2 * stride) + .map(|j| ConstantOrR1CSWitness::Witness(base + j)) + .collect(); + let scalar_base = base + 2 * stride; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(scalar_base), + ConstantOrR1CSWitness::Witness(scalar_base + 1), + ConstantOrR1CSWitness::Witness(scalar_base + 2), + ConstantOrR1CSWitness::Witness(scalar_base + 3), + ]; + let out_base = scalar_base + 4; + let out_x_limbs: Vec = (0..num_limbs).map(|j| out_base + j).collect(); + let out_y_limbs: Vec = (0..num_limbs).map(|j| out_base + num_limbs + j).collect(); + let out_inf = out_base + 2 * num_limbs; + + let outputs = MsmLimbedOutputs { + out_x_limbs: out_x_limbs.clone(), + out_y_limbs: out_y_limbs.clone(), + out_inf, + }; + let msm_ops = vec![(points, scalars, outputs)]; + add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); + add_range_checks(&mut compiler, range_checks); + + let num_witnesses = compiler.num_witnesses(); + + let mut initial_values = vec![(0, FieldElement::from(1u64))]; + for (j, fe) in p1x_fes.iter().enumerate() { + initial_values.push((base + j, *fe)); + } + for (j, fe) in p1y_fes.iter().enumerate() { + initial_values.push((base + num_limbs + j, *fe)); + } + initial_values.push((base + 2 * num_limbs, FieldElement::zero())); + let p2_base = base + stride; + for (j, fe) in p2x_fes.iter().enumerate() { + initial_values.push((p2_base + j, *fe)); + } + for (j, fe) in p2y_fes.iter().enumerate() { + initial_values.push((p2_base + num_limbs + j, *fe)); + } + initial_values.push((p2_base + 2 * num_limbs, FieldElement::zero())); + initial_values.push((scalar_base, u256_to_fe(&s1_lo))); + initial_values.push((scalar_base + 1, u256_to_fe(&s1_hi))); + initial_values.push((scalar_base + 2, u256_to_fe(&s2_lo))); + initial_values.push((scalar_base + 3, u256_to_fe(&s2_hi))); + for (j, fe) in ex_fes.iter().enumerate() { + initial_values.push((out_x_limbs[j], *fe)); + } + for (j, fe) in ey_fes.iter().enumerate() { + initial_values.push((out_y_limbs[j], *fe)); + } + initial_values.push((out_inf, FieldElement::zero())); + + let witness = solve_witnesses(&compiler.witness_builders, num_witnesses, &initial_values); + println!(">>> number of witnesses : {:?}", witness.len()); + println!( + ">>> number of constraints : {:?}", + compiler.r1cs.num_constraints() + ); + + check_r1cs_satisfaction(&compiler.r1cs, &witness) + .expect("R1CS satisfaction check failed for two-point MSM"); +} + +/// Two-point MSM where one scalar is zero — only the non-zero point +/// should contribute. +#[test] +fn test_two_point_one_zero_scalar() { + let curve = Secp256r1; + let gx = curve.generator().0; + let gy = curve.generator().1; + let a = &curve.curve_a(); + let p = &curve.field_modulus_p(); + let (num_limbs, limb_bits) = msm_params_for_curve(&curve, 2); + let stride = 2 * num_limbs + 1; + + // P1 = G (scalar=5), P2 = 2G (scalar=0) + let (p2x, p2y) = ec_scalar_mul(&gx, &gy, &[2, 0, 0, 0], a, p); + let s1: [u64; 4] = [5, 0, 0, 0]; + let s2: [u64; 4] = [0, 0, 0, 0]; + // Expected: 5·G + 0·(2G) = 5G + let (ex, ey) = ec_scalar_mul(&gx, &gy, &[5, 0, 0, 0], a, p); + + let (s1_lo, s1_hi) = split_scalar(&s1); + let (s2_lo, s2_hi) = split_scalar(&s2); + + let p1x_fes = u256_to_limb_fes(&gx, limb_bits, num_limbs); + let p1y_fes = u256_to_limb_fes(&gy, limb_bits, num_limbs); + let p2x_fes = u256_to_limb_fes(&p2x, limb_bits, num_limbs); + let p2y_fes = u256_to_limb_fes(&p2y, limb_bits, num_limbs); + let ex_fes = u256_to_limb_fes(&ex, limb_bits, num_limbs); + let ey_fes = u256_to_limb_fes(&ey, limb_bits, num_limbs); + + let mut compiler = NoirToR1CSCompiler::new(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + let base = compiler.num_witnesses(); + let total = 2 * stride + 4 + stride; + compiler.r1cs.add_witnesses(total); + + let points: Vec = (0..2 * stride) + .map(|j| ConstantOrR1CSWitness::Witness(base + j)) + .collect(); + let scalar_base = base + 2 * stride; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(scalar_base), + ConstantOrR1CSWitness::Witness(scalar_base + 1), + ConstantOrR1CSWitness::Witness(scalar_base + 2), + ConstantOrR1CSWitness::Witness(scalar_base + 3), + ]; + let out_base = scalar_base + 4; + let out_x_limbs: Vec = (0..num_limbs).map(|j| out_base + j).collect(); + let out_y_limbs: Vec = (0..num_limbs).map(|j| out_base + num_limbs + j).collect(); + let out_inf = out_base + 2 * num_limbs; + + let outputs = MsmLimbedOutputs { + out_x_limbs: out_x_limbs.clone(), + out_y_limbs: out_y_limbs.clone(), + out_inf, + }; + let msm_ops = vec![(points, scalars, outputs)]; + add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); + add_range_checks(&mut compiler, range_checks); + + let num_witnesses = compiler.num_witnesses(); + + let mut initial_values = vec![(0, FieldElement::from(1u64))]; + // P1 limbs (generator) + for (j, fe) in p1x_fes.iter().enumerate() { + initial_values.push((base + j, *fe)); + } + for (j, fe) in p1y_fes.iter().enumerate() { + initial_values.push((base + num_limbs + j, *fe)); + } + initial_values.push((base + 2 * num_limbs, FieldElement::zero())); + // P2 limbs + let p2_base = base + stride; + for (j, fe) in p2x_fes.iter().enumerate() { + initial_values.push((p2_base + j, *fe)); + } + for (j, fe) in p2y_fes.iter().enumerate() { + initial_values.push((p2_base + num_limbs + j, *fe)); + } + initial_values.push((p2_base + 2 * num_limbs, FieldElement::zero())); + // Scalars + initial_values.push((scalar_base, u256_to_fe(&s1_lo))); + initial_values.push((scalar_base + 1, u256_to_fe(&s1_hi))); + initial_values.push((scalar_base + 2, u256_to_fe(&s2_lo))); + initial_values.push((scalar_base + 3, u256_to_fe(&s2_hi))); + // Expected output limbs + for (j, fe) in ex_fes.iter().enumerate() { + initial_values.push((out_x_limbs[j], *fe)); + } + for (j, fe) in ey_fes.iter().enumerate() { + initial_values.push((out_y_limbs[j], *fe)); + } + initial_values.push((out_inf, FieldElement::zero())); + + let witness = solve_witnesses(&compiler.witness_builders, num_witnesses, &initial_values); + + check_r1cs_satisfaction(&compiler.r1cs, &witness) + .expect("R1CS satisfaction check failed for two-point MSM with one zero scalar"); +}