diff --git a/Makefile b/Makefile index 29d482d947..c57050c41a 100644 --- a/Makefile +++ b/Makefile @@ -294,7 +294,7 @@ check_typos: install_typos_checker .PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled clippy_gpu: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ - --features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types \ + --features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types,zk-pok \ --all-targets \ -p $(TFHE_SPEC) -- --no-deps -D warnings @@ -892,7 +892,7 @@ test_high_level_api: install_rs_build_toolchain test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) nextest run --cargo-profile $(CARGO_PROFILE) \ - --features=integer,internal-keycache,gpu -p $(TFHE_SPEC) \ + --features=integer,internal-keycache,gpu,zk-pok -p $(TFHE_SPEC) \ -E "test(/high_level_api::.*gpu.*/)" test_high_level_api_hpu: install_rs_build_toolchain install_cargo_nextest @@ -1066,7 +1066,7 @@ check_compile_tests: install_rs_build_toolchain .PHONY: check_compile_tests_benches_gpu # Build tests in debug without running them check_compile_tests_benches_gpu: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \ - --features=experimental,boolean,shortint,integer,internal-keycache,gpu \ + --features=experimental,boolean,shortint,integer,internal-keycache,gpu,zk-pok \ -p $(TFHE_SPEC) mkdir -p "$(TFHECUDA_BUILD)" && \ cd "$(TFHECUDA_BUILD)" && \ diff --git a/tfhe-benchmark/benches/integer/zk_pke.rs b/tfhe-benchmark/benches/integer/zk_pke.rs index 22e119b50a..6976fc6a44 100644 --- a/tfhe-benchmark/benches/integer/zk_pke.rs +++ b/tfhe-benchmark/benches/integer/zk_pke.rs @@ -418,11 +418,11 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) { #[cfg(all(feature = "gpu", feature = "zk-pok"))] mod cuda { use super::*; - use benchmark::utilities::{cuda_local_keys, cuda_local_streams}; + use benchmark::utilities::cuda_local_streams; use criterion::BatchSize; use itertools::Itertools; use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams}; - use tfhe::integer::gpu::key_switching_key::CudaKeySwitchingKey; + use tfhe::integer::gpu::key_switching_key::{CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial}; use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList; use tfhe::integer::gpu::CudaServerKey; use tfhe::integer::CompressedServerKey; @@ -451,14 +451,17 @@ mod cuda { let param_name = param_name.as_str(); let cks = ClientKey::new(param_fhe); let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks); + let sk = compressed_server_key.decompress(); let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams); + let compact_private_key = CompactPrivateKey::new(param_pke); let pk = CompactPublicKey::new(&compact_private_key); - let d_ksk = CudaKeySwitchingKey::new( - (&compact_private_key, None), - (&cks, &gpu_sks), - param_ksk, - &streams, + let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk); + let d_ksk_material = + CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams); + let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material( + &d_ksk_material, + &gpu_sks, ); // We have a use case with 320 bits of metadata @@ -609,7 +612,6 @@ mod cuda { }); } BenchmarkType::Throughput => { - let gpu_sks_vec = cuda_local_keys(&cks); let gpu_count = get_number_of_gpus() as usize; let elements = zk_throughput_num_elements(); @@ -637,20 +639,17 @@ mod cuda { .collect::>(); let local_streams = cuda_local_streams(num_block, elements as usize); - let d_ksk_vec = gpu_sks_vec + let d_ksk_material_vec = local_streams .par_iter() - .zip(local_streams.par_iter()) - .map(|(gpu_sks, local_stream)| { - CudaKeySwitchingKey::new( - (&compact_private_key, None), - (&cks, gpu_sks), - param_ksk, + .map(|local_stream| { + CudaKeySwitchingKeyMaterial::from_key_switching_key( + &ksk, local_stream, ) }) .collect::>(); - assert_eq!(d_ksk_vec.len(), gpu_count); + assert_eq!(d_ksk_material_vec.len(), gpu_count); bench_group.bench_function(&bench_id_verify, |b| { b.iter(|| { @@ -673,14 +672,16 @@ mod cuda { (gpu_cts, local_streams) }; - b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| { - gpu_cts.par_iter() - .zip(local_streams.par_iter()) - .enumerate() - .for_each(|(i, (gpu_ct, local_stream))| { - gpu_ct - .expand_without_verification(&d_ksk_vec[i % gpu_count], local_stream) - .unwrap(); + b.iter_batched(setup_encrypted_values, + |(gpu_cts, local_streams)| { + gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each + (|(i, (gpu_ct, local_stream))| { + let d_ksk = + CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks); + + gpu_ct + .expand_without_verification(&d_ksk, local_stream) + .unwrap(); }); }, BatchSize::SmallInput); }); @@ -698,16 +699,15 @@ mod cuda { (gpu_cts, local_streams) }; - b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| { - gpu_cts - .par_iter() - .zip(local_streams.par_iter()) - .for_each(|(gpu_ct, local_stream)| { - gpu_ct - .verify_and_expand( - &crs, &pk, &metadata, &d_ksk, local_stream - ) - .unwrap(); + b.iter_batched(setup_encrypted_values, + |(gpu_cts, local_streams)| { + gpu_cts.par_iter().zip(local_streams.par_iter()).for_each + (|(gpu_ct, local_stream)| { + gpu_ct + .verify_and_expand( + &crs, &pk, &metadata, &d_ksk, local_stream, + ) + .unwrap(); }); }, BatchSize::SmallInput); }); diff --git a/tfhe/docs/configuration/gpu_acceleration/zk-pok.md b/tfhe/docs/configuration/gpu_acceleration/zk-pok.md new file mode 100644 index 0000000000..8ef6158d25 --- /dev/null +++ b/tfhe/docs/configuration/gpu_acceleration/zk-pok.md @@ -0,0 +1,82 @@ +# Zero-knowledge proofs + +Zero-knowledge proofs (ZK) are a powerful tool to assert that the encryption of a message is correct, as discussed in [advanced features](../../fhe-computation/advanced-features/zk-pok.md). +However, computation is not possible on the type of ciphertexts it produces (i.e. `ProvenCompactCiphertextList`). This document explains how to use the GPU to accelerate the +preprocessing step needed to convert ciphertexts formatted for ZK to ciphertexts in the right format for computation purposes on GPU. This +operation is called "expansion". + +## Proven compact ciphertext list + +A proven compact list of ciphertexts can be seen as a compacted collection of ciphertexts for which encryption can be verified. +This verification is currently only supported on the CPU, but the expansion can be accelerated using the GPU. +This way, verification and expansion can be performed in parallel, efficiently using all the available computational resources. + +## Supported types +Encrypted messages can be integers (like FheUint64) or booleans. The GPU backend does not currently support encrypted strings. + +{% hint style="info" %} +You can enable this feature using the flag: `--features=zk-pok,gpu` when building **TFHE-rs**. +{% endhint %} + + +## Example + +The following example shows how a client can encrypt and prove a ciphertext, and how a server can verify the proof, preprocess the ciphertext and run a computation on it on GPU: + +```rust +use rand::random; +use tfhe::CompressedServerKey; +use tfhe::prelude::*; +use tfhe::set_server_key; +use tfhe::zk::{CompactPkeCrs, ZkComputeLoad}; + +pub fn main() -> Result<(), Box> { + let params = tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + // Indicate which parameters to use for the Compact Public Key encryption + let cpk_params = tfhe::shortint::parameters::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + // And parameters allowing to keyswitch/cast to the computation parameters. + let casting_params = tfhe::shortint::parameters::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + // Enable the dedicated parameters on the config + let config = tfhe::ConfigBuilder::with_custom_parameters(params) + .use_dedicated_compact_public_key_parameters((cpk_params, casting_params)).build(); + + // The CRS should be generated in an offline phase then shared to all clients and the server + let crs = CompactPkeCrs::from_config(config, 64).unwrap(); + + // Then use TFHE-rs as usual + let client_key = tfhe::ClientKey::generate(config); + let compressed_server_key = CompressedServerKey::new(&client_key); + let gpu_server_key = compressed_server_key.decompress_to_gpu(); + + let public_key = tfhe::CompactPublicKey::try_new(&client_key).unwrap(); + // This can be left empty, but if provided allows to tie the proof to arbitrary data + let metadata = [b'T', b'F', b'H', b'E', b'-', b'r', b's']; + + let clear_a = random::(); + let clear_b = random::(); + + let proven_compact_list = tfhe::ProvenCompactCiphertextList::builder(&public_key) + .push(clear_a) + .push(clear_b) + .build_with_proof_packed(&crs, &metadata, ZkComputeLoad::Verify)?; + + // Server side + let result = { + set_server_key(gpu_server_key); + + // Verify the ciphertexts + let expander = + proven_compact_list.verify_and_expand(&crs, &public_key, &metadata)?; + let a: tfhe::FheUint64 = expander.get(0)?.unwrap(); + let b: tfhe::FheUint64 = expander.get(1)?.unwrap(); + + a + b + }; + + // Back on the client side + let a_plus_b: u64 = result.decrypt(&client_key); + assert_eq!(a_plus_b, clear_a.wrapping_add(clear_b)); + + Ok(()) +} +``` diff --git a/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs b/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs index b40760d15a..7736bef08e 100644 --- a/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs +++ b/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs @@ -7,7 +7,7 @@ use crate::core_crypto::prelude::{ use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu; /// A structure representing a vector of LWE ciphertexts with 64 bits of precision on the GPU. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct CudaLweCiphertextList(pub(crate) CudaLweList); #[allow(dead_code)] diff --git a/tfhe/src/core_crypto/gpu/entities/lwe_compact_ciphertext_list.rs b/tfhe/src/core_crypto/gpu/entities/lwe_compact_ciphertext_list.rs index a52fa779b2..c369cb3c30 100644 --- a/tfhe/src/core_crypto/gpu/entities/lwe_compact_ciphertext_list.rs +++ b/tfhe/src/core_crypto/gpu/entities/lwe_compact_ciphertext_list.rs @@ -10,6 +10,10 @@ use crate::core_crypto::prelude::{ pub struct CudaLweCompactCiphertextList(pub CudaLweList); impl CudaLweCompactCiphertextList { + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self(self.0.duplicate(streams)) + } + pub fn from_lwe_compact_ciphertext_list>( h_ct: &LweCompactCiphertextList, streams: &CudaStreams, diff --git a/tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs b/tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs index fd37d09a03..1fc9ddb4a2 100644 --- a/tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs +++ b/tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs @@ -10,6 +10,7 @@ use crate::core_crypto::prelude::{ UnsignedInteger, }; +#[derive(Clone)] #[allow(dead_code)] pub struct CudaLweKeyswitchKey { pub(crate) d_vec: CudaVec, diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index 100e72b978..ad0669cd0c 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -993,7 +993,7 @@ pub unsafe fn fourier_transform_backward_as_torus_f128_async ); } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct CudaLweList { // Pointer to GPU data pub d_vec: CudaVec, @@ -1005,6 +1005,17 @@ pub struct CudaLweList { pub ciphertext_modulus: CiphertextModulus, } +impl CudaLweList { + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self { + d_vec: self.d_vec.duplicate(streams), + lwe_ciphertext_count: self.lwe_ciphertext_count, + lwe_dimension: self.lwe_dimension, + ciphertext_modulus: self.ciphertext_modulus, + } + } +} + #[derive(Debug, Clone)] pub struct CudaGlweList { // Pointer to GPU data diff --git a/tfhe/src/high_level_api/backward_compatibility/compact_list.rs b/tfhe/src/high_level_api/backward_compatibility/compact_list.rs index cc38b5a16a..b34f5ba32c 100644 --- a/tfhe/src/high_level_api/backward_compatibility/compact_list.rs +++ b/tfhe/src/high_level_api/backward_compatibility/compact_list.rs @@ -3,6 +3,9 @@ use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; use crate::{CompactCiphertextList, Tag}; +#[cfg(feature = "zk-pok")] +use crate::ProvenCompactCiphertextList; + #[derive(Version)] pub struct CompactCiphertextListV0(crate::integer::ciphertext::CompactCiphertextList); @@ -17,9 +20,6 @@ impl Upgrade for CompactCiphertextListV0 { } } -#[cfg(feature = "zk-pok")] -use crate::ProvenCompactCiphertextList; - #[derive(VersionsDispatch)] pub enum CompactCiphertextListVersions { V0(CompactCiphertextListV0), diff --git a/tfhe/src/high_level_api/compact_list.rs b/tfhe/src/high_level_api/compact_list.rs index 6415b4a9c6..5516668d80 100644 --- a/tfhe/src/high_level_api/compact_list.rs +++ b/tfhe/src/high_level_api/compact_list.rs @@ -1,8 +1,4 @@ -use tfhe_versionable::Versionize; - use crate::backward_compatibility::compact_list::CompactCiphertextListVersions; -#[cfg(feature = "zk-pok")] -use crate::backward_compatibility::compact_list::ProvenCompactCiphertextListVersions; use crate::conformance::ParameterSetConformant; use crate::core_crypto::commons::math::random::{Deserialize, Serialize}; use crate::core_crypto::prelude::Numeric; @@ -10,7 +6,7 @@ use crate::high_level_api::global_state; use crate::high_level_api::keys::InternalServerKey; use crate::high_level_api::traits::Tagged; use crate::integer::block_decomposition::DecomposableInto; -use crate::integer::ciphertext::{Compactable, DataKind, Expandable}; +use crate::integer::ciphertext::{Compactable, DataKind}; use crate::integer::encryption::KnowsMessageModulus; use crate::integer::parameters::{ CompactCiphertextListConformanceParams, IntegerCompactCiphertextListExpansionMode, @@ -18,9 +14,14 @@ use crate::integer::parameters::{ use crate::named::Named; use crate::prelude::CiphertextList; use crate::shortint::MessageModulus; +use crate::HlExpandable; +use tfhe_versionable::Versionize; #[cfg(feature = "zk-pok")] pub use zk::ProvenCompactCiphertextList; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; + #[cfg(feature = "zk-pok")] use crate::zk::{CompactPkeCrs, ZkComputeLoad}; use crate::{CompactPublicKey, Tag}; @@ -173,7 +174,7 @@ impl CompactCiphertextList { self.inner .expand(sks.integer_compact_ciphertext_list_expansion_mode()) .map(|inner| CompactCiphertextListExpander { - inner, + inner: InnerCompactCiphertextListExpander::Cpu(inner), tag: self.tag.clone(), }) } @@ -183,9 +184,11 @@ impl CompactCiphertextList { if !self.inner.is_packed() && !self.inner.needs_casting() { // No ServerKey required, short-circuit to avoid the global state call return Ok(CompactCiphertextListExpander { - inner: self - .inner - .expand(IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking)?, + inner: InnerCompactCiphertextListExpander::Cpu( + self.inner.expand( + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking, + )?, + ), tag: self.tag.clone(), }); } @@ -196,7 +199,7 @@ impl CompactCiphertextList { .inner .expand(cpu_key.integer_compact_ciphertext_list_expansion_mode()) .map(|inner| CompactCiphertextListExpander { - inner, + inner: InnerCompactCiphertextListExpander::Cpu(inner), tag: self.tag.clone(), }), #[cfg(any(feature = "gpu", feature = "hpu"))] @@ -228,17 +231,134 @@ impl ParameterSetConformant for CompactCiphertextList { #[cfg(feature = "zk-pok")] mod zk { use super::*; + use crate::backward_compatibility::compact_list::ProvenCompactCiphertextListVersions; use crate::conformance::ParameterSetConformant; + use crate::high_level_api::global_state::device_of_internal_keys; use crate::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams; - use crate::zk::CompactPkeCrs; + #[cfg(feature = "gpu")] + use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey; + #[cfg(feature = "gpu")] + use crate::integer::gpu::zk::CudaProvenCompactCiphertextList; + use serde::Serializer; + + pub enum InnerProvenCompactCiphertextList { + Cpu(crate::integer::ciphertext::ProvenCompactCiphertextList), + #[cfg(feature = "gpu")] + Cuda(crate::integer::gpu::zk::CudaProvenCompactCiphertextList), + } + + impl Clone for InnerProvenCompactCiphertextList { + fn clone(&self) -> Self { + match self { + Self::Cpu(inner) => Self::Cpu(inner.clone()), + #[cfg(feature = "gpu")] + Self::Cuda(inner) => { + with_thread_local_cuda_streams(|streams| Self::Cuda(inner.duplicate(streams))) + } + } + } + } #[derive(Clone, Serialize, Deserialize, Versionize)] #[versionize(ProvenCompactCiphertextListVersions)] pub struct ProvenCompactCiphertextList { - pub(crate) inner: crate::integer::ciphertext::ProvenCompactCiphertextList, + pub(crate) inner: InnerProvenCompactCiphertextList, pub(crate) tag: Tag, } + impl InnerProvenCompactCiphertextList { + pub(crate) fn on_cpu(&self) -> &crate::integer::ciphertext::ProvenCompactCiphertextList { + match self { + Self::Cpu(inner) => inner, + #[cfg(feature = "gpu")] + Self::Cuda(inner) => &inner.h_proved_lists, + } + } + + fn move_to_device(&mut self, device: crate::Device) { + let new_value = match (&self, device) { + (Self::Cpu(_), crate::Device::Cpu) => None, + #[cfg(feature = "gpu")] + (Self::Cuda(cuda_ct), crate::Device::CudaGpu) => { + with_thread_local_cuda_streams(|streams| { + if cuda_ct.gpu_indexes() == streams.gpu_indexes() { + None + } else { + Some(Self::Cuda(cuda_ct.duplicate(streams))) + } + }) + } + #[cfg(feature = "gpu")] + (Self::Cuda(cuda_ct), crate::Device::Cpu) => { + let cpu_ct = cuda_ct.h_proved_lists.clone(); + Some(Self::Cpu(cpu_ct)) + } + #[cfg(feature = "gpu")] + (Self::Cpu(cpu_ct), crate::Device::CudaGpu) => { + let cuda_ct = with_thread_local_cuda_streams(|streams| { + CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list( + cpu_ct, streams, + ) + }); + Some(Self::Cuda(cuda_ct)) + } + }; + + if let Some(v) = new_value { + *self = v; + } + } + } + + impl serde::Serialize for InnerProvenCompactCiphertextList { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.on_cpu().serialize(serializer) + } + } + + impl<'de> serde::Deserialize<'de> for InnerProvenCompactCiphertextList { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut new = + crate::integer::ciphertext::ProvenCompactCiphertextList::deserialize(deserializer) + .map(Self::Cpu)?; + + if let Some(device) = device_of_internal_keys() { + new.move_to_device(device); + } + + Ok(new) + } + } + use tfhe_versionable::{Unversionize, UnversionizeError, VersionizeOwned}; + impl Versionize for InnerProvenCompactCiphertextList { + type Versioned<'vers> = + ::VersionedOwned; + fn versionize(&self) -> Self::Versioned<'_> { + self.on_cpu().clone().versionize_owned() + } + } + impl VersionizeOwned for InnerProvenCompactCiphertextList { + type VersionedOwned = + ::VersionedOwned; + fn versionize_owned(self) -> Self::VersionedOwned { + self.on_cpu().clone().versionize_owned() + } + } + + impl Unversionize for InnerProvenCompactCiphertextList { + fn unversionize(versioned: Self::VersionedOwned) -> Result { + Ok(Self::Cpu( + crate::integer::ciphertext::ProvenCompactCiphertextList::unversionize(versioned)?, + )) + } + } + impl Tagged for ProvenCompactCiphertextList { fn tag(&self) -> &Tag { &self.tag @@ -258,7 +378,7 @@ mod zk { } pub fn len(&self) -> usize { - self.inner.len() + self.inner.on_cpu().len() } pub fn is_empty(&self) -> bool { @@ -266,8 +386,9 @@ mod zk { } pub fn get_kind_of(&self, index: usize) -> Option { - self.inner.get_kind_of(index).and_then(|data_kind| { - crate::FheTypes::from_data_kind(data_kind, self.inner.ct_list.message_modulus()) + let inner_cpu = self.inner.on_cpu(); + inner_cpu.get_kind_of(index).and_then(|data_kind| { + crate::FheTypes::from_data_kind(data_kind, inner_cpu.ct_list.message_modulus()) }) } @@ -277,7 +398,7 @@ mod zk { pk: &CompactPublicKey, metadata: &[u8], ) -> crate::zk::ZkVerificationOutcome { - self.inner.verify(crs, &pk.key.key, metadata) + self.inner.on_cpu().verify(crs, &pk.key.key, metadata) } pub fn verify_and_expand( @@ -286,36 +407,112 @@ mod zk { pk: &CompactPublicKey, metadata: &[u8], ) -> crate::Result { - // For WASM - if !self.inner.is_packed() && !self.inner.needs_casting() { - // No ServerKey required, short circuit to avoid the global state call - return Ok(CompactCiphertextListExpander { - inner: self.inner.verify_and_expand( + #[allow(irrefutable_let_patterns)] + if let InnerProvenCompactCiphertextList::Cpu(inner) = &self.inner { + // For WASM + if !inner.is_packed() && !inner.needs_casting() { + let expander = inner.verify_and_expand( crs, &pk.key.key, metadata, IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking, - )?, - tag: self.tag.clone(), - }); + )?; + // No ServerKey required, short circuit to avoid the global state call + return Ok(CompactCiphertextListExpander { + inner: InnerCompactCiphertextListExpander::Cpu(expander), + tag: self.tag.clone(), + }); + } } global_state::try_with_internal_keys(|maybe_keys| match maybe_keys { None => Err(crate::high_level_api::errors::UninitializedServerKey.into()), - Some(InternalServerKey::Cpu(cpu_key)) => self - .inner - .verify_and_expand( - crs, - &pk.key.key, - metadata, - cpu_key.integer_compact_ciphertext_list_expansion_mode(), - ) - .map(|expander| CompactCiphertextListExpander { - inner: expander, - tag: self.tag.clone(), - }), - #[cfg(any(feature = "gpu", feature = "hpu"))] - Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())), + Some(InternalServerKey::Cpu(cpu_key)) => match &self.inner { + InnerProvenCompactCiphertextList::Cpu(inner) => inner + .verify_and_expand( + crs, + &pk.key.key, + metadata, + cpu_key.integer_compact_ciphertext_list_expansion_mode(), + ) + .map(|expander| CompactCiphertextListExpander { + inner: InnerCompactCiphertextListExpander::Cpu(expander), + tag: self.tag.clone(), + }), + #[cfg(feature = "gpu")] + InnerProvenCompactCiphertextList::Cuda(inner) => inner + .h_proved_lists + .verify_and_expand( + crs, + &pk.key.key, + metadata, + cpu_key.integer_compact_ciphertext_list_expansion_mode(), + ) + .map(|expander| CompactCiphertextListExpander { + inner: InnerCompactCiphertextListExpander::Cpu(expander), + tag: self.tag.clone(), + }), + }, + #[cfg(feature = "gpu")] + Some(InternalServerKey::Cuda(gpu_key)) => match &self.inner { + InnerProvenCompactCiphertextList::Cuda(inner) => { + with_thread_local_cuda_streams(|streams| { + let ksk = CudaKeySwitchingKey { + key_switching_key_material: gpu_key + .key + .cpk_key_switching_key_material + .as_ref() + .unwrap(), + dest_server_key: &gpu_key.key.key, + }; + let expander = inner.verify_and_expand( + crs, + &pk.key.key, + metadata, + &ksk, + streams, + )?; + + Ok(CompactCiphertextListExpander { + inner: InnerCompactCiphertextListExpander::Cuda(expander), + tag: self.tag.clone(), + }) + }) + } + InnerProvenCompactCiphertextList::Cpu(cpu_inner) => { + with_thread_local_cuda_streams(|streams| { + let gpu_proven_ct = CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list( + cpu_inner, streams, + ); + with_thread_local_cuda_streams(|streams| { + let ksk = CudaKeySwitchingKey { + key_switching_key_material: gpu_key + .key + .cpk_key_switching_key_material + .as_ref() + .unwrap(), + dest_server_key: &gpu_key.key.key, + }; + let expander = gpu_proven_ct.verify_and_expand( + crs, + &pk.key.key, + metadata, + &ksk, + streams, + )?; + + Ok(CompactCiphertextListExpander { + inner: InnerCompactCiphertextListExpander::Cuda(expander), + tag: self.tag.clone(), + }) + }) + }) + } + }, + #[cfg(feature = "hpu")] + Some(InternalServerKey::Hpu(_)) => Err(crate::error!( + "Hpu does not support ProvenCompactCiphertextList" + )), }) } @@ -324,30 +521,84 @@ mod zk { /// /// If you are here you were probably looking for it: use at your own risks. pub fn expand_without_verification(&self) -> crate::Result { - // For WASM - if !self.inner.is_packed() && !self.inner.needs_casting() { - // No ServerKey required, short circuit to avoid the global state call - return Ok(CompactCiphertextListExpander { - inner: self.inner.expand_without_verification( - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking, - )?, - tag: self.tag.clone(), - }); + #[allow(irrefutable_let_patterns)] + if let InnerProvenCompactCiphertextList::Cpu(inner) = &self.inner { + // For WASM + if !inner.is_packed() && !inner.needs_casting() { + // No ServerKey required, short circuit to avoid the global state call + return Ok(CompactCiphertextListExpander { + inner: InnerCompactCiphertextListExpander::Cpu( + inner.expand_without_verification( + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking, + )?, + ), + tag: self.tag.clone(), + }); + } } - global_state::try_with_internal_keys(|maybe_keys| match maybe_keys { + global_state::try_with_internal_keys(|maybe_keys| { + match maybe_keys { None => Err(crate::high_level_api::errors::UninitializedServerKey.into()), - Some(InternalServerKey::Cpu(cpu_key)) => self - .inner - .expand_without_verification( - cpu_key.integer_compact_ciphertext_list_expansion_mode(), - ) - .map(|expander| CompactCiphertextListExpander { - inner: expander, - tag: self.tag.clone(), - }), - #[cfg(any(feature = "gpu", feature = "hpu"))] - Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())), + Some(InternalServerKey::Cpu(cpu_key)) => match &self.inner { + InnerProvenCompactCiphertextList::Cpu(inner) => inner + .expand_without_verification( + cpu_key.integer_compact_ciphertext_list_expansion_mode(), + ) + .map(|expander| CompactCiphertextListExpander { + inner: InnerCompactCiphertextListExpander::Cpu(expander), + tag: self.tag.clone(), + }), + #[cfg(feature = "gpu")] + InnerProvenCompactCiphertextList::Cuda(_) => { + Err(crate::Error::new("Tried expanding a ProvenCompactCiphertextList on the GPU, but the set ServerKey is a ServerKey".to_string())) + } + }, + #[cfg(feature = "gpu")] + Some(InternalServerKey::Cuda(gpu_key)) => match &self.inner { + InnerProvenCompactCiphertextList::Cuda(inner) => { + with_thread_local_cuda_streams(|streams| { + let ksk = CudaKeySwitchingKey { + key_switching_key_material: gpu_key + .key + .cpk_key_switching_key_material + .as_ref() + .unwrap(), + dest_server_key: &gpu_key.key.key, + }; + let expander = inner.expand_without_verification(&ksk, streams)?; + + Ok(CompactCiphertextListExpander { + inner: InnerCompactCiphertextListExpander::Cuda(expander), + tag: self.tag.clone(), + }) + }) + } + InnerProvenCompactCiphertextList::Cpu(inner) => { + with_thread_local_cuda_streams(|streams| { + let gpu_proven_ct = CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list( + inner, streams, + ); + let ksk = CudaKeySwitchingKey { + key_switching_key_material: gpu_key + .key + .cpk_key_switching_key_material + .as_ref() + .unwrap(), + dest_server_key: &gpu_key.key.key, + }; + let expander = gpu_proven_ct.expand_without_verification(&ksk, streams)?; + + Ok(CompactCiphertextListExpander { + inner: InnerCompactCiphertextListExpander::Cuda(expander), + tag: self.tag.clone(), + }) + }) + } + }, + #[cfg(feature = "hpu")] + Some(InternalServerKey::Hpu(_)) => Err(crate::error!("Hpu does not support ProvenCompactCiphertextList")), + } }) } } @@ -356,9 +607,7 @@ mod zk { type ParameterSet = IntegerProvenCompactCiphertextListConformanceParams; fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { - let Self { inner, tag: _ } = self; - - inner.is_conformant(parameter_set) + self.inner.on_cpu().is_conformant(parameter_set) } } @@ -367,7 +616,7 @@ mod zk { use super::*; use crate::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams; use crate::shortint::parameters::*; - use crate::zk::CompactPkeCrs; + use rand::{thread_rng, Rng}; #[test] @@ -409,14 +658,24 @@ mod zk { } } +pub enum InnerCompactCiphertextListExpander { + Cpu(crate::integer::ciphertext::CompactCiphertextListExpander), + #[cfg(feature = "gpu")] + Cuda(crate::integer::gpu::ciphertext::compact_list::CudaCompactCiphertextListExpander), +} + pub struct CompactCiphertextListExpander { - pub(in crate::high_level_api) inner: crate::integer::ciphertext::CompactCiphertextListExpander, + pub inner: InnerCompactCiphertextListExpander, tag: Tag, } impl CiphertextList for CompactCiphertextListExpander { fn len(&self) -> usize { - self.inner.len() + match &self.inner { + InnerCompactCiphertextListExpander::Cpu(inner) => inner.len(), + #[cfg(feature = "gpu")] + InnerCompactCiphertextListExpander::Cuda(inner) => inner.len(), + } } fn is_empty(&self) -> bool { @@ -424,16 +683,33 @@ impl CiphertextList for CompactCiphertextListExpander { } fn get_kind_of(&self, index: usize) -> Option { - self.inner.get_kind_of(index).and_then(|data_kind| { - crate::FheTypes::from_data_kind(data_kind, self.inner.message_modulus()) - }) + match &self.inner { + InnerCompactCiphertextListExpander::Cpu(inner) => { + inner.get_kind_of(index).and_then(|data_kind| { + crate::FheTypes::from_data_kind(data_kind, inner.message_modulus()) + }) + } + #[cfg(feature = "gpu")] + InnerCompactCiphertextListExpander::Cuda(inner) => { + inner.get_kind_of(index).and_then(|data_kind| { + crate::FheTypes::from_data_kind(data_kind, inner.message_modulus(index)?) + }) + } + } } fn get(&self, index: usize) -> crate::Result> where - T: Expandable + Tagged, + T: HlExpandable + Tagged, { - let mut expanded = self.inner.get::(index); + let mut expanded = match &self.inner { + InnerCompactCiphertextListExpander::Cpu(inner) => inner.get::(index), + #[cfg(feature = "gpu")] + InnerCompactCiphertextListExpander::Cuda(inner) => { + with_thread_local_cuda_streams(|streams| inner.get::(index, streams)) + } + }; + if let Ok(Some(inner)) = &mut expanded { inner.tag_mut().set_data(self.tag.data()); } @@ -531,7 +807,6 @@ impl CompactCiphertextListBuilder { }) .expect("Internal error, invalid parameters should not have been allowed") } - #[cfg(feature = "zk-pok")] pub fn build_with_proof_packed( &self, @@ -542,7 +817,10 @@ impl CompactCiphertextListBuilder { self.inner .build_with_proof_packed(crs, metadata, compute_load) .map(|proved_list| ProvenCompactCiphertextList { - inner: proved_list, + inner: + crate::high_level_api::compact_list::zk::InnerProvenCompactCiphertextList::Cpu( + proved_list, + ), tag: self.tag.clone(), }) } diff --git a/tfhe/src/high_level_api/keys/inner.rs b/tfhe/src/high_level_api/keys/inner.rs index cd968baab3..49124f9e53 100644 --- a/tfhe/src/high_level_api/keys/inner.rs +++ b/tfhe/src/high_level_api/keys/inner.rs @@ -315,6 +315,9 @@ impl IntegerServerKey { #[cfg(feature = "gpu")] pub struct IntegerCudaServerKey { pub(crate) key: crate::integer::gpu::CudaServerKey, + #[allow(dead_code)] + pub(crate) cpk_key_switching_key_material: + Option, pub(crate) compression_key: Option, pub(crate) decompression_key: diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index c0dd431e8b..e7e8f059f5 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -8,6 +8,8 @@ use super::ClientKey; use crate::backward_compatibility::keys::{CompressedServerKeyVersions, ServerKeyVersions}; use crate::conformance::ParameterSetConformant; #[cfg(feature = "gpu")] +use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey; +#[cfg(feature = "gpu")] use crate::core_crypto::gpu::{synchronize_devices, CudaStreams}; #[cfg(feature = "gpu")] use crate::high_level_api::keys::inner::IntegerCudaServerKey; @@ -292,6 +294,21 @@ impl CompressedServerKey { &self.integer_key.key, &streams, ); + let cpk_key_switching_key_material = self + .integer_key + .cpk_key_switching_key_material + .as_ref() + .map(|cpk_ksk_material| { + let ksk_material = cpk_ksk_material.decompress(); + let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key( + &ksk_material.material.key_switching_key, + &streams, + ); + CudaKeySwitchingKeyMaterial { + lwe_keyswitch_key: d_ksk, + destination_key: ksk_material.material.destination_key, + } + }); let compression_key: Option< crate::integer::gpu::list_compression::server_keys::CudaCompressionKey, > = self @@ -328,6 +345,7 @@ impl CompressedServerKey { CudaServerKey { key: Arc::new(IntegerCudaServerKey { key, + cpk_key_switching_key_material, compression_key, decompression_key, }), @@ -460,6 +478,9 @@ mod hpu { use crate::high_level_api::keys::inner::IntegerServerKeyConformanceParams; +#[cfg(feature = "gpu")] +use crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial; + impl ParameterSetConformant for ServerKey { type ParameterSet = IntegerServerKeyConformanceParams; diff --git a/tfhe/src/high_level_api/tests/tags_on_entities.rs b/tfhe/src/high_level_api/tests/tags_on_entities.rs index ff46180e9f..b8f2a8dd3b 100644 --- a/tfhe/src/high_level_api/tests/tags_on_entities.rs +++ b/tfhe/src/high_level_api/tests/tags_on_entities.rs @@ -143,6 +143,115 @@ fn test_tag_propagation_zk_pok() { } } +#[test] +#[cfg(feature = "zk-pok")] +#[cfg(feature = "gpu")] +fn test_tag_propagation_zk_pok_gpu() { + use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + let config = + ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128) + .use_dedicated_compact_public_key_parameters(( + PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + )) + .build(); + let crs = crate::zk::CompactPkeCrs::from_config(config, (2 * 32) + (2 * 64) + 2).unwrap(); + + let metadata = [b'h', b'l', b'a', b'p', b'i']; + + let mut cks = ClientKey::generate(config); + let tag_value = random(); + cks.tag_mut().set_u64(tag_value); + let cks = serialize_then_deserialize(&cks); + assert_eq!(cks.tag().as_u64(), tag_value); + + let compressed_server_key = CompressedServerKey::new(&cks); + let gpu_sks = compressed_server_key.decompress_to_gpu(); + assert_eq!(gpu_sks.tag(), cks.tag()); + set_server_key(gpu_sks); + + let cpk = CompactPublicKey::new(&cks); + assert_eq!(cpk.tag(), cks.tag()); + + let mut builder = CompactCiphertextList::builder(&cpk); + + let list_packed = builder + .push(32u32) + .push(1u32) + .push(-1i64) + .push(i64::MIN) + .push(false) + .push(true) + .build_with_proof_packed(&crs, &metadata, crate::zk::ZkComputeLoad::Proof) + .unwrap(); + + let expander = list_packed + .verify_and_expand(&crs, &cpk, &metadata) + .unwrap(); + + { + let au32: FheUint32 = expander.get(0).unwrap().unwrap(); + let bu32: FheUint32 = expander.get(1).unwrap().unwrap(); + assert_eq!(au32.tag(), cks.tag()); + assert_eq!(bu32.tag(), cks.tag()); + + let cu32 = au32 + bu32; + assert_eq!(cu32.tag(), cks.tag()); + } + + { + let ai64: FheInt64 = expander.get(2).unwrap().unwrap(); + let bi64: FheInt64 = expander.get(3).unwrap().unwrap(); + assert_eq!(ai64.tag(), cks.tag()); + assert_eq!(bi64.tag(), cks.tag()); + + let ci64 = ai64 + bi64; + assert_eq!(ci64.tag(), cks.tag()); + } + + { + let abool: FheBool = expander.get(4).unwrap().unwrap(); + let bbool: FheBool = expander.get(5).unwrap().unwrap(); + assert_eq!(abool.tag(), cks.tag()); + assert_eq!(bbool.tag(), cks.tag()); + + let cbool = abool & bbool; + assert_eq!(cbool.tag(), cks.tag()); + } + + let unverified_expander = list_packed.expand_without_verification().unwrap(); + + { + let au32: FheUint32 = unverified_expander.get(0).unwrap().unwrap(); + let bu32: FheUint32 = unverified_expander.get(1).unwrap().unwrap(); + assert_eq!(au32.tag(), cks.tag()); + assert_eq!(bu32.tag(), cks.tag()); + + let cu32 = au32 + bu32; + assert_eq!(cu32.tag(), cks.tag()); + } + + { + let ai64: FheInt64 = unverified_expander.get(2).unwrap().unwrap(); + let bi64: FheInt64 = unverified_expander.get(3).unwrap().unwrap(); + assert_eq!(ai64.tag(), cks.tag()); + assert_eq!(bi64.tag(), cks.tag()); + + let ci64 = ai64 + bi64; + assert_eq!(ci64.tag(), cks.tag()); + } + + { + let abool: FheBool = unverified_expander.get(4).unwrap().unwrap(); + let bbool: FheBool = unverified_expander.get(5).unwrap().unwrap(); + assert_eq!(abool.tag(), cks.tag()); + assert_eq!(bbool.tag(), cks.tag()); + + let cbool = abool & bbool; + assert_eq!(cbool.tag(), cks.tag()); + } +} + #[test] #[cfg(feature = "gpu")] fn test_tag_propagation_gpu() { diff --git a/tfhe/src/integer/ciphertext/compact_list.rs b/tfhe/src/integer/ciphertext/compact_list.rs index 5c551f23dc..50654ead74 100644 --- a/tfhe/src/integer/ciphertext/compact_list.rs +++ b/tfhe/src/integer/ciphertext/compact_list.rs @@ -310,7 +310,7 @@ pub struct CompactCiphertextListExpander { } impl CompactCiphertextListExpander { - fn new(expanded_blocks: Vec, info: Vec) -> Self { + pub(crate) fn new(expanded_blocks: Vec, info: Vec) -> Self { Self { expanded_blocks, info, diff --git a/tfhe/src/integer/gpu/ciphertext/compact_list.rs b/tfhe/src/integer/gpu/ciphertext/compact_list.rs index 7b20f04444..a7631bc525 100644 --- a/tfhe/src/integer/gpu/ciphertext/compact_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compact_list.rs @@ -1,13 +1,16 @@ +use crate::core_crypto::commons::traits::contiguous_entity_container::ContiguousEntityContainer; use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; use crate::core_crypto::gpu::lwe_compact_ciphertext_list::CudaLweCompactCiphertextList; use crate::core_crypto::gpu::CudaStreams; -use crate::integer::ciphertext::DataKind; +use crate::core_crypto::prelude::LweCiphertext; +use crate::integer::ciphertext::{CompactCiphertextListExpander, DataKind}; use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaExpandable; use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo}; use crate::integer::gpu::ciphertext::CudaRadixCiphertext; use crate::shortint::ciphertext::CompactCiphertextList; use crate::shortint::parameters::{CompactCiphertextListExpansionKind, Degree}; -use crate::shortint::{CarryModulus, MessageModulus}; +use crate::shortint::{CarryModulus, Ciphertext, MessageModulus}; +use itertools::Itertools; pub struct CudaCompactCiphertextList { pub(crate) d_ct_list: CudaLweCompactCiphertextList, @@ -17,12 +20,25 @@ pub struct CudaCompactCiphertextList { pub(crate) expansion_kind: CompactCiphertextListExpansionKind, } +impl CudaCompactCiphertextList { + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self { + d_ct_list: self.d_ct_list.duplicate(streams), + degree: self.degree, + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + expansion_kind: self.expansion_kind, + } + } +} + #[derive(Clone)] pub struct CudaCompactCiphertextListInfo { pub info: CudaBlockInfo, pub data_kind: DataKind, } +#[derive(Clone)] pub struct CudaCompactCiphertextListExpander { pub(crate) expanded_blocks: CudaLweCiphertextList, pub(crate) blocks_info: Vec, @@ -39,6 +55,29 @@ impl CudaCompactCiphertextListExpander { } } + pub fn len(&self) -> usize { + self.expanded_blocks.lwe_ciphertext_count().0 + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn get_kind_of(&self, index: usize) -> Option { + let blocks = self.blocks_info.get(index)?; + Some(blocks.data_kind) + } + + pub fn message_modulus(&self, index: usize) -> Option { + let blocks = self.blocks_info.get(index)?; + Some(blocks.info.message_modulus) + } + + pub fn carry_modulus(&self, index: usize) -> Option { + let blocks = self.blocks_info.get(index)?; + Some(blocks.info.carry_modulus) + } + fn blocks_of( &self, index: usize, @@ -92,6 +131,44 @@ impl CudaCompactCiphertextListExpander { .map(|(blocks, kind)| T::from_expanded_blocks(blocks, kind)) .transpose() } + + pub fn to_compact_ciphertext_list_expander( + &self, + streams: &CudaStreams, + ) -> CompactCiphertextListExpander { + let lwe_ciphertext_list = self.expanded_blocks.to_lwe_ciphertext_list(streams); + let ciphertext_modulus = self.expanded_blocks.ciphertext_modulus(); + + let expanded_blocks = lwe_ciphertext_list + .iter() + .zip(self.blocks_info.clone()) + .map(|(ct, info)| { + let lwe = LweCiphertext::from_container(ct.as_ref().to_vec(), ciphertext_modulus); + Ciphertext::new( + lwe, + info.info.degree, + info.info.noise_level, + info.info.message_modulus, + info.info.carry_modulus, + info.info.atomic_pattern, + ) + }) + .collect_vec(); + let info = self + .blocks_info + .iter() + .map(|ct_info| ct_info.data_kind) + .collect_vec(); + + CompactCiphertextListExpander::new(expanded_blocks, info) + } + + pub fn duplicate(&self) -> Self { + Self { + expanded_blocks: self.expanded_blocks.clone(), + blocks_info: self.blocks_info.iter().cloned().collect_vec(), + } + } } impl CudaCompactCiphertextList { diff --git a/tfhe/src/integer/gpu/key_switching_key/mod.rs b/tfhe/src/integer/gpu/key_switching_key/mod.rs index 449d4fab13..3d326b7090 100644 --- a/tfhe/src/integer/gpu/key_switching_key/mod.rs +++ b/tfhe/src/integer/gpu/key_switching_key/mod.rs @@ -1,60 +1,47 @@ use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey; use crate::core_crypto::gpu::CudaStreams; -use crate::integer::client_key::secret_encryption_key::SecretEncryptionKeyView; use crate::integer::gpu::CudaServerKey; -use crate::integer::ClientKey; -use crate::shortint::engine::ShortintEngine; -use crate::shortint::parameters::ShortintKeySwitchingParameters; +use crate::integer::key_switching_key::KeySwitchingKey; use crate::shortint::EncryptionKeyChoice; +#[derive(Clone)] #[allow(dead_code)] -pub struct CudaKeySwitchingKey<'keys> { - pub(crate) key_switching_key: CudaLweKeyswitchKey, - pub(crate) dest_server_key: &'keys CudaServerKey, +pub struct CudaKeySwitchingKeyMaterial { + pub(crate) lwe_keyswitch_key: CudaLweKeyswitchKey, pub(crate) destination_key: EncryptionKeyChoice, } -impl<'keys> CudaKeySwitchingKey<'keys> { - pub fn new<'input_key, InputEncryptionKey>( - input_key_pair: (InputEncryptionKey, Option<&'keys CudaServerKey>), - output_key_pair: (&'keys ClientKey, &'keys CudaServerKey), - params: ShortintKeySwitchingParameters, - streams: &CudaStreams, - ) -> Self - where - InputEncryptionKey: Into>, - { - let input_secret_key: SecretEncryptionKeyView<'_> = input_key_pair.0.into(); +#[allow(dead_code)] +pub struct CudaKeySwitchingKey<'key> { + pub(crate) key_switching_key_material: &'key CudaKeySwitchingKeyMaterial, + pub(crate) dest_server_key: &'key CudaServerKey, +} - // Creation of the key switching key - let key_switching_key = ShortintEngine::with_thread_local_mut(|engine| { - engine.new_key_switching_key(&input_secret_key.key, output_key_pair.0.as_ref(), params) - }); - let d_key_switching_key = - CudaLweKeyswitchKey::from_lwe_keyswitch_key(&key_switching_key, streams); - let full_message_modulus_input = - input_secret_key.key.carry_modulus.0 * input_secret_key.key.message_modulus.0; - let full_message_modulus_output = output_key_pair.0.key.parameters.carry_modulus().0 - * output_key_pair.0.key.parameters.message_modulus().0; - assert!( - full_message_modulus_input.is_power_of_two() - && full_message_modulus_output.is_power_of_two(), - "Cannot create casting key if the full messages moduli are not a power of 2" +impl CudaKeySwitchingKeyMaterial { + pub fn from_key_switching_key( + key_switching_key: &KeySwitchingKey, + streams: &CudaStreams, + ) -> Self { + let key_switching_key_material = &key_switching_key.key.key_switching_key_material; + let d_lwe_keyswich_key = CudaLweKeyswitchKey::from_lwe_keyswitch_key( + &key_switching_key_material.key_switching_key, + streams, ); - if full_message_modulus_input > full_message_modulus_output { - assert!( - input_key_pair.1.is_some(), - "Trying to build a integer::gpu::KeySwitchingKey \ - going from a large modulus {full_message_modulus_input} \ - to a smaller modulus {full_message_modulus_output} \ - without providing a source CudaServerKey, this is not supported" - ); + Self { + lwe_keyswitch_key: d_lwe_keyswich_key, + destination_key: key_switching_key_material.destination_key, } + } +} - CudaKeySwitchingKey { - key_switching_key: d_key_switching_key, - dest_server_key: output_key_pair.1, - destination_key: params.destination_key, +impl<'key> CudaKeySwitchingKey<'key> { + pub fn from_cuda_key_switching_key_material( + key_switching_key_material: &'key CudaKeySwitchingKeyMaterial, + dest_server_key: &'key CudaServerKey, + ) -> Self { + Self { + key_switching_key_material, + dest_server_key, } } } diff --git a/tfhe/src/integer/gpu/zk/mod.rs b/tfhe/src/integer/gpu/zk/mod.rs index 7001807350..fa3eb9596f 100644 --- a/tfhe/src/integer/gpu/zk/mod.rs +++ b/tfhe/src/integer/gpu/zk/mod.rs @@ -18,6 +18,7 @@ use crate::integer::{CompactPublicKey, ProvenCompactCiphertextList}; use crate::shortint::ciphertext::{Degree, NoiseLevel}; use crate::shortint::AtomicPatternKind; use crate::zk::CompactPkeCrs; +use crate::GpuIndex; use itertools::Itertools; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu; @@ -28,6 +29,28 @@ pub struct CudaProvenCompactCiphertextList { } impl CudaProvenCompactCiphertextList { + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self { + h_proved_lists: self.h_proved_lists.clone(), + d_compact_lists: self + .d_compact_lists + .iter() + .map(|ct_list| ct_list.duplicate(streams)) + .collect_vec(), + } + } + + pub fn gpu_indexes(&self) -> &[GpuIndex] { + self.d_compact_lists + .first() + .unwrap() + .d_ct_list + .0 + .d_vec + .gpu_indexes + .as_slice() + } + unsafe fn flatten_async( slice_ciphertext_list: &[CudaCompactCiphertextList], streams: &CudaStreams, @@ -152,13 +175,9 @@ impl CudaProvenCompactCiphertextList { DataKind::Boolean => 1, DataKind::Signed(x) => *x, DataKind::Unsigned(x) => *x, - _ => panic!("DataKind not supported on GPUs"), + DataKind::String { .. } => panic!("DataKind not supported on GPUs"), }; - std::iter::repeat(match data_kind { - DataKind::Boolean => true, - _ => false, - }) - .take(repetitions) + std::iter::repeat_n(matches!(data_kind, DataKind::Boolean), repetitions) }) .collect_vec(); @@ -194,25 +213,22 @@ impl CudaProvenCompactCiphertextList { streams, ); - let d_input = &CudaProvenCompactCiphertextList::flatten_async( - self.d_compact_lists.as_slice(), - streams, - ); - let casting_key = &key.key_switching_key; - let sks = key.dest_server_key; + let d_input = &Self::flatten_async(self.d_compact_lists.as_slice(), streams); + let casting_key = &key.key_switching_key_material; + let sks = &key.dest_server_key; let computing_ks_key = &key.dest_server_key.key_switching_key; - let casting_key_type: KsType = key.destination_key.into(); + let casting_key_type: KsType = casting_key.destination_key.into(); match &sks.bootstrapping_key { CudaBootstrappingKey::Classic(d_bsk) => { expand_async( streams, &mut d_output, - &d_input, + d_input, &d_bsk.d_vec, &computing_ks_key.d_vec, - &casting_key.d_vec, + &casting_key.lwe_keyswitch_key.d_vec, sks.message_modulus, sks.carry_modulus, d_bsk.glwe_dimension(), @@ -220,10 +236,16 @@ impl CudaProvenCompactCiphertextList { d_bsk.input_lwe_dimension(), computing_ks_key.decomposition_level_count(), computing_ks_key.decomposition_base_log(), - casting_key.input_key_lwe_size().to_lwe_dimension(), - casting_key.output_key_lwe_size().to_lwe_dimension(), - casting_key.decomposition_level_count(), - casting_key.decomposition_base_log(), + casting_key + .lwe_keyswitch_key + .input_key_lwe_size() + .to_lwe_dimension(), + casting_key + .lwe_keyswitch_key + .output_key_lwe_size() + .to_lwe_dimension(), + casting_key.lwe_keyswitch_key.decomposition_level_count(), + casting_key.lwe_keyswitch_key.decomposition_base_log(), d_bsk.decomp_level_count, d_bsk.decomp_base_log, PBSType::Classical, @@ -238,13 +260,10 @@ impl CudaProvenCompactCiphertextList { expand_async( streams, &mut d_output, - &CudaProvenCompactCiphertextList::flatten_async( - self.d_compact_lists.as_slice(), - streams, - ), + &Self::flatten_async(self.d_compact_lists.as_slice(), streams), &d_multibit_bsk.d_vec, &computing_ks_key.d_vec, - &casting_key.d_vec, + &casting_key.lwe_keyswitch_key.d_vec, sks.message_modulus, sks.carry_modulus, d_multibit_bsk.glwe_dimension(), @@ -252,10 +271,16 @@ impl CudaProvenCompactCiphertextList { d_multibit_bsk.input_lwe_dimension(), computing_ks_key.decomposition_level_count(), computing_ks_key.decomposition_base_log(), - casting_key.input_key_lwe_size().to_lwe_dimension(), - casting_key.output_key_lwe_size().to_lwe_dimension(), - casting_key.decomposition_level_count(), - casting_key.decomposition_base_log(), + casting_key + .lwe_keyswitch_key + .input_key_lwe_size() + .to_lwe_dimension(), + casting_key + .lwe_keyswitch_key + .output_key_lwe_size() + .to_lwe_dimension(), + casting_key.lwe_keyswitch_key.decomposition_level_count(), + casting_key.lwe_keyswitch_key.decomposition_base_log(), d_multibit_bsk.decomp_level_count, d_multibit_bsk.decomp_base_log, PBSType::MultiBit, @@ -294,7 +319,7 @@ impl CudaProvenCompactCiphertextList { }) .collect(); - CudaProvenCompactCiphertextList { + Self { h_proved_lists: h_proved_lists.clone(), d_compact_lists, } @@ -328,6 +353,18 @@ impl CudaProvenCompactCiphertextList { } } +impl<'de> serde::Deserialize<'de> for CudaProvenCompactCiphertextList { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let cpu_ct = ProvenCompactCiphertextList::deserialize(deserializer)?; + let streams = CudaStreams::new_multi_gpu(); + + Ok(Self::from_proven_compact_ciphertext_list(&cpu_ct, &streams)) + } +} + #[cfg(feature = "zk-pok")] #[cfg(test)] mod tests { @@ -344,9 +381,12 @@ mod tests { use crate::integer::ciphertext::{CompactCiphertextList, DataKind}; use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; - use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey; + use crate::integer::gpu::key_switching_key::{ + CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial, + }; use crate::integer::gpu::zk::CudaProvenCompactCiphertextList; use crate::integer::gpu::CudaServerKey; + use crate::integer::key_switching_key::KeySwitchingKey; use crate::integer::{ ClientKey, CompactPrivateKey, CompactPublicKey, CompressedServerKey, ProvenCompactCiphertextList, @@ -405,15 +445,16 @@ mod tests { let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks); let streams = CudaStreams::new_multi_gpu(); + let sk = compressed_server_key.decompress(); let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams); let compact_private_key = CompactPrivateKey::new(pke_params); - let d_ksk = CudaKeySwitchingKey::new( - (&compact_private_key, None), - (&cks, &gpu_sk), - ksk_params, - &streams, - ); + let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params); + let d_ksk_material = + CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams); + let d_ksk = + CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material, &gpu_sk); + let pk = CompactPublicKey::new(&compact_private_key); let msgs = (0..512) @@ -493,17 +534,17 @@ mod tests { .unwrap(); let cks = ClientKey::new(fhe_params); let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks); - + let sk = compressed_server_key.decompress(); let streams = CudaStreams::new_multi_gpu(); let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams); let compact_private_key = CompactPrivateKey::new(pke_params); - let d_ksk = CudaKeySwitchingKey::new( - (&compact_private_key, None), - (&cks, &gpu_sk), - ksk_params, - &streams, - ); + let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params); + let d_ksk_material = + CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams); + let d_ksk = + CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material, &gpu_sk); + let pk = CompactPublicKey::new(&compact_private_key); let msgs = (0..2).map(|_| random::()).collect::>(); @@ -583,17 +624,17 @@ mod tests { .unwrap(); let cks = ClientKey::new(fhe_params); let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks); - + let sk = compressed_server_key.decompress(); let streams = CudaStreams::new_multi_gpu(); let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams); let compact_private_key = CompactPrivateKey::new(pke_params); - let d_ksk = CudaKeySwitchingKey::new( - (&compact_private_key, None), - (&cks, &gpu_sk), - ksk_params, - &streams, - ); + let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params); + let d_ksk_material = + CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams); + let d_ksk = + CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material, &gpu_sk); + let pk = CompactPublicKey::new(&compact_private_key); let msgs = (0..2).map(|_| random::()).collect::>(); diff --git a/tfhe/src/test_user_docs.rs b/tfhe/src/test_user_docs.rs index d53444f322..b1a03c1793 100644 --- a/tfhe/src/test_user_docs.rs +++ b/tfhe/src/test_user_docs.rs @@ -239,6 +239,10 @@ mod test_gpu_doc { "../docs/configuration/gpu_acceleration/multi_gpu.md", configuration_gpu_acceleration_multi_gpu_device_selection ); + doctest!( + "../docs/configuration/gpu_acceleration/zk-pok.md", + configuration_gpu_acceleration_zk_pok + ); } #[cfg(feature = "hpu")]