Skip to content

Commit cf38e44

Browse files
committed
feat(gpu): add support for GPU-accelerated expand on the HL Api
- includes documentation about GPU's accelerated expand on the HL API - rework CudaKeySwitchingKey - Cloning the key is no longer necessary on the HL API
1 parent a01949e commit cf38e44

File tree

17 files changed

+829
-211
lines changed

17 files changed

+829
-211
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ check_typos: install_typos_checker
294294
.PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled
295295
clippy_gpu: install_rs_check_toolchain
296296
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
297-
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types \
297+
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types,zk-pok \
298298
--all-targets \
299299
-p $(TFHE_SPEC) -- --no-deps -D warnings
300300

@@ -892,7 +892,7 @@ test_high_level_api: install_rs_build_toolchain
892892

893893
test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest
894894
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) nextest run --cargo-profile $(CARGO_PROFILE) \
895-
--features=integer,internal-keycache,gpu -p $(TFHE_SPEC) \
895+
--features=integer,internal-keycache,gpu,zk-pok -p $(TFHE_SPEC) \
896896
-E "test(/high_level_api::.*gpu.*/)"
897897

898898
test_high_level_api_hpu: install_rs_build_toolchain install_cargo_nextest
@@ -1066,7 +1066,7 @@ check_compile_tests: install_rs_build_toolchain
10661066
.PHONY: check_compile_tests_benches_gpu # Build tests in debug without running them
10671067
check_compile_tests_benches_gpu: install_rs_build_toolchain
10681068
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \
1069-
--features=experimental,boolean,shortint,integer,internal-keycache,gpu \
1069+
--features=experimental,boolean,shortint,integer,internal-keycache,gpu,zk-pok \
10701070
-p $(TFHE_SPEC)
10711071
mkdir -p "$(TFHECUDA_BUILD)" && \
10721072
cd "$(TFHECUDA_BUILD)" && \

