Skip to content

Commit 7dd45a4

Browse files
committed
feat(gpu): add support for GPU-accelerated expand on the HL Api
- includes documentation about GPU's accelerated expand on the HL API - rework CudaKeySwitchingKey - Cloning the key is no longer necessary on the HL API
1 parent 259d125 commit 7dd45a4

File tree

17 files changed

+829
-211
lines changed

17 files changed

+829
-211
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ check_typos: install_typos_checker
294294
.PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled
295295
clippy_gpu: install_rs_check_toolchain
296296
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
297-
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types \
297+
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types,zk-pok \
298298
--all-targets \
299299
-p $(TFHE_SPEC) -- --no-deps -D warnings
300300

@@ -892,7 +892,7 @@ test_high_level_api: install_rs_build_toolchain
892892

893893
test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest
894894
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) nextest run --cargo-profile $(CARGO_PROFILE) \
895-
--features=integer,internal-keycache,gpu -p $(TFHE_SPEC) \
895+
--features=integer,internal-keycache,gpu,zk-pok -p $(TFHE_SPEC) \
896896
-E "test(/high_level_api::.*gpu.*/)"
897897

898898
test_high_level_api_hpu: install_rs_build_toolchain install_cargo_nextest
@@ -1066,7 +1066,7 @@ check_compile_tests: install_rs_build_toolchain
10661066
.PHONY: check_compile_tests_benches_gpu # Build tests in debug without running them
10671067
check_compile_tests_benches_gpu: install_rs_build_toolchain
10681068
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \
1069-
--features=experimental,boolean,shortint,integer,internal-keycache,gpu \
1069+
--features=experimental,boolean,shortint,integer,internal-keycache,gpu,zk-pok \
10701070
-p $(TFHE_SPEC)
10711071
mkdir -p "$(TFHECUDA_BUILD)" && \
10721072
cd "$(TFHECUDA_BUILD)" && \

