Skip to content

Commit 3764ea3

Browse files
committed
feat(gpu): add support for GPU-accelerated expand on the HL Api
1 parent 0a27971 commit 3764ea3

File tree

15 files changed

+694
-108
lines changed

15 files changed

+694
-108
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ test_high_level_api: install_rs_build_toolchain
839839

840840
test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest
841841
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) nextest run --cargo-profile $(CARGO_PROFILE) \
842-
--features=integer,internal-keycache,gpu -p $(TFHE_SPEC) \
842+
--features=integer,internal-keycache,gpu,zk-pok -p $(TFHE_SPEC) \
843843
-E "test(/high_level_api::.*gpu.*/)"
844844

845845
.PHONY: test_strings # Run the tests for strings ci

tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::core_crypto::prelude::{
77
use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu;
88

99
/// A structure representing a vector of LWE ciphertexts with 64 bits of precision on the GPU.
10-
#[derive(Debug)]
10+
#[derive(Clone, Debug)]
1111
pub struct CudaLweCiphertextList<T: UnsignedInteger>(pub(crate) CudaLweList<T>);
1212

1313
#[allow(dead_code)]

tfhe/src/core_crypto/gpu/entities/lwe_compact_ciphertext_list.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ use crate::core_crypto::prelude::{
1010
pub struct CudaLweCompactCiphertextList<T: UnsignedInteger>(pub CudaLweList<T>);
1111

1212
impl<T: UnsignedInteger> CudaLweCompactCiphertextList<T> {
13+
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
14+
Self(self.0.duplicate(streams))
15+
}
16+
1317
pub fn from_lwe_compact_ciphertext_list<C: Container<Element = T>>(
1418
h_ct: &LweCompactCiphertextList<C>,
1519
streams: &CudaStreams,

tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::core_crypto::prelude::{
1111
};
1212

1313
#[allow(dead_code)]
14+
#[derive(Clone)]
1415
pub struct CudaLweKeyswitchKey<T: UnsignedInteger> {
1516
pub(crate) d_vec: CudaVec<T>,
1617
input_lwe_size: LweSize,

tfhe/src/core_crypto/gpu/mod.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ pub unsafe fn fourier_transform_backward_as_torus_f128_async<T: UnsignedInteger>
914914
);
915915
}
916916

917-
#[derive(Debug)]
917+
#[derive(Clone, Debug)]
918918
pub struct CudaLweList<T: UnsignedInteger> {
919919
// Pointer to GPU data
920920
pub d_vec: CudaVec<T>,
@@ -926,6 +926,17 @@ pub struct CudaLweList<T: UnsignedInteger> {
926926
pub ciphertext_modulus: CiphertextModulus<T>,
927927
}
928928

929+
impl<T: UnsignedInteger> CudaLweList<T> {
930+
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
931+
Self {
932+
d_vec: self.d_vec.duplicate(streams),
933+
lwe_ciphertext_count: self.lwe_ciphertext_count,
934+
lwe_dimension: self.lwe_dimension,
935+
ciphertext_modulus: self.ciphertext_modulus,
936+
}
937+
}
938+
}
939+
929940
#[derive(Debug, Clone)]
930941
pub struct CudaGlweList<T: UnsignedInteger> {
931942
// Pointer to GPU data

tfhe/src/high_level_api/backward_compatibility/compact_list.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
33

44
use crate::{CompactCiphertextList, Tag};
55

6+
#[cfg(feature = "zk-pok")]
7+
use crate::ProvenCompactCiphertextList;
8+
69
#[derive(Version)]
710
pub struct CompactCiphertextListV0(crate::integer::ciphertext::CompactCiphertextList);
811

@@ -17,9 +20,6 @@ impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
1720
}
1821
}
1922

20-
#[cfg(feature = "zk-pok")]
21-
use crate::ProvenCompactCiphertextList;
22-
2323
#[derive(VersionsDispatch)]
2424
pub enum CompactCiphertextListVersions {
2525
V0(CompactCiphertextListV0),

0 commit comments

Comments
 (0)