Skip to content

feat(gpu): add support for GPU-accelerated expand on the HL Api #2276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ check_typos: install_typos_checker
.PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled
clippy_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types \
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types,zk-pok \
--all-targets \
-p $(TFHE_SPEC) -- --no-deps -D warnings

Expand Down Expand Up @@ -892,7 +892,7 @@ test_high_level_api: install_rs_build_toolchain

test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) nextest run --cargo-profile $(CARGO_PROFILE) \
--features=integer,internal-keycache,gpu -p $(TFHE_SPEC) \
--features=integer,internal-keycache,gpu,zk-pok -p $(TFHE_SPEC) \
-E "test(/high_level_api::.*gpu.*/)"

test_high_level_api_hpu: install_rs_build_toolchain install_cargo_nextest
Expand Down Expand Up @@ -1066,7 +1066,7 @@ check_compile_tests: install_rs_build_toolchain
.PHONY: check_compile_tests_benches_gpu # Build tests in debug without running them
check_compile_tests_benches_gpu: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \
--features=experimental,boolean,shortint,integer,internal-keycache,gpu \
--features=experimental,boolean,shortint,integer,internal-keycache,gpu,zk-pok \
-p $(TFHE_SPEC)
mkdir -p "$(TFHECUDA_BUILD)" && \
cd "$(TFHECUDA_BUILD)" && \
Expand Down
68 changes: 34 additions & 34 deletions tfhe-benchmark/benches/integer/zk_pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,11 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
mod cuda {
use super::*;
use benchmark::utilities::{cuda_local_keys, cuda_local_streams};
use benchmark::utilities::cuda_local_streams;
use criterion::BatchSize;
use itertools::Itertools;
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
use tfhe::integer::gpu::key_switching_key::CudaKeySwitchingKey;
use tfhe::integer::gpu::key_switching_key::{CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial};
use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList;
use tfhe::integer::gpu::CudaServerKey;
use tfhe::integer::CompressedServerKey;
Expand Down Expand Up @@ -451,14 +451,17 @@ mod cuda {
let param_name = param_name.as_str();
let cks = ClientKey::new(param_fhe);
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
let sk = compressed_server_key.decompress();
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);

let compact_private_key = CompactPrivateKey::new(param_pke);
let pk = CompactPublicKey::new(&compact_private_key);
let d_ksk = CudaKeySwitchingKey::new(
(&compact_private_key, None),
(&cks, &gpu_sks),
param_ksk,
&streams,
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
let d_ksk_material =
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
&d_ksk_material,
&gpu_sks,
);

// We have a use case with 320 bits of metadata
Expand Down Expand Up @@ -609,7 +612,6 @@ mod cuda {
});
}
BenchmarkType::Throughput => {
let gpu_sks_vec = cuda_local_keys(&cks);
let gpu_count = get_number_of_gpus() as usize;

let elements = zk_throughput_num_elements();
Expand Down Expand Up @@ -637,20 +639,17 @@ mod cuda {
.collect::<Vec<_>>();

let local_streams = cuda_local_streams(num_block, elements as usize);
let d_ksk_vec = gpu_sks_vec
let d_ksk_material_vec = local_streams
.par_iter()
.zip(local_streams.par_iter())
.map(|(gpu_sks, local_stream)| {
CudaKeySwitchingKey::new(
(&compact_private_key, None),
(&cks, gpu_sks),
param_ksk,
.map(|local_stream| {
CudaKeySwitchingKeyMaterial::from_key_switching_key(
&ksk,
local_stream,
)
})
.collect::<Vec<_>>();

assert_eq!(d_ksk_vec.len(), gpu_count);
assert_eq!(d_ksk_material_vec.len(), gpu_count);

bench_group.bench_function(&bench_id_verify, |b| {
b.iter(|| {
Expand All @@ -673,14 +672,16 @@ mod cuda {
(gpu_cts, local_streams)
};

b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
gpu_cts.par_iter()
.zip(local_streams.par_iter())
.enumerate()
.for_each(|(i, (gpu_ct, local_stream))| {
gpu_ct
.expand_without_verification(&d_ksk_vec[i % gpu_count], local_stream)
.unwrap();
b.iter_batched(setup_encrypted_values,
|(gpu_cts, local_streams)| {
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
(|(i, (gpu_ct, local_stream))| {
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);

gpu_ct
.expand_without_verification(&d_ksk, local_stream)
.unwrap();
});
}, BatchSize::SmallInput);
});
Expand All @@ -698,16 +699,15 @@ mod cuda {
(gpu_cts, local_streams)
};

b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
gpu_cts
.par_iter()
.zip(local_streams.par_iter())
.for_each(|(gpu_ct, local_stream)| {
gpu_ct
.verify_and_expand(
&crs, &pk, &metadata, &d_ksk, local_stream
)
.unwrap();
b.iter_batched(setup_encrypted_values,
|(gpu_cts, local_streams)| {
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
(|(gpu_ct, local_stream)| {
gpu_ct
.verify_and_expand(
&crs, &pk, &metadata, &d_ksk, local_stream,
)
.unwrap();
});
}, BatchSize::SmallInput);
});
Expand Down
82 changes: 82 additions & 0 deletions tfhe/docs/configuration/gpu_acceleration/zk-pok.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Zero-knowledge proofs

Zero-knowledge proofs (ZK) are a powerful tool to assert that the encryption of a message is correct, as discussed in [advanced features](../../fhe-computation/advanced-features/zk-pok.md).
However, computation is not possible on the type of ciphertexts it produces (i.e. `ProvenCompactCiphertextList`). This document explains how to use the GPU to accelerate the
preprocessing step needed to convert ciphertexts formatted for ZK to ciphertexts in the right format for computation purposes on GPU. This
operation is called "expansion".

## Proven compact ciphertext list

A proven compact list of ciphertexts can be seen as a compacted collection of ciphertexts for which encryption can be verified.
This verification is currently only supported on the CPU, but the expansion can be accelerated using the GPU.
This way, verification and expansion can be performed in parallel, efficiently using all the available computational resources.

## Supported types
Encrypted messages can be integers (like FheUint64) or booleans. The GPU backend does not currently support encrypted strings.

{% hint style="info" %}
You can enable this feature using the flag: `--features=zk-pok,gpu` when building **TFHE-rs**.
{% endhint %}


## Example

The following example shows how a client can encrypt and prove a ciphertext, and how a server can verify the proof, preprocess the ciphertext and run a computation on it on GPU:

```rust
use rand::random;
use tfhe::CompressedServerKey;
use tfhe::prelude::*;
use tfhe::set_server_key;
use tfhe::zk::{CompactPkeCrs, ZkComputeLoad};

pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let params = tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
// Indicate which parameters to use for the Compact Public Key encryption
let cpk_params = tfhe::shortint::parameters::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
// And parameters allowing to keyswitch/cast to the computation parameters.
let casting_params = tfhe::shortint::parameters::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
// Enable the dedicated parameters on the config
let config = tfhe::ConfigBuilder::with_custom_parameters(params)
.use_dedicated_compact_public_key_parameters((cpk_params, casting_params)).build();

// The CRS should be generated in an offline phase then shared to all clients and the server
let crs = CompactPkeCrs::from_config(config, 64).unwrap();

// Then use TFHE-rs as usual
let client_key = tfhe::ClientKey::generate(config);
let compressed_server_key = CompressedServerKey::new(&client_key);
let gpu_server_key = compressed_server_key.decompress_to_gpu();

let public_key = tfhe::CompactPublicKey::try_new(&client_key).unwrap();
// This can be left empty, but if provided allows to tie the proof to arbitrary data
let metadata = [b'T', b'F', b'H', b'E', b'-', b'r', b's'];

let clear_a = random::<u64>();
let clear_b = random::<u64>();

let proven_compact_list = tfhe::ProvenCompactCiphertextList::builder(&public_key)
.push(clear_a)
.push(clear_b)
.build_with_proof_packed(&crs, &metadata, ZkComputeLoad::Verify)?;

// Server side
let result = {
set_server_key(gpu_server_key);

// Verify the ciphertexts
let expander =
proven_compact_list.verify_and_expand(&crs, &public_key, &metadata)?;
let a: tfhe::FheUint64 = expander.get(0)?.unwrap();
let b: tfhe::FheUint64 = expander.get(1)?.unwrap();

a + b
};

// Back on the client side
let a_plus_b: u64 = result.decrypt(&client_key);
assert_eq!(a_plus_b, clear_a.wrapping_add(clear_b));

Ok(())
}
```
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::core_crypto::prelude::{
use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu;

/// A structure representing a vector of LWE ciphertexts with 64 bits of precision on the GPU.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct CudaLweCiphertextList<T: UnsignedInteger>(pub(crate) CudaLweList<T>);

#[allow(dead_code)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ use crate::core_crypto::prelude::{
pub struct CudaLweCompactCiphertextList<T: UnsignedInteger>(pub CudaLweList<T>);

impl<T: UnsignedInteger> CudaLweCompactCiphertextList<T> {
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
Self(self.0.duplicate(streams))
}

pub fn from_lwe_compact_ciphertext_list<C: Container<Element = T>>(
h_ct: &LweCompactCiphertextList<C>,
streams: &CudaStreams,
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::core_crypto::prelude::{
UnsignedInteger,
};

#[derive(Clone)]
#[allow(dead_code)]
pub struct CudaLweKeyswitchKey<T: UnsignedInteger> {
pub(crate) d_vec: CudaVec<T>,
Expand Down
13 changes: 12 additions & 1 deletion tfhe/src/core_crypto/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ pub unsafe fn fourier_transform_backward_as_torus_f128_async<T: UnsignedInteger>
);
}

#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct CudaLweList<T: UnsignedInteger> {
// Pointer to GPU data
pub d_vec: CudaVec<T>,
Expand All @@ -1005,6 +1005,17 @@ pub struct CudaLweList<T: UnsignedInteger> {
pub ciphertext_modulus: CiphertextModulus<T>,
}

impl<T: UnsignedInteger> CudaLweList<T> {
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
Self {
d_vec: self.d_vec.duplicate(streams),
lwe_ciphertext_count: self.lwe_ciphertext_count,
lwe_dimension: self.lwe_dimension,
ciphertext_modulus: self.ciphertext_modulus,
}
}
}

#[derive(Debug, Clone)]
pub struct CudaGlweList<T: UnsignedInteger> {
// Pointer to GPU data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use crate::{CompactCiphertextList, Tag};

#[cfg(feature = "zk-pok")]
use crate::ProvenCompactCiphertextList;

#[derive(Version)]
pub struct CompactCiphertextListV0(crate::integer::ciphertext::CompactCiphertextList);

Expand All @@ -17,9 +20,6 @@ impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
}
}

#[cfg(feature = "zk-pok")]
use crate::ProvenCompactCiphertextList;

#[derive(VersionsDispatch)]
pub enum CompactCiphertextListVersions {
V0(CompactCiphertextListV0),
Expand Down
Loading
Loading