tfhe-benchmark/benches/integer/zk_pke.rs

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,11 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
418418
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
419419
mod cuda {
420420
use super::*;
421-
use benchmark::utilities::{cuda_local_keys, cuda_local_streams};
421+
use benchmark::utilities::cuda_local_streams;
422422
use criterion::BatchSize;
423423
use itertools::Itertools;
424424
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
425-
use tfhe::integer::gpu::key_switching_key::CudaKeySwitchingKey;
425+
use tfhe::integer::gpu::key_switching_key::{CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial};
426426
use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList;
427427
use tfhe::integer::gpu::CudaServerKey;
428428
use tfhe::integer::CompressedServerKey;
@@ -451,14 +451,17 @@ mod cuda {
451451
let param_name = param_name.as_str();
452452
let cks = ClientKey::new(param_fhe);
453453
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
454+
let sk = compressed_server_key.decompress();
454455
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
456+
455457
let compact_private_key = CompactPrivateKey::new(param_pke);
456458
let pk = CompactPublicKey::new(&compact_private_key);
457-
let d_ksk = CudaKeySwitchingKey::new(
458-
(&compact_private_key, None),
459-
(&cks, &gpu_sks),
460-
param_ksk,
461-
&streams,
459+
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
460+
let d_ksk_material =
461+
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
462+
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
463+
&d_ksk_material,
464+
&gpu_sks,
462465
);
463466

464467
// We have a use case with 320 bits of metadata
@@ -609,7 +612,6 @@ mod cuda {
609612
});
610613
}
611614
BenchmarkType::Throughput => {
612-
let gpu_sks_vec = cuda_local_keys(&cks);
613615
let gpu_count = get_number_of_gpus() as usize;
614616

615617
let elements = zk_throughput_num_elements();
@@ -637,20 +639,17 @@ mod cuda {
637639
.collect::<Vec<_>>();
638640

639641
let local_streams = cuda_local_streams(num_block, elements as usize);
640-
let d_ksk_vec = gpu_sks_vec
642+
let d_ksk_material_vec = local_streams
641643
.par_iter()
642-
.zip(local_streams.par_iter())
643-
.map(|(gpu_sks, local_stream)| {
644-
CudaKeySwitchingKey::new(
645-
(&compact_private_key, None),
646-
(&cks, gpu_sks),
647-
param_ksk,
644+
.map(|local_stream| {
645+
CudaKeySwitchingKeyMaterial::from_key_switching_key(
646+
&ksk,
648647
local_stream,
649648
)
650649
})
651650
.collect::<Vec<_>>();
652651

653-
assert_eq!(d_ksk_vec.len(), gpu_count);
652+
assert_eq!(d_ksk_material_vec.len(), gpu_count);
654653

655654
bench_group.bench_function(&bench_id_verify, |b| {
656655
b.iter(|| {
@@ -673,14 +672,16 @@ mod cuda {
673672
(gpu_cts, local_streams)
674673
};
675674

676-
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
677-
gpu_cts.par_iter()
678-
.zip(local_streams.par_iter())
679-
.enumerate()
680-
.for_each(|(i, (gpu_ct, local_stream))| {
681-
gpu_ct
682-
.expand_without_verification(&d_ksk_vec[i % gpu_count], local_stream)
683-
.unwrap();
675+
b.iter_batched(setup_encrypted_values,
676+
|(gpu_cts, local_streams)| {
677+
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
678+
(|(i, (gpu_ct, local_stream))| {
679+
let d_ksk =
680+
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);
681+
682+
gpu_ct
683+
.expand_without_verification(&d_ksk, local_stream)
684+
.unwrap();
684685
});
685686
}, BatchSize::SmallInput);
686687
});
@@ -698,16 +699,15 @@ mod cuda {
698699
(gpu_cts, local_streams)
699700
};
700701

701-
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
702-
gpu_cts
703-
.par_iter()
704-
.zip(local_streams.par_iter())
705-
.for_each(|(gpu_ct, local_stream)| {
706-
gpu_ct
707-
.verify_and_expand(
708-
&crs, &pk, &metadata, &d_ksk, local_stream
709-
)
710-
.unwrap();
702+
b.iter_batched(setup_encrypted_values,
703+
|(gpu_cts, local_streams)| {
704+
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
705+
(|(gpu_ct, local_stream)| {
706+
gpu_ct
707+
.verify_and_expand(
708+
&crs, &pk, &metadata, &d_ksk, local_stream,
709+
)
710+
.unwrap();
711711
});
712712
}, BatchSize::SmallInput);
713713
});
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Zero-knowledge proofs
2+
3+
Zero-knowledge proofs (ZK) are a powerful tool to assert that the encryption of a message is correct, as discussed in [advanced features](../../fhe-computation/advanced-features/zk-pok.md).
4+
However, computation is not possible on the type of ciphertexts it produces (i.e. `ProvenCompactCiphertextList`). This document explains how to use the GPU to accelerate the
5+
preprocessing step needed to convert ciphertexts formatted for ZK to ciphertexts in the right format for computation purposes on GPU. This
6+
operation is called "expansion".
7+
8+
## Proven compact ciphertext list
9+
10+
A proven compact list of ciphertexts can be seen as a compacted collection of ciphertexts which encryption can be verified.
11+
This verification is currently only supported on the CPU, but the expansion can be accelerated using the GPU.
12+
This way, verification and expansion can be performed in parallel, efficiently using all the available computational resources.
13+
14+
## Supported types
15+
Encrypted messages can be integers (like FheUint64) or booleans. The GPU backend does not currently support encrypted strings.
16+
17+
{% hint style="info" %}
18+
You can enable this feature using the flag: `--features=zk-pok,gpu` when building **TFHE-rs**.
19+
{% endhint %}
20+
21+
22+
## Example
23+
24+
The following example shows how a client can encrypt and prove a ciphertext, and how a server can verify the proof, preprocess the ciphertext and run a computation on it on GPU:
25+
26+
```rust
27+
use rand::random;
28+
use tfhe::CompressedServerKey;
29+
use tfhe::prelude::*;
30+
use tfhe::set_server_key;
31+
use tfhe::zk::{CompactPkeCrs, ZkComputeLoad};
32+
33+
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
34+
let params = tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
35+
// Indicate which parameters to use for the Compact Public Key encryption
36+
let cpk_params = tfhe::shortint::parameters::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
37+
// And parameters allowing to keyswitch/cast to the computation parameters.
38+
let casting_params = tfhe::shortint::parameters::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
39+
// Enable the dedicated parameters on the config
40+
let config = tfhe::ConfigBuilder::with_custom_parameters(params)
41+
.use_dedicated_compact_public_key_parameters((cpk_params, casting_params)).build();
42+
43+
// The CRS should be generated in an offline phase then shared to all clients and the server
44+
let crs = CompactPkeCrs::from_config(config, 64).unwrap();
45+
46+
// Then use TFHE-rs as usual
47+
let client_key = tfhe::ClientKey::generate(config);
48+
let compressed_server_key = CompressedServerKey::new(&client_key);
49+
let gpu_server_key = compressed_server_key.decompress_to_gpu();
50+
51+
let public_key = tfhe::CompactPublicKey::try_new(&client_key).unwrap();
52+
// This can be left empty, but if provided allows to tie the proof to arbitrary data
53+
let metadata = [b'T', b'F', b'H', b'E', b'-', b'r', b's'];
54+
55+
let clear_a = random::<u64>();
56+
let clear_b = random::<u64>();
57+
58+
let proven_compact_list = tfhe::ProvenCompactCiphertextList::builder(&public_key)
59+
.push(clear_a)
60+
.push(clear_b)
61+
.build_with_proof_packed(&crs, &metadata, ZkComputeLoad::Verify)?;
62+
63+
// Server side
64+
let result = {
65+
set_server_key(gpu_server_key);
66+
67+
// Verify the ciphertexts
68+
let expander =
69+
proven_compact_list.verify_and_expand(&crs, &public_key, &metadata)?;
70+
let a: tfhe::FheUint64 = expander.get(0)?.unwrap();
71+
let b: tfhe::FheUint64 = expander.get(1)?.unwrap();
72+
73+
a + b
74+
};
75+
76+
// Back on the client side
77+
let a_plus_b: u64 = result.decrypt(&client_key);
78+
assert_eq!(a_plus_b, clear_a.wrapping_add(clear_b));
79+
80+
Ok(())
81+
}
82+
```

tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::core_crypto::prelude::{
77
use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu;
88

99
/// A structure representing a vector of LWE ciphertexts with 64 bits of precision on the GPU.
10-
#[derive(Debug)]
10+
#[derive(Clone, Debug)]
1111
pub struct CudaLweCiphertextList<T: UnsignedInteger>(pub(crate) CudaLweList<T>);
1212

1313
#[allow(dead_code)]

tfhe/src/core_crypto/gpu/entities/lwe_compact_ciphertext_list.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ use crate::core_crypto::prelude::{
1010
pub struct CudaLweCompactCiphertextList<T: UnsignedInteger>(pub CudaLweList<T>);
1111

1212
impl<T: UnsignedInteger> CudaLweCompactCiphertextList<T> {
13+
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
14+
Self(self.0.duplicate(streams))
15+
}
16+
1317
pub fn from_lwe_compact_ciphertext_list<C: Container<Element = T>>(
1418
h_ct: &LweCompactCiphertextList<C>,
1519
streams: &CudaStreams,

tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::core_crypto::prelude::{
1010
UnsignedInteger,
1111
};
1212

13+
#[derive(Clone)]
1314
#[allow(dead_code)]
1415
pub struct CudaLweKeyswitchKey<T: UnsignedInteger> {
1516
pub(crate) d_vec: CudaVec<T>,

tfhe/src/core_crypto/gpu/mod.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ pub unsafe fn fourier_transform_backward_as_torus_f128_async<T: UnsignedInteger>
993993
);
994994
}
995995

996-
#[derive(Debug)]
996+
#[derive(Clone, Debug)]
997997
pub struct CudaLweList<T: UnsignedInteger> {
998998
// Pointer to GPU data
999999
pub d_vec: CudaVec<T>,
@@ -1005,6 +1005,17 @@ pub struct CudaLweList<T: UnsignedInteger> {
10051005
pub ciphertext_modulus: CiphertextModulus<T>,
10061006
}
10071007

1008+
impl<T: UnsignedInteger> CudaLweList<T> {
1009+
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
1010+
Self {
1011+
d_vec: self.d_vec.duplicate(streams),
1012+
lwe_ciphertext_count: self.lwe_ciphertext_count,
1013+
lwe_dimension: self.lwe_dimension,
1014+
ciphertext_modulus: self.ciphertext_modulus,
1015+
}
1016+
}
1017+
}
1018+
10081019
#[derive(Debug, Clone)]
10091020
pub struct CudaGlweList<T: UnsignedInteger> {
10101021
// Pointer to GPU data

tfhe/src/high_level_api/backward_compatibility/compact_list.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
33

44
use crate::{CompactCiphertextList, Tag};
55

6+
#[cfg(feature = "zk-pok")]
7+
use crate::ProvenCompactCiphertextList;
8+
69
#[derive(Version)]
710
pub struct CompactCiphertextListV0(crate::integer::ciphertext::CompactCiphertextList);
811

@@ -17,9 +20,6 @@ impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
1720
}
1821
}
1922

20-
#[cfg(feature = "zk-pok")]
21-
use crate::ProvenCompactCiphertextList;
22-
2323
#[derive(VersionsDispatch)]
2424
pub enum CompactCiphertextListVersions {
2525
V0(CompactCiphertextListV0),

0 commit comments

Comments
 (0)