tfhe-benchmark/benches/integer/zk_pke.rs

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -430,11 +430,11 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
430430
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
431431
mod cuda {
432432
use super::*;
433-
use benchmark::utilities::{cuda_local_keys, cuda_local_streams};
433+
use benchmark::utilities::cuda_local_streams;
434434
use criterion::BatchSize;
435435
use itertools::Itertools;
436436
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
437-
use tfhe::integer::gpu::key_switching_key::CudaKeySwitchingKey;
437+
use tfhe::integer::gpu::key_switching_key::{CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial};
438438
use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList;
439439
use tfhe::integer::gpu::CudaServerKey;
440440
use tfhe::integer::CompressedServerKey;
@@ -463,14 +463,17 @@ mod cuda {
463463
let param_name = param_name.as_str();
464464
let cks = ClientKey::new(param_fhe);
465465
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
466+
let sk = compressed_server_key.decompress();
466467
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
468+
467469
let compact_private_key = CompactPrivateKey::new(param_pke);
468470
let pk = CompactPublicKey::new(&compact_private_key);
469-
let d_ksk = CudaKeySwitchingKey::new(
470-
(&compact_private_key, None),
471-
(&cks, &gpu_sks),
472-
param_ksk,
473-
&streams,
471+
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
472+
let d_ksk_material =
473+
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
474+
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
475+
&d_ksk_material,
476+
&gpu_sks,
474477
);
475478

476479
// We have a use case with 320 bits of metadata
@@ -621,7 +624,6 @@ mod cuda {
621624
});
622625
}
623626
BenchmarkType::Throughput => {
624-
let gpu_sks_vec = cuda_local_keys(&cks);
625627
let gpu_count = get_number_of_gpus() as usize;
626628

627629
// Execute the operation once to know its cost.
@@ -666,20 +668,17 @@ mod cuda {
666668
.collect::<Vec<_>>();
667669

668670
let local_streams = cuda_local_streams(num_block, elements as usize);
669-
let d_ksk_vec = gpu_sks_vec
671+
let d_ksk_material_vec = local_streams
670672
.par_iter()
671-
.zip(local_streams.par_iter())
672-
.map(|(gpu_sks, local_stream)| {
673-
CudaKeySwitchingKey::new(
674-
(&compact_private_key, None),
675-
(&cks, gpu_sks),
676-
param_ksk,
673+
.map(|local_stream| {
674+
CudaKeySwitchingKeyMaterial::from_key_switching_key(
675+
&ksk,
677676
local_stream,
678677
)
679678
})
680679
.collect::<Vec<_>>();
681680

682-
assert_eq!(d_ksk_vec.len(), gpu_count);
681+
assert_eq!(d_ksk_material_vec.len(), gpu_count);
683682

684683
bench_group.bench_function(&bench_id_verify, |b| {
685684
b.iter(|| {
@@ -702,14 +701,16 @@ mod cuda {
702701
(gpu_cts, local_streams)
703702
};
704703

705-
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
706-
gpu_cts.par_iter()
707-
.zip(local_streams.par_iter())
708-
.enumerate()
709-
.for_each(|(i, (gpu_ct, local_stream))| {
710-
gpu_ct
711-
.expand_without_verification(&d_ksk_vec[i % gpu_count], local_stream)
712-
.unwrap();
704+
b.iter_batched(setup_encrypted_values,
705+
|(gpu_cts, local_streams)| {
706+
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
707+
(|(i, (gpu_ct, local_stream))| {
708+
let d_ksk =
709+
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);
710+
711+
gpu_ct
712+
.expand_without_verification(&d_ksk, local_stream)
713+
.unwrap();
713714
});
714715
}, BatchSize::SmallInput);
715716
});
@@ -727,16 +728,15 @@ mod cuda {
727728
(gpu_cts, local_streams)
728729
};
729730

730-
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
731-
gpu_cts
732-
.par_iter()
733-
.zip(local_streams.par_iter())
734-
.for_each(|(gpu_ct, local_stream)| {
735-
gpu_ct
736-
.verify_and_expand(
737-
&crs, &pk, &metadata, &d_ksk, local_stream
738-
)
739-
.unwrap();
731+
b.iter_batched(setup_encrypted_values,
732+
|(gpu_cts, local_streams)| {
733+
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
734+
(|(gpu_ct, local_stream)| {
735+
gpu_ct
736+
.verify_and_expand(
737+
&crs, &pk, &metadata, &d_ksk, local_stream,
738+
)
739+
.unwrap();
740740
});
741741
}, BatchSize::SmallInput);
742742
});
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Zero-knowledge proofs
2+
3+
Zero-knowledge proofs (ZK) are a powerful tool to assert that the encryption of a message is correct, as discussed in [advanced features](../../fhe-computation/advanced-features/zk-pok.md).
4+
However, computation is not possible on the type of ciphertexts it produces (i.e. `ProvenCompactCiphertextList`). This document explains how to use the GPU to accelerate the
5+
preprocessing step needed to convert ciphertexts formatted for ZK to ciphertexts in the right format for computation purposes on GPU. This
6+
operation is called "expansion".
7+
8+
## Proven compact ciphertext list
9+
10+
A proven compact list of ciphertexts can be seen as a compacted collection of ciphertexts which encryption can be verified.
11+
This verification is currently only supported on the CPU, but the expansion can be accelerated using the GPU.
12+
This way, verification and expansion can be performed in parallel, efficiently using all the available computational resources.
13+
14+
## Supported types
15+
Encrypted messages can be integers (like FheUint64) or booleans. The GPU backend does not currently support encrypted strings.
16+
17+
{% hint style="info" %}
18+
You can enable this feature using the flag: `--features=zk-pok,gpu` when building **TFHE-rs**.
19+
{% endhint %}
20+
21+
22+
## Example
23+
24+
The following example shows how a client can encrypt and prove a ciphertext, and how a server can verify the proof, preprocess the ciphertext and run a computation on it on GPU:
25+
26+
```rust
27+
use rand::random;
28+
use tfhe::CompressedServerKey;
29+
use tfhe::prelude::*;
30+
use tfhe::set_server_key;
31+
use tfhe::zk::{CompactPkeCrs, ZkComputeLoad};
32+
33+
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
34+
let params = tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
35+
// Indicate which parameters to use for the Compact Public Key encryption
36+
let cpk_params = tfhe::shortint::parameters::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
37+
// And parameters allowing to keyswitch/cast to the computation parameters.
38+
let casting_params = tfhe::shortint::parameters::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
39+
// Enable the dedicated parameters on the config
40+
let config = tfhe::ConfigBuilder::with_custom_parameters(params)
41+
.use_dedicated_compact_public_key_parameters((cpk_params, casting_params)).build();
42+
43+
// The CRS should be generated in an offline phase then shared to all clients and the server
44+
let crs = CompactPkeCrs::from_config(config, 64).unwrap();
45+
46+
// Then use TFHE-rs as usual
47+
let client_key = tfhe::ClientKey::generate(config);
48+
let compressed_server_key = CompressedServerKey::new(&client_key);
49+
let gpu_server_key = compressed_server_key.decompress_to_gpu();
50+
51+
let public_key = tfhe::CompactPublicKey::try_new(&client_key).unwrap();
52+
// This can be left empty, but if provided allows to tie the proof to arbitrary data
53+
let metadata = [b'T', b'F', b'H', b'E', b'-', b'r', b's'];
54+
55+
let clear_a = random::<u64>();
56+
let clear_b = random::<u64>();
57+
58+
let proven_compact_list = tfhe::ProvenCompactCiphertextList::builder(&public_key)
59+
.push(clear_a)
60+
.push(clear_b)
61+
.build_with_proof_packed(&crs, &metadata, ZkComputeLoad::Verify)?;
62+
63+
// Server side
64+
let result = {
65+
set_server_key(gpu_server_key);
66+
67+
// Verify the ciphertexts
68+
let expander =
69+
proven_compact_list.verify_and_expand(&crs, &public_key, &metadata)?;
70+
let a: tfhe::FheUint64 = expander.get(0)?.unwrap();
71+
let b: tfhe::FheUint64 = expander.get(1)?.unwrap();
72+
73+
a + b
74+
};
75+
76+
// Back on the client side
77+
let a_plus_b: u64 = result.decrypt(&client_key);
78+
assert_eq!(a_plus_b, clear_a.wrapping_add(clear_b));
79+
80+
Ok(())
81+
}
82+
```

tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::core_crypto::prelude::{
77
use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu;
88

99
/// A structure representing a vector of LWE ciphertexts with 64 bits of precision on the GPU.
10-
#[derive(Debug)]
10+
#[derive(Clone, Debug)]
1111
pub struct CudaLweCiphertextList<T: UnsignedInteger>(pub(crate) CudaLweList<T>);
1212

1313
#[allow(dead_code)]

tfhe/src/core_crypto/gpu/entities/lwe_compact_ciphertext_list.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ use crate::core_crypto::prelude::{
1010
pub struct CudaLweCompactCiphertextList<T: UnsignedInteger>(pub CudaLweList<T>);
1111

1212
impl<T: UnsignedInteger> CudaLweCompactCiphertextList<T> {
13+
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
14+
Self(self.0.duplicate(streams))
15+
}
16+
1317
pub fn from_lwe_compact_ciphertext_list<C: Container<Element = T>>(
1418
h_ct: &LweCompactCiphertextList<C>,
1519
streams: &CudaStreams,

tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::core_crypto::prelude::{
1010
UnsignedInteger,
1111
};
1212

13+
#[derive(Clone)]
1314
#[allow(dead_code)]
1415
pub struct CudaLweKeyswitchKey<T: UnsignedInteger> {
1516
pub(crate) d_vec: CudaVec<T>,

tfhe/src/core_crypto/gpu/mod.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ pub unsafe fn fourier_transform_backward_as_torus_f128_async<T: UnsignedInteger>
993993
);
994994
}
995995

996-
#[derive(Debug)]
996+
#[derive(Clone, Debug)]
997997
pub struct CudaLweList<T: UnsignedInteger> {
998998
// Pointer to GPU data
999999
pub d_vec: CudaVec<T>,
@@ -1005,6 +1005,17 @@ pub struct CudaLweList<T: UnsignedInteger> {
10051005
pub ciphertext_modulus: CiphertextModulus<T>,
10061006
}
10071007

1008+
impl<T: UnsignedInteger> CudaLweList<T> {
1009+
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
1010+
Self {
1011+
d_vec: self.d_vec.duplicate(streams),
1012+
lwe_ciphertext_count: self.lwe_ciphertext_count,
1013+
lwe_dimension: self.lwe_dimension,
1014+
ciphertext_modulus: self.ciphertext_modulus,
1015+
}
1016+
}
1017+
}
1018+
10081019
#[derive(Debug, Clone)]
10091020
pub struct CudaGlweList<T: UnsignedInteger> {
10101021
// Pointer to GPU data

tfhe/src/high_level_api/backward_compatibility/compact_list.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
33

44
use crate::{CompactCiphertextList, Tag};
55

6+
#[cfg(feature = "zk-pok")]
7+
use crate::ProvenCompactCiphertextList;
8+
69
#[derive(Version)]
710
pub struct CompactCiphertextListV0(crate::integer::ciphertext::CompactCiphertextList);
811

@@ -17,9 +20,6 @@ impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
1720
}
1821
}
1922

20-
#[cfg(feature = "zk-pok")]
21-
use crate::ProvenCompactCiphertextList;
22-
2323
#[derive(VersionsDispatch)]
2424
pub enum CompactCiphertextListVersions {
2525
V0(CompactCiphertextListV0),

0 commit comments

Comments
 (0)