Compare commits

...

27 Commits

Author SHA1 Message Date
Agnes Leroy
de710cb2fb Debug with long run tests 2025-05-28 12:15:52 +02:00
Agnes Leroy
59a78c76a9 fix(gpu): fix build after shift/rotate mem tracking merge 2025-05-28 12:08:09 +02:00
Pedro Alves
1025246b17 fix(gpu): fix a linking problem on Hopper GPUs 2025-05-28 09:27:33 +02:00
Agnes Leroy
338e9eaeef feat(gpu): add memory tracking functions for shift/rotate 2025-05-28 09:26:27 +02:00
David Testé
0bec4d2ba1 chore(ci): pin rust-toolchain action to v1 2025-05-27 17:31:33 +02:00
David Testé
c5fab98900 chore(ci): add token to do online workflow security checks 2025-05-27 17:31:33 +02:00
Nicolas Sarlin
14e1ee5bd3 fix(gpu): build with hpu and zk features 2025-05-27 16:10:38 +02:00
Pedro Alves
52bc778629 feat(gpu): completely remove the internal CUDA_STREAMS in the HL API
- From now on the streams stored in the available cuda server key are the ones to be
2025-05-27 10:29:34 -03:00
Pedro Alves
10405c9836 feat(gpu): improve test_specific_gpu_selection() so it always tests all possible GPU configurations 2025-05-27 10:29:34 -03:00
Pedro Alves
5eaf6cec55 feat(gpu): reintroduce the feature that allows a user to perform computation on multi-gpu using a custom selection of GPUs
This reverts commit a7d8d2b1d4.
2025-05-27 10:29:34 -03:00
Agnes Leroy
3bfacc1e9d chore(bench): add swap throughput benchmark 2025-05-27 12:08:31 +02:00
Agnes Leroy
a47a418d41 chore(gpu): rework dex bench to prepare throughput benchmark 2025-05-27 12:08:31 +02:00
David Testé
75b3141e19 chore(ci): fix command parsing for gpu benchmark common workflow
Quote escaping was flawed and would generate an array containing a unique string instead of several ones separated by commas.
2025-05-27 10:14:06 +02:00
Agnes Leroy
d01328e0fe fix(gpu): fix overflow error in clear inputs remainder in long run tests 2025-05-26 22:51:18 +02:00
Agnes Leroy
6e102b5fa1 chore(gpu): fix oom error in ci 2025-05-26 22:50:55 +02:00
Pedro Alves
8aa6fa514e fix(gpu): add missing error checks after some kernels 2025-05-26 16:29:23 -03:00
Nicolas Sarlin
21a19cd3c5 chore(shortint): modswitch noise reduction key upgrade without clone 2025-05-26 16:53:35 +02:00
Nicolas Sarlin
f51c70d536 feat(shortint): adds generic client key for atomic pattern support 2025-05-26 16:53:35 +02:00
Agnes Leroy
66e3c02838 feat(gpu): add memory tracking functions for comparisons 2025-05-23 14:37:39 +02:00
Pedro Alves
408e81c45a feat(gpu): add support for GPU-accelerated expand on the HL Api
- includes documentation about GPU's accelerated expand on the HL API
- rework CudaKeySwitchingKey
- Cloning the key is no longer necessary on the HL API
2025-05-23 11:54:29 +02:00
dependabot[bot]
4152906c5d chore(deps): bump actions/upload-artifact from 4.6.0 to 4.6.2
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.6.0 to 4.6.2.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v4.6.0...ea165f8d65b6e75b540449e92b4886f43607fa02)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: 4.6.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-05-23 11:23:02 +02:00
dependabot[bot]
9fc8a0b5bc chore(deps): bump codecov/codecov-action from 5.4.2 to 5.4.3
Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 5.4.2 to 5.4.3.
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](ad3126e916...18283e04ce)

---
updated-dependencies:
- dependency-name: codecov/codecov-action
  dependency-version: 5.4.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-05-23 11:22:55 +02:00
dependabot[bot]
5dc3e59d13 chore(deps): bump zgosalvez/github-actions-ensure-sha-pinned-actions
Bumps [zgosalvez/github-actions-ensure-sha-pinned-actions](https://github.com/zgosalvez/github-actions-ensure-sha-pinned-actions) from 3.0.23 to 3.0.25.
- [Release notes](https://github.com/zgosalvez/github-actions-ensure-sha-pinned-actions/releases)
- [Commits](4830be28ce...fc87bb5b5a)

---
updated-dependencies:
- dependency-name: zgosalvez/github-actions-ensure-sha-pinned-actions
  dependency-version: 3.0.25
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-05-23 11:22:48 +02:00
Nicolas Sarlin
b40996a7e5 chore(shortint): prepare the v1.3 params folder 2025-05-23 10:57:56 +02:00
Pedro Alves
b066ef19fa fix(gpu): fix the internal benchmark 2025-05-23 10:32:24 +02:00
Nicolas Sarlin
25d008bae8 fix(bench): add missing internal keycache feature 2025-05-22 16:14:30 +02:00
David Testé
2749c1088c chore(ci): handle multi directories for parameters records 2025-05-22 15:03:02 +02:00
201 changed files with 10745 additions and 2382 deletions

View File

@@ -84,7 +84,7 @@ jobs:
run: |
# Use Sed to extract a value from a string, this cannot be done with the ${variable//search/replace} pattern.
# shellcheck disable=SC2001
PARSED_COMMAND=$(echo "${INPUTS_COMMAND}" | sed 's/[[:space:]]*,[[:space:]]*/\\", \\"/g')
PARSED_COMMAND=$(echo "${INPUTS_COMMAND}" | sed 's/[[:space:]]*,[[:space:]]*/\", \"/g')
echo "COMMAND=[\"${PARSED_COMMAND}\"]" >> "${GITHUB_ENV}"
- name: Set single operations flavor

View File

@@ -44,7 +44,7 @@ jobs:
} >> "${GITHUB_ENV}"
- name: Install rust
uses: dtolnay/rust-toolchain@a54c7afa936fefeb4456b2dd8068152669aa8203
uses: dtolnay/rust-toolchain@888c2e1ea69ab0d4330cbf0af1ecc7b68f368cc1
with:
toolchain: nightly
@@ -76,7 +76,7 @@ jobs:
REF_NAME: ${{ github.ref_name }}
- name: Upload parsed results artifact
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02
with:
name: ${{ github.sha }}_integer_benchmarks
path: ${{ env.RESULTS_FILENAME }}

View File

@@ -38,9 +38,11 @@ jobs:
- name: Check workflows security
run: |
make check_workflow_security
env:
GH_TOKEN: ${{ env.CHECKOUT_TOKEN }}
- name: Ensure SHA pinned actions
uses: zgosalvez/github-actions-ensure-sha-pinned-actions@4830be28ce81da52ec70d65c552a7403821d98d4 # v3.0.23
uses: zgosalvez/github-actions-ensure-sha-pinned-actions@fc87bb5b5a97953d987372e74478de634726b3e5 # v3.0.25
with:
allowlist: |
slsa-framework/slsa-github-generator

View File

@@ -90,7 +90,7 @@ jobs:
make test_shortint_cov
- name: Upload tfhe coverage to Codecov
uses: codecov/codecov-action@ad3126e916f78f00edff4ed0317cf185271ccc2d
uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24
if: steps.changed-files.outputs.tfhe_any_changed == 'true'
with:
token: ${{ secrets.CODECOV_TOKEN }}
@@ -104,7 +104,7 @@ jobs:
make test_integer_cov
- name: Upload tfhe coverage to Codecov
uses: codecov/codecov-action@ad3126e916f78f00edff4ed0317cf185271ccc2d
uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24
if: steps.changed-files.outputs.tfhe_any_changed == 'true'
with:
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -149,7 +149,7 @@ jobs:
- name: Run High Level API Tests
run: |
BIG_TESTS_INSTANCE=FALSE make test_high_level_api_gpu
make test_high_level_api_gpu
slack-notify:
name: Slack Notification

View File

@@ -294,7 +294,7 @@ check_typos: install_typos_checker
.PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled
clippy_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types \
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types,zk-pok \
--all-targets \
-p $(TFHE_SPEC) -- --no-deps -D warnings
@@ -315,7 +315,7 @@ clippy_hpu: install_rs_check_toolchain
.PHONY: clippy_gpu_hpu # Run clippy lints on tfhe with "gpu" and "hpu" enabled
clippy_gpu_hpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=boolean,shortint,integer,internal-keycache,gpu,hpu,pbs-stats,extended-types \
--features=boolean,shortint,integer,internal-keycache,gpu,hpu,pbs-stats,extended-types,zk-pok \
--all-targets \
-p $(TFHE_SPEC) -- --no-deps -D warnings
@@ -899,7 +899,7 @@ test_high_level_api: install_rs_build_toolchain
test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) nextest run --cargo-profile $(CARGO_PROFILE) \
--features=integer,internal-keycache,gpu -p $(TFHE_SPEC) \
--test-threads=4 --features=integer,internal-keycache,gpu,zk-pok -p $(TFHE_SPEC) \
-E "test(/high_level_api::.*gpu.*/)"
test_high_level_api_hpu: install_rs_build_toolchain install_cargo_nextest
@@ -1073,7 +1073,7 @@ check_compile_tests: install_rs_build_toolchain
.PHONY: check_compile_tests_benches_gpu # Build tests in debug without running them
check_compile_tests_benches_gpu: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \
--features=experimental,boolean,shortint,integer,internal-keycache,gpu \
--features=experimental,boolean,shortint,integer,internal-keycache,gpu,zk-pok \
-p $(TFHE_SPEC)
mkdir -p "$(TFHECUDA_BUILD)" && \
cd "$(TFHECUDA_BUILD)" && \

View File

@@ -8,7 +8,7 @@ extern std::mutex m;
extern bool p2p_enabled;
extern "C" {
int32_t cuda_setup_multi_gpu();
int32_t cuda_setup_multi_gpu(int device_0_id);
}
// Define a variant type that can be either a vector or a single pointer

View File

@@ -522,7 +522,8 @@ template <typename Torus>
bool has_support_to_cuda_programmable_bootstrap_tbc(uint32_t num_samples,
uint32_t glwe_dimension,
uint32_t polynomial_size,
uint32_t level_count);
uint32_t level_count,
uint32_t max_shared_memory);
#ifdef __CUDACC__
__device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,

View File

@@ -492,6 +492,7 @@ __host__ void host_fourier_transform_forward_as_integer_f128(
batch_convert_u128_to_f128_as_integer<params>
<<<grid_size, block_size, 0, stream>>>(d_re0, d_re1, d_im0, d_im1,
d_standard);
check_cuda_error(cudaGetLastError());
// call negacyclic 128 bit forward fft.
if (full_sm) {
@@ -503,6 +504,7 @@ __host__ void host_fourier_transform_forward_as_integer_f128(
<<<grid_size, block_size, shared_memory_size, stream>>>(
d_re0, d_re1, d_im0, d_im1, d_re0, d_re1, d_im0, d_im1, buffer);
}
check_cuda_error(cudaGetLastError());
cuda_memcpy_async_to_cpu(re0, d_re0, N / 2 * sizeof(double), stream,
gpu_index);

View File

@@ -261,6 +261,8 @@ void cuda_fourier_polynomial_mul(void *stream_v, uint32_t gpu_index,
default:
break;
}
check_cuda_error(cudaGetLastError());
cuda_drop_async(buffer, stream, gpu_index);
}

View File

@@ -279,6 +279,7 @@ void cuda_convert_lwe_programmable_bootstrap_key(cudaStream_t stream,
PANIC("Cuda error (convert KSK): unsupported polynomial size. Supported "
"N's are powers of two in the interval [256..16384].")
}
check_cuda_error(cudaGetLastError());
cuda_drop_async(d_bsk, stream, gpu_index);
cuda_drop_async(buffer, stream, gpu_index);
@@ -315,6 +316,7 @@ void convert_u128_to_f128_and_forward_fft_128(cudaStream_t stream,
// convert u128 into 4 x double
batch_convert_u128_to_f128_strided_as_torus<params>
<<<grid_size, block_size, 0, stream>>>(d_bsk, d_standard);
check_cuda_error(cudaGetLastError());
// call negacyclic 128 bit forward fft.
if (full_sm) {
@@ -326,6 +328,7 @@ void convert_u128_to_f128_and_forward_fft_128(cudaStream_t stream,
<<<grid_size, block_size, shared_memory_size, stream>>>(d_bsk, d_bsk,
buffer);
}
check_cuda_error(cudaGetLastError());
cuda_drop_async(buffer, stream, gpu_index);
}

View File

@@ -848,4 +848,7 @@ template uint64_t scratch_cuda_programmable_bootstrap_tbc<uint64_t>(
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory,
bool allocate_ms_array);
template bool
supports_distributed_shared_memory_on_classic_programmable_bootstrap<
__uint128_t>(uint32_t polynomial_size, uint32_t max_shared_memory);
#endif

View File

@@ -6,7 +6,8 @@
std::mutex m;
bool p2p_enabled = false;
int32_t cuda_setup_multi_gpu() {
// Enable bidirectional p2p access between all available GPUs and device_0_id
int32_t cuda_setup_multi_gpu(int device_0_id) {
int num_gpus = cuda_get_number_of_gpus();
if (num_gpus == 0)
PANIC("GPU error: the number of GPUs should be > 0.")
@@ -18,11 +19,13 @@ int32_t cuda_setup_multi_gpu() {
omp_set_nested(1);
int has_peer_access_to_device_0;
for (int i = 1; i < num_gpus; i++) {
check_cuda_error(
cudaDeviceCanAccessPeer(&has_peer_access_to_device_0, i, 0));
check_cuda_error(cudaDeviceCanAccessPeer(&has_peer_access_to_device_0,
i, device_0_id));
if (has_peer_access_to_device_0) {
cuda_set_device(i);
check_cuda_error(cudaDeviceEnablePeerAccess(0, 0));
check_cuda_error(cudaDeviceEnablePeerAccess(device_0_id, 0));
cuda_set_device(device_0_id);
check_cuda_error(cudaDeviceEnablePeerAccess(i, 0));
}
num_used_gpus += 1;
}

View File

@@ -168,7 +168,7 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, TbcMultiBit)
(benchmark::State &st) {
if (!has_support_to_cuda_programmable_bootstrap_tbc_multi_bit<uint64_t>(
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
pbs_level)) {
pbs_level, cuda_get_max_shared_memory(0))) {
st.SkipWithError("Configuration not supported for tbc operation");
return;
}
@@ -256,7 +256,7 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, TbcPBC)
(benchmark::State &st) {
if (!has_support_to_cuda_programmable_bootstrap_tbc<uint64_t>(
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
pbs_level)) {
pbs_level, cuda_get_max_shared_memory(0))) {
st.SkipWithError("Configuration not supported for tbc operation");
return;
}

View File

@@ -65,7 +65,7 @@ public:
number_of_inputs = (int)GetParam().number_of_inputs;
// Enable Multi-GPU logic
gpu_count = cuda_setup_multi_gpu();
gpu_count = cuda_setup_multi_gpu(0);
active_gpu_count = std::min((uint)number_of_inputs, gpu_count);
for (uint gpu_i = 0; gpu_i < active_gpu_count; gpu_i++) {
streams.push_back(cuda_create_stream(gpu_i));

View File

@@ -101,6 +101,6 @@ extern "C" {
pub fn cuda_drop_async(ptr: *mut c_void, stream: *mut c_void, gpu_index: u32);
pub fn cuda_setup_multi_gpu() -> i32;
pub fn cuda_setup_multi_gpu(gpu_index: u32) -> i32;
} // extern "C"

View File

@@ -14,6 +14,8 @@ import sys
ONE_HOUR_IN_SECONDS = 3600
ONE_SECOND_IN_NANOSECONDS = 1e9
# These are directories where crypto parameters records can be stored.
BENCHMARK_DIRS = ["tfhe-benchmark", "tfhe-zk-pok"]
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -348,8 +350,18 @@ def get_parameters(bench_id, directory):
:return: :class:`tuple` as ``(benchmark parameters, display name, operator type)``
"""
params_dir = pathlib.Path("tfhe-benchmark", "benchmarks_parameters", bench_id)
params = _parse_file_to_json(params_dir, "parameters.json")
for dirname in BENCHMARK_DIRS:
params_dir = pathlib.Path(dirname, "benchmarks_parameters", bench_id)
try:
params = _parse_file_to_json(params_dir, "parameters.json")
except FileNotFoundError:
continue
else:
break
else:
raise FileNotFoundError(
f"file not found: '[...]/benchmarks_parameters/{bench_id}/parameters.json'"
)
display_name = params.pop("display_name")
operator = params.pop("operator_type")

View File

@@ -140,8 +140,8 @@ if [[ "${backend}" == "gpu" ]]; then
test_threads=8
doctest_threads=8
else
test_threads=1
doctest_threads=1
test_threads=4
doctest_threads=4
fi
fi

View File

@@ -138,11 +138,13 @@ pub fn test_shortint_clientkey(
let key: ClientKey = load_and_unversionize(dir, test, format)?;
if test_params != key.parameters {
if test_params != key.parameters() {
Err(test.failure(
format!(
"Invalid {} parameters:\n Expected :\n{:?}\nGot:\n{:?}",
format, test_params, key.parameters
format,
test_params,
key.parameters()
),
format,
))

View File

@@ -136,7 +136,7 @@ required-features = ["shortint"]
name = "pbs128-bench"
path = "benches/core_crypto/pbs128_bench.rs"
harness = false
required-features = ["shortint"]
required-features = ["shortint", "internal-keycache"]
[[bin]]
name = "boolean_key_sizes"

File diff suppressed because it is too large Load Diff

View File

@@ -418,11 +418,11 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
mod cuda {
use super::*;
use benchmark::utilities::{cuda_local_keys, cuda_local_streams};
use benchmark::utilities::cuda_local_streams;
use criterion::BatchSize;
use itertools::Itertools;
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
use tfhe::integer::gpu::key_switching_key::CudaKeySwitchingKey;
use tfhe::integer::gpu::key_switching_key::{CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial};
use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList;
use tfhe::integer::gpu::CudaServerKey;
use tfhe::integer::CompressedServerKey;
@@ -451,14 +451,17 @@ mod cuda {
let param_name = param_name.as_str();
let cks = ClientKey::new(param_fhe);
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
let sk = compressed_server_key.decompress();
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
let compact_private_key = CompactPrivateKey::new(param_pke);
let pk = CompactPublicKey::new(&compact_private_key);
let d_ksk = CudaKeySwitchingKey::new(
(&compact_private_key, None),
(&cks, &gpu_sks),
param_ksk,
&streams,
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
let d_ksk_material =
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
&d_ksk_material,
&gpu_sks,
);
// We have a use case with 320 bits of metadata
@@ -609,7 +612,6 @@ mod cuda {
});
}
BenchmarkType::Throughput => {
let gpu_sks_vec = cuda_local_keys(&cks);
let gpu_count = get_number_of_gpus() as usize;
let elements = zk_throughput_num_elements();
@@ -637,20 +639,17 @@ mod cuda {
.collect::<Vec<_>>();
let local_streams = cuda_local_streams(num_block, elements as usize);
let d_ksk_vec = gpu_sks_vec
let d_ksk_material_vec = local_streams
.par_iter()
.zip(local_streams.par_iter())
.map(|(gpu_sks, local_stream)| {
CudaKeySwitchingKey::new(
(&compact_private_key, None),
(&cks, gpu_sks),
param_ksk,
.map(|local_stream| {
CudaKeySwitchingKeyMaterial::from_key_switching_key(
&ksk,
local_stream,
)
})
.collect::<Vec<_>>();
assert_eq!(d_ksk_vec.len(), gpu_count);
assert_eq!(d_ksk_material_vec.len(), gpu_count);
bench_group.bench_function(&bench_id_verify, |b| {
b.iter(|| {
@@ -673,14 +672,16 @@ mod cuda {
(gpu_cts, local_streams)
};
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
gpu_cts.par_iter()
.zip(local_streams.par_iter())
.enumerate()
.for_each(|(i, (gpu_ct, local_stream))| {
gpu_ct
.expand_without_verification(&d_ksk_vec[i % gpu_count], local_stream)
.unwrap();
b.iter_batched(setup_encrypted_values,
|(gpu_cts, local_streams)| {
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
(|(i, (gpu_ct, local_stream))| {
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);
gpu_ct
.expand_without_verification(&d_ksk, local_stream)
.unwrap();
});
}, BatchSize::SmallInput);
});
@@ -698,16 +699,15 @@ mod cuda {
(gpu_cts, local_streams)
};
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
gpu_cts
.par_iter()
.zip(local_streams.par_iter())
.for_each(|(gpu_ct, local_stream)| {
gpu_ct
.verify_and_expand(
&crs, &pk, &metadata, &d_ksk, local_stream
)
.unwrap();
b.iter_batched(setup_encrypted_values,
|(gpu_cts, local_streams)| {
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
(|(gpu_ct, local_stream)| {
gpu_ct
.verify_and_expand(
&crs, &pk, &metadata, &d_ksk, local_stream,
)
.unwrap();
});
}, BatchSize::SmallInput);
});

View File

@@ -27,7 +27,7 @@ fn bench_server_key_unary_function<F>(
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus().0;
let modulus = cks.parameters().message_modulus().0;
let clear_text = rng.gen::<u64>() % modulus;
@@ -70,7 +70,7 @@ fn bench_server_key_binary_function<F>(
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus().0;
let modulus = cks.parameters().message_modulus().0;
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
@@ -115,7 +115,7 @@ fn bench_server_key_binary_scalar_function<F>(
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus().0;
let modulus = cks.parameters().message_modulus().0;
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
@@ -159,7 +159,7 @@ fn bench_server_key_binary_scalar_division_function<F>(
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus().0;
let modulus = cks.parameters().message_modulus().0;
assert_ne!(modulus, 1);
let clear_0 = rng.gen::<u64>() % modulus;
@@ -200,7 +200,7 @@ fn carry_extract_bench(c: &mut Criterion) {
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus().0;
let modulus = cks.parameters().message_modulus().0;
let clear_0 = rng.gen::<u64>() % modulus;
@@ -236,7 +236,7 @@ fn programmable_bootstrapping_bench(c: &mut Criterion) {
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus().0;
let modulus = cks.parameters().message_modulus().0;
let acc = sks.generate_lookup_table(|x| x);

View File

@@ -0,0 +1,82 @@
# Zero-knowledge proofs
Zero-knowledge proofs (ZK) are a powerful tool to assert that the encryption of a message is correct, as discussed in [advanced features](../../fhe-computation/advanced-features/zk-pok.md).
However, computation is not possible on the type of ciphertexts it produces (i.e. `ProvenCompactCiphertextList`). This document explains how to use the GPU to accelerate the
preprocessing step needed to convert ciphertexts formatted for ZK to ciphertexts in the right format for computation purposes on GPU. This
operation is called "expansion".
## Proven compact ciphertext list
A proven compact list of ciphertexts can be seen as a compacted collection of ciphertexts for which encryption can be verified.
This verification is currently only supported on the CPU, but the expansion can be accelerated using the GPU.
This way, verification and expansion can be performed in parallel, efficiently using all the available computational resources.
## Supported types
Encrypted messages can be integers (like FheUint64) or booleans. The GPU backend does not currently support encrypted strings.
{% hint style="info" %}
You can enable this feature using the flag: `--features=zk-pok,gpu` when building **TFHE-rs**.
{% endhint %}
## Example
The following example shows how a client can encrypt and prove a ciphertext, and how a server can verify the proof, preprocess the ciphertext and run a computation on it on GPU:
```rust
use rand::random;
use tfhe::CompressedServerKey;
use tfhe::prelude::*;
use tfhe::set_server_key;
use tfhe::zk::{CompactPkeCrs, ZkComputeLoad};
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let params = tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
// Indicate which parameters to use for the Compact Public Key encryption
let cpk_params = tfhe::shortint::parameters::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
// And parameters allowing to keyswitch/cast to the computation parameters.
let casting_params = tfhe::shortint::parameters::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
// Enable the dedicated parameters on the config
let config = tfhe::ConfigBuilder::with_custom_parameters(params)
.use_dedicated_compact_public_key_parameters((cpk_params, casting_params)).build();
// The CRS should be generated in an offline phase then shared to all clients and the server
let crs = CompactPkeCrs::from_config(config, 64).unwrap();
// Then use TFHE-rs as usual
let client_key = tfhe::ClientKey::generate(config);
let compressed_server_key = CompressedServerKey::new(&client_key);
let gpu_server_key = compressed_server_key.decompress_to_gpu();
let public_key = tfhe::CompactPublicKey::try_new(&client_key).unwrap();
// This can be left empty, but if provided allows to tie the proof to arbitrary data
let metadata = [b'T', b'F', b'H', b'E', b'-', b'r', b's'];
let clear_a = random::<u64>();
let clear_b = random::<u64>();
let proven_compact_list = tfhe::ProvenCompactCiphertextList::builder(&public_key)
.push(clear_a)
.push(clear_b)
.build_with_proof_packed(&crs, &metadata, ZkComputeLoad::Verify)?;
// Server side
let result = {
set_server_key(gpu_server_key);
// Verify the ciphertexts
let expander =
proven_compact_list.verify_and_expand(&crs, &public_key, &metadata)?;
let a: tfhe::FheUint64 = expander.get(0)?.unwrap();
let b: tfhe::FheUint64 = expander.get(1)?.unwrap();
a + b
};
// Back on the client side
let a_plus_b: u64 = result.decrypt(&client_key);
assert_eq!(a_plus_b, clear_a.wrapping_add(clear_b));
Ok(())
}
```

View File

@@ -114,7 +114,7 @@ fn main() {
let msg1 = 1;
let msg2 = 0;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);

View File

@@ -86,7 +86,7 @@ fn main() {
let msg1 = 1;
let msg2 = 0;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);

View File

@@ -59,7 +59,7 @@ fn main() {
let msg2 = 3;
let scalar = 4;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -91,7 +91,7 @@ fn main() {
let msg2 = 3;
let scalar = 4;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -134,7 +134,7 @@ fn main() {
let msg2 = 3;
let scalar = 4;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -168,7 +168,7 @@ fn main() {
let msg2 = 3;
let scalar = 4;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -244,7 +244,7 @@ fn main() {
let msg1 = 2;
let msg2 = 1;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the private client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);
@@ -275,7 +275,7 @@ fn main() {
let msg1 = 2;
let msg2 = 1;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the private client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);
@@ -306,7 +306,7 @@ fn main() {
let msg1 = 2;
let msg2 = 1;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the private client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);
@@ -365,7 +365,7 @@ fn main() {
let msg1 = 3;
let msg2 = 2;
let modulus = client_key.parameters.message_modulus().0;
let modulus = client_key.parameters().message_modulus().0;
// We use the private client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);

View File

@@ -7,7 +7,7 @@ use crate::core_crypto::prelude::{
use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu;
/// A structure representing a vector of LWE ciphertexts with 64 bits of precision on the GPU.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct CudaLweCiphertextList<T: UnsignedInteger>(pub(crate) CudaLweList<T>);
#[allow(dead_code)]

View File

@@ -10,6 +10,10 @@ use crate::core_crypto::prelude::{
pub struct CudaLweCompactCiphertextList<T: UnsignedInteger>(pub CudaLweList<T>);
impl<T: UnsignedInteger> CudaLweCompactCiphertextList<T> {
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
Self(self.0.duplicate(streams))
}
pub fn from_lwe_compact_ciphertext_list<C: Container<Element = T>>(
h_ct: &LweCompactCiphertextList<C>,
streams: &CudaStreams,

View File

@@ -10,6 +10,7 @@ use crate::core_crypto::prelude::{
UnsignedInteger,
};
#[derive(Clone)]
#[allow(dead_code)]
pub struct CudaLweKeyswitchKey<T: UnsignedInteger> {
pub(crate) d_vec: CudaVec<T>,

View File

@@ -17,7 +17,6 @@ pub use entities::*;
use std::ffi::c_void;
use tfhe_cuda_backend::bindings::*;
use tfhe_cuda_backend::cuda_bind::*;
pub struct CudaStreams {
pub ptr: Vec<*mut c_void>,
pub gpu_indexes: Vec<GpuIndex>,
@@ -30,7 +29,7 @@ unsafe impl Sync for CudaStreams {}
impl CudaStreams {
/// Create a new `CudaStreams` structure with as many GPUs as there are on the machine
pub fn new_multi_gpu() -> Self {
let gpu_count = setup_multi_gpu();
let gpu_count = setup_multi_gpu(GpuIndex::new(0));
let mut gpu_indexes = Vec::with_capacity(gpu_count as usize);
let mut ptr_array = Vec::with_capacity(gpu_count as usize);
@@ -43,6 +42,22 @@ impl CudaStreams {
gpu_indexes,
}
}
/// Create a new `CudaStreams` structure with the GPUs with id provided in a list
pub fn new_multi_gpu_with_indexes(indexes: &[GpuIndex]) -> Self {
let _gpu_count = setup_multi_gpu(indexes[0]);
let mut gpu_indexes = Vec::with_capacity(indexes.len());
let mut ptr_array = Vec::with_capacity(indexes.len());
for &i in indexes {
ptr_array.push(unsafe { cuda_create_stream(i.get()) });
gpu_indexes.push(i);
}
Self {
ptr: ptr_array,
gpu_indexes,
}
}
/// Create a new `CudaStreams` structure with one GPU, whose index corresponds to the one given
/// as input
pub fn new_single_gpu(gpu_index: GpuIndex) -> Self {
@@ -88,6 +103,14 @@ impl CudaStreams {
}
}
impl Clone for CudaStreams {
fn clone(&self) -> Self {
// The `new_multi_gpu_with_indexes()` function is used here to adapt to any specific type of
// streams being cloned (single, multi, or custom)
Self::new_multi_gpu_with_indexes(self.gpu_indexes.as_slice())
}
}
impl Drop for CudaStreams {
fn drop(&mut self) {
for (i, &s) in self.ptr.iter().enumerate() {
@@ -959,7 +982,7 @@ pub unsafe fn fourier_transform_backward_as_torus_f128_async<T: UnsignedInteger>
);
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct CudaLweList<T: UnsignedInteger> {
// Pointer to GPU data
pub d_vec: CudaVec<T>,
@@ -971,6 +994,17 @@ pub struct CudaLweList<T: UnsignedInteger> {
pub ciphertext_modulus: CiphertextModulus<T>,
}
impl<T: UnsignedInteger> CudaLweList<T> {
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
Self {
d_vec: self.d_vec.duplicate(streams),
lwe_ciphertext_count: self.lwe_ciphertext_count,
lwe_dimension: self.lwe_dimension,
ciphertext_modulus: self.ciphertext_modulus,
}
}
}
#[derive(Debug, Clone)]
pub struct CudaGlweList<T: UnsignedInteger> {
// Pointer to GPU data
@@ -1002,8 +1036,8 @@ pub fn get_number_of_gpus() -> u32 {
}
/// Setup multi-GPU and return the number of GPUs used
pub fn setup_multi_gpu() -> u32 {
unsafe { cuda_setup_multi_gpu() as u32 }
pub fn setup_multi_gpu(device_0_id: GpuIndex) -> u32 {
unsafe { cuda_setup_multi_gpu(device_0_id.get()) as u32 }
}
/// Synchronize device

View File

@@ -7,7 +7,6 @@ use crate::array::stride::{ParStridedIter, ParStridedIterMut, StridedIter};
use crate::array::traits::TensorSlice;
use crate::high_level_api::array::{ArrayBackend, BackendDataContainer, BackendDataContainerMut};
use crate::high_level_api::global_state;
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::prelude::{FheDecrypt, FheTryEncrypt};
use crate::{ClientKey, FheBoolId};
@@ -25,10 +24,15 @@ pub type GpuFheBoolSliceMut<'a> =
pub struct GpuBooleanSlice<'a>(pub(crate) &'a [CudaBooleanBlock]);
pub struct GpuBooleanSliceMut<'a>(pub(crate) &'a mut [CudaBooleanBlock]);
pub struct GpuBooleanOwned(pub(crate) Vec<CudaBooleanBlock>);
use crate::high_level_api::global_state::with_cuda_internal_keys;
impl Clone for GpuBooleanOwned {
fn clone(&self) -> Self {
with_thread_local_cuda_streams(|streams| {
// When cloning, we assume that the intention is to return a ciphertext that lies in the GPU
// 0 defined in the set server key. Hence, we use the server key to get the streams instead
// of those inside the ciphertext itself
with_cuda_internal_keys(|key| {
let streams = &key.streams;
Self(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
})
}
@@ -83,7 +87,8 @@ impl BackendDataContainer for GpuBooleanSlice<'_> {
}
fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|key| {
let streams = &key.streams;
GpuBooleanOwned(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
})
}
@@ -104,7 +109,8 @@ impl BackendDataContainer for GpuBooleanSliceMut<'_> {
}
fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|key| {
let streams = &key.streams;
GpuBooleanOwned(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
})
}
@@ -156,14 +162,13 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
rhs: TensorSlice<'_, Self::Slice<'a>>,
) -> Self::Owned {
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams))
})
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams))
})
.collect::<Vec<_>>()
}))
}
@@ -172,14 +177,13 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
rhs: TensorSlice<'_, Self::Slice<'a>>,
) -> Self::Owned {
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams))
})
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams))
})
.collect::<Vec<_>>()
}))
}
@@ -188,24 +192,22 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
rhs: TensorSlice<'_, Self::Slice<'a>>,
) -> Self::Owned {
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams))
})
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams))
})
.collect::<Vec<_>>()
}))
}
fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams)))
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams)))
.collect::<Vec<_>>()
}))
}
}
@@ -216,16 +218,13 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
rhs: TensorSlice<'_, &'_ [bool]>,
) -> Self::Owned {
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter().copied())
.map(|(lhs, rhs)| {
CudaBooleanBlock(
cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams),
)
})
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.zip(rhs.par_iter().copied())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams))
})
.collect::<Vec<_>>()
}))
}
@@ -234,16 +233,13 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
rhs: TensorSlice<'_, &'_ [bool]>,
) -> Self::Owned {
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter().copied())
.map(|(lhs, rhs)| {
CudaBooleanBlock(
cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams),
)
})
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.zip(rhs.par_iter().copied())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams))
})
.collect::<Vec<_>>()
}))
}
@@ -252,16 +248,13 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
rhs: TensorSlice<'_, &'_ [bool]>,
) -> Self::Owned {
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter().copied())
.map(|(lhs, rhs)| {
CudaBooleanBlock(
cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams),
)
})
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.zip(rhs.par_iter().copied())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams))
})
.collect::<Vec<_>>()
}))
}
}
@@ -270,7 +263,8 @@ impl FheTryEncrypt<&[bool], ClientKey> for GpuFheBoolArray {
type Error = crate::Error;
fn try_encrypt(values: &[bool], cks: &ClientKey) -> Result<Self, Self::Error> {
let encrypted = with_thread_local_cuda_streams(|streams| {
let encrypted = with_cuda_internal_keys(|key| {
let streams = &key.streams;
values
.iter()
.copied()
@@ -285,7 +279,8 @@ impl FheTryEncrypt<&[bool], ClientKey> for GpuFheBoolArray {
impl FheDecrypt<Vec<bool>> for GpuFheBoolSlice<'_> {
fn decrypt(&self, key: &ClientKey) -> Vec<bool> {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|cuda_key| {
let streams = &cuda_key.streams;
self.elems
.0
.iter()

View File

@@ -13,7 +13,7 @@ use crate::array::traits::{
};
use crate::core_crypto::gpu::CudaStreams;
use crate::high_level_api::global_state;
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::global_state::with_cuda_internal_keys;
use crate::high_level_api::integers::{FheIntId, FheUintId};
use crate::integer::block_decomposition::{
DecomposableInto, RecomposableFrom, RecomposableSignedInteger,
@@ -53,7 +53,8 @@ where
T: CudaIntegerRadixCiphertext,
{
fn clone(&self) -> Self {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|key| {
let streams = &key.streams;
Self(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
})
}
@@ -108,12 +109,11 @@ where
F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, &T, &CudaStreams) -> T,
{
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams))
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams))
.collect::<Vec<_>>()
}))
}
@@ -170,12 +170,11 @@ where
F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, Clear, &CudaStreams) -> T,
{
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams))
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams))
.collect::<Vec<_>>()
}))
}
@@ -336,11 +335,10 @@ where
fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams))
.collect::<Vec<_>>()
})
let streams = &cuda_key.streams;
lhs.par_iter()
.map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams))
.collect::<Vec<_>>()
}))
}
}
@@ -439,7 +437,8 @@ where
}
fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|key| {
let streams = &key.streams;
GpuOwned(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
})
}
@@ -463,7 +462,8 @@ where
}
fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|key| {
let streams = &key.streams;
GpuOwned(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
})
}
@@ -492,7 +492,8 @@ where
fn try_encrypt(clears: &'a [Clear], key: &ClientKey) -> Result<Self, Self::Error> {
let num_blocks = Id::num_blocks(key.message_modulus());
Ok(Self::new(
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|cuda_key| {
let streams = &cuda_key.streams;
clears
.iter()
.copied()
@@ -527,7 +528,8 @@ where
));
}
let num_blocks = Id::num_blocks(key.message_modulus());
let elems = with_thread_local_cuda_streams(|streams| {
let elems = with_cuda_internal_keys(|cuda_key| {
let streams = &cuda_key.streams;
clears
.iter()
.copied()
@@ -570,7 +572,8 @@ where
Clear: RecomposableFrom<u64> + UnsignedNumeric,
{
fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|cuda_key| {
let streams = &cuda_key.streams;
self.as_tensor_slice()
.iter()
.map(|ct: &CudaUnsignedRadixCiphertext| {
@@ -591,7 +594,8 @@ where
fn try_encrypt(clears: &'a [Clear], key: &ClientKey) -> Result<Self, Self::Error> {
let num_blocks = Id::num_blocks(key.message_modulus());
Ok(Self::new(
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|cuda_key| {
let streams = &cuda_key.streams;
clears
.iter()
.copied()
@@ -634,7 +638,8 @@ where
Clear: RecomposableSignedInteger,
{
fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|cuda_key| {
let streams = &cuda_key.streams;
self.elems
.0
.iter()

View File

@@ -13,8 +13,6 @@ pub(in crate::high_level_api) mod traits;
use crate::array::traits::TensorSlice;
use crate::high_level_api::array::traits::HasClear;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::{FheIntId, FheUintId};
use crate::high_level_api::keys::InternalServerKey;
use crate::{FheBool, FheId, FheInt, FheUint, Tag};
@@ -369,7 +367,8 @@ pub fn fhe_uint_array_eq<Id: FheUintId>(lhs: &[FheUint<Id>], rhs: &[FheUint<Id>]
FheBool::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(gpu_key) => {
let streams = &gpu_key.streams;
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams))
@@ -381,7 +380,7 @@ pub fn fhe_uint_array_eq<Id: FheUintId>(lhs: &[FheUint<Id>], rhs: &[FheUint<Id>]
let result = gpu_key.key.key.all_eq_slices(&tmp_lhs, &tmp_rhs, streams);
FheBool::new(result, gpu_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support Array yet.")
@@ -410,7 +409,8 @@ pub fn fhe_uint_array_contains_sub_slice<Id: FheUintId>(
FheBool::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(gpu_key) => {
let streams = &gpu_key.streams;
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams))
@@ -425,7 +425,7 @@ pub fn fhe_uint_array_contains_sub_slice<Id: FheUintId>(
.key
.contains_sub_slice(&tmp_lhs, &tmp_pattern, streams);
FheBool::new(result, gpu_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support Array yet.")

View File

@@ -3,6 +3,9 @@ use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
use crate::{CompactCiphertextList, Tag};
#[cfg(feature = "zk-pok")]
use crate::ProvenCompactCiphertextList;
#[derive(Version)]
pub struct CompactCiphertextListV0(crate::integer::ciphertext::CompactCiphertextList);
@@ -17,9 +20,6 @@ impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
}
}
#[cfg(feature = "zk-pok")]
use crate::ProvenCompactCiphertextList;
#[derive(VersionsDispatch)]
pub enum CompactCiphertextListVersions {
V0(CompactCiphertextListV0),

View File

@@ -3,8 +3,6 @@ use crate::backward_compatibility::booleans::FheBoolVersions;
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric};
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId};
use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::traits::{FheEq, IfThenElse, ScalarIfThenElse, Tagged};
@@ -408,7 +406,8 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool {
(InnerBoolean::Cpu(new_ct), key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*ct_then.ciphertext.on_gpu(streams),
@@ -417,7 +416,7 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool {
);
let boolean_inner = CudaBooleanBlock(inner);
(InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support if_then_else with clear input")
@@ -449,7 +448,8 @@ where
FheUint::new(inner, cpu_sks.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*ct_then.ciphertext.on_gpu(streams),
@@ -458,7 +458,7 @@ where
);
FheUint::new(inner, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_then = ct_then.ciphertext.on_hpu(device);
@@ -506,7 +506,8 @@ impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
FheInt::new(new_ct, key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*ct_then.ciphertext.on_gpu(streams),
@@ -515,7 +516,7 @@ impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
);
FheInt::new(inner, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support signed integers")
@@ -537,7 +538,8 @@ impl IfThenElse<Self> for FheBool {
(InnerBoolean::Cpu(new_ct), key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*ct_then.ciphertext.on_gpu(streams),
@@ -546,7 +548,7 @@ impl IfThenElse<Self> for FheBool {
);
let boolean_inner = CudaBooleanBlock(inner);
(InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bool if then else")
@@ -600,7 +602,8 @@ where
Self::new(ciphertext, key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.eq(
&*self.ciphertext.on_gpu(streams),
&other.borrow().ciphertext.on_gpu(streams),
@@ -608,7 +611,7 @@ where
);
let ciphertext = InnerBoolean::Cuda(inner);
Self::new(ciphertext, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support FheBool::eq")
@@ -646,7 +649,8 @@ where
Self::new(ciphertext, key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.ne(
&*self.ciphertext.on_gpu(streams),
&other.borrow().ciphertext.on_gpu(streams),
@@ -654,7 +658,7 @@ where
);
let ciphertext = InnerBoolean::Cuda(inner);
Self::new(ciphertext, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support FheBool::ne")
@@ -695,14 +699,15 @@ impl FheEq<bool> for FheBool {
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.scalar_eq(
&*self.ciphertext.on_gpu(streams),
u8::from(other),
streams,
);
(InnerBoolean::Cuda(inner), cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support FheBool::eq with a bool")
@@ -742,14 +747,15 @@ impl FheEq<bool> for FheBool {
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.scalar_ne(
&*self.ciphertext.on_gpu(streams),
u8::from(other),
streams,
);
(InnerBoolean::Cuda(inner), cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support FheBool::ne with a bool")
@@ -820,7 +826,8 @@ where
(InnerBoolean::Cpu(inner_ct), key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_ct = cuda_key.key.key.bitand(
&*self.ciphertext.on_gpu(streams),
&rhs.borrow().ciphertext.on_gpu(streams),
@@ -833,7 +840,7 @@ where
)),
cuda_key.tag.clone(),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitand (&)")
@@ -909,7 +916,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_ct = cuda_key.key.key.bitor(
&*self.ciphertext.on_gpu(streams),
&rhs.borrow().ciphertext.on_gpu(streams),
@@ -921,7 +929,7 @@ where
)),
cuda_key.tag.clone(),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitor (|)")
@@ -997,7 +1005,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_ct = cuda_key.key.key.bitxor(
&*self.ciphertext.on_gpu(streams),
&rhs.borrow().ciphertext.on_gpu(streams),
@@ -1009,7 +1018,7 @@ where
)),
cuda_key.tag.clone(),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitxor (^)")
@@ -1077,7 +1086,8 @@ impl BitAnd<bool> for &FheBool {
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_ct = cuda_key.key.key.scalar_bitand(
&*self.ciphertext.on_gpu(streams),
u8::from(rhs),
@@ -1089,7 +1099,7 @@ impl BitAnd<bool> for &FheBool {
)),
cuda_key.tag.clone(),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("hpu does not bitand (&) with a bool")
@@ -1157,7 +1167,8 @@ impl BitOr<bool> for &FheBool {
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_ct = cuda_key.key.key.scalar_bitor(
&*self.ciphertext.on_gpu(streams),
u8::from(rhs),
@@ -1169,7 +1180,7 @@ impl BitOr<bool> for &FheBool {
)),
cuda_key.tag.clone(),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("hpu does not bitor (|) with a bool")
@@ -1237,7 +1248,8 @@ impl BitXor<bool> for &FheBool {
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_ct = cuda_key.key.key.scalar_bitxor(
&*self.ciphertext.on_gpu(streams),
u8::from(rhs),
@@ -1249,7 +1261,7 @@ impl BitXor<bool> for &FheBool {
)),
cuda_key.tag.clone(),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("hpu does not bitxor (^) with a bool")
@@ -1445,13 +1457,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.bitand_assign(
self.ciphertext.as_gpu_mut(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitand assign (&=)")
@@ -1492,13 +1505,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.bitor_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitor assign (|=)")
@@ -1539,13 +1553,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.bitxor_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitxor assign (^=)")
@@ -1580,13 +1595,14 @@ impl BitAndAssign<bool> for FheBool {
.scalar_bitand_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs));
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_bitand_assign(
self.ciphertext.as_gpu_mut(streams),
u8::from(rhs),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitand assign (&=) with a bool")
@@ -1621,13 +1637,14 @@ impl BitOrAssign<bool> for FheBool {
.scalar_bitor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs));
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_bitor_assign(
self.ciphertext.as_gpu_mut(streams),
u8::from(rhs),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitor assign (|=) with a bool")
@@ -1662,13 +1679,14 @@ impl BitXorAssign<bool> for FheBool {
.scalar_bitxor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs));
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_bitxor_assign(
self.ciphertext.as_gpu_mut(streams),
u8::from(rhs),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitor assign (^=) with a bool")
@@ -1729,7 +1747,8 @@ impl std::ops::Not for &FheBool {
(InnerBoolean::Cpu(inner), key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner =
cuda_key
.key
@@ -1741,7 +1760,7 @@ impl std::ops::Not for &FheBool {
)),
cuda_key.tag.clone(),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitnot (!)")

View File

@@ -1,8 +1,6 @@
use super::base::FheBool;
use crate::high_level_api::booleans::inner::InnerBoolean;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::keys::InternalServerKey;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
@@ -90,7 +88,8 @@ impl FheTryTrivialEncrypt<bool> for FheBool {
(ct, key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner: CudaUnsignedRadixCiphertext =
cuda_key
.key
@@ -100,7 +99,7 @@ impl FheTryTrivialEncrypt<bool> for FheBool {
inner.into_inner(),
));
(ct, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support trivial encryption")

View File

@@ -4,9 +4,9 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::high_level_api::details::MaybeCloned;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::{
with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes,
};
use crate::high_level_api::global_state::with_cuda_internal_keys;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams_for_gpu_indexes;
use crate::integer::BooleanBlock;
use crate::Device;
use serde::{Deserializer, Serializer};
@@ -36,7 +36,9 @@ impl Clone for InnerBoolean {
Self::Cpu(inner) => Self::Cpu(inner.clone()),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => {
with_thread_local_cuda_streams(|streams| Self::Cuda(inner.duplicate(streams)))
with_thread_local_cuda_streams_for_gpu_indexes(inner.gpu_indexes(), |streams| {
Self::Cuda(inner.duplicate(streams))
})
}
#[cfg(feature = "hpu")]
Self::Hpu(inner) => Self::Hpu(inner.clone()),
@@ -251,7 +253,8 @@ impl InnerBoolean {
#[cfg(feature = "gpu")]
// We may not be on the correct Cuda device
if let Self::Cuda(cuda_ct) = self {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|key| {
let streams = &key.streams;
if cuda_ct.gpu_indexes() != streams.gpu_indexes() {
*cuda_ct = cuda_ct.duplicate(streams);
}
@@ -273,7 +276,8 @@ impl InnerBoolean {
}
#[cfg(feature = "gpu")]
Device::CudaGpu => {
let new_inner = with_thread_local_cuda_streams(|streams| {
let new_inner = with_cuda_internal_keys(|key| {
let streams = &key.streams;
crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock::from_boolean_block(
&cpu_ct, streams,
)

View File

@@ -1,7 +1,5 @@
use super::{FheBool, InnerBoolean};
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::keys::InternalServerKey;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
@@ -41,7 +39,8 @@ impl FheBool {
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let d_ct: CudaUnsignedRadixCiphertext = cuda_key
.key
.key
@@ -52,7 +51,7 @@ impl FheBool {
)),
cuda_key.tag.clone(),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support random bool generation")

View File

@@ -1,8 +1,4 @@
use tfhe_versionable::Versionize;
use crate::backward_compatibility::compact_list::CompactCiphertextListVersions;
#[cfg(feature = "zk-pok")]
use crate::backward_compatibility::compact_list::ProvenCompactCiphertextListVersions;
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::commons::math::random::{Deserialize, Serialize};
use crate::core_crypto::prelude::Numeric;
@@ -10,7 +6,7 @@ use crate::high_level_api::global_state;
use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::traits::Tagged;
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::ciphertext::{Compactable, DataKind, Expandable};
use crate::integer::ciphertext::{Compactable, DataKind};
use crate::integer::encryption::KnowsMessageModulus;
use crate::integer::parameters::{
CompactCiphertextListConformanceParams, IntegerCompactCiphertextListExpansionMode,
@@ -18,9 +14,14 @@ use crate::integer::parameters::{
use crate::named::Named;
use crate::prelude::CiphertextList;
use crate::shortint::MessageModulus;
use crate::HlExpandable;
use tfhe_versionable::Versionize;
#[cfg(feature = "zk-pok")]
pub use zk::ProvenCompactCiphertextList;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_cuda_internal_keys;
#[cfg(feature = "zk-pok")]
use crate::zk::{CompactPkeCrs, ZkComputeLoad};
use crate::{CompactPublicKey, Tag};
@@ -173,7 +174,7 @@ impl CompactCiphertextList {
self.inner
.expand(sks.integer_compact_ciphertext_list_expansion_mode())
.map(|inner| CompactCiphertextListExpander {
inner,
inner: InnerCompactCiphertextListExpander::Cpu(inner),
tag: self.tag.clone(),
})
}
@@ -183,9 +184,11 @@ impl CompactCiphertextList {
if !self.inner.is_packed() && !self.inner.needs_casting() {
// No ServerKey required, short-circuit to avoid the global state call
return Ok(CompactCiphertextListExpander {
inner: self
.inner
.expand(IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking)?,
inner: InnerCompactCiphertextListExpander::Cpu(
self.inner.expand(
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking,
)?,
),
tag: self.tag.clone(),
});
}
@@ -196,7 +199,7 @@ impl CompactCiphertextList {
.inner
.expand(cpu_key.integer_compact_ciphertext_list_expansion_mode())
.map(|inner| CompactCiphertextListExpander {
inner,
inner: InnerCompactCiphertextListExpander::Cpu(inner),
tag: self.tag.clone(),
}),
#[cfg(any(feature = "gpu", feature = "hpu"))]
@@ -228,17 +231,144 @@ impl ParameterSetConformant for CompactCiphertextList {
#[cfg(feature = "zk-pok")]
mod zk {
use super::*;
use crate::backward_compatibility::compact_list::ProvenCompactCiphertextListVersions;
use crate::conformance::ParameterSetConformant;
use crate::high_level_api::global_state::device_of_internal_keys;
use crate::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams;
use crate::zk::CompactPkeCrs;
#[cfg(feature = "gpu")]
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey;
#[cfg(feature = "gpu")]
use crate::integer::gpu::zk::CudaProvenCompactCiphertextList;
use serde::Serializer;
pub enum InnerProvenCompactCiphertextList {
Cpu(crate::integer::ciphertext::ProvenCompactCiphertextList),
#[cfg(feature = "gpu")]
Cuda(crate::integer::gpu::zk::CudaProvenCompactCiphertextList),
}
impl Clone for InnerProvenCompactCiphertextList {
fn clone(&self) -> Self {
match self {
Self::Cpu(inner) => Self::Cpu(inner.clone()),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
Self::Cuda(inner.duplicate(streams))
}),
}
}
}
#[derive(Clone, Serialize, Deserialize, Versionize)]
#[versionize(ProvenCompactCiphertextListVersions)]
pub struct ProvenCompactCiphertextList {
pub(crate) inner: crate::integer::ciphertext::ProvenCompactCiphertextList,
pub(crate) inner: InnerProvenCompactCiphertextList,
pub(crate) tag: Tag,
}
impl InnerProvenCompactCiphertextList {
pub(crate) fn on_cpu(&self) -> &crate::integer::ciphertext::ProvenCompactCiphertextList {
match self {
Self::Cpu(inner) => inner,
#[cfg(feature = "gpu")]
Self::Cuda(inner) => &inner.h_proved_lists,
}
}
#[allow(clippy::unnecessary_wraps)] // Method can return an error if hpu is enabled
fn move_to_device(&mut self, device: crate::Device) -> Result<(), crate::Error> {
let new_value = match (&self, device) {
(Self::Cpu(_), crate::Device::Cpu) => None,
#[cfg(feature = "gpu")]
(Self::Cuda(cuda_ct), crate::Device::CudaGpu) => with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
if cuda_ct.gpu_indexes() == streams.gpu_indexes() {
None
} else {
Some(Self::Cuda(cuda_ct.duplicate(streams)))
}
}),
#[cfg(feature = "gpu")]
(Self::Cuda(cuda_ct), crate::Device::Cpu) => {
let cpu_ct = cuda_ct.h_proved_lists.clone();
Some(Self::Cpu(cpu_ct))
}
#[cfg(feature = "gpu")]
(Self::Cpu(cpu_ct), crate::Device::CudaGpu) => {
let cuda_ct = with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
cpu_ct, streams,
)
});
Some(Self::Cuda(cuda_ct))
}
#[cfg(feature = "hpu")]
(_, crate::Device::Hpu) => {
return Err(crate::error!(
"Hpu does not support ProvenCompactCiphertextList"
))
}
};
if let Some(v) = new_value {
*self = v;
}
Ok(())
}
}
impl serde::Serialize for InnerProvenCompactCiphertextList {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.on_cpu().serialize(serializer)
}
}
impl<'de> serde::Deserialize<'de> for InnerProvenCompactCiphertextList {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut new =
crate::integer::ciphertext::ProvenCompactCiphertextList::deserialize(deserializer)
.map(Self::Cpu)?;
if let Some(device) = device_of_internal_keys() {
new.move_to_device(device)
.map_err(serde::de::Error::custom)?;
}
Ok(new)
}
}
use tfhe_versionable::{Unversionize, UnversionizeError, VersionizeOwned};
impl Versionize for InnerProvenCompactCiphertextList {
type Versioned<'vers> =
<crate::integer::ciphertext::ProvenCompactCiphertextList as VersionizeOwned>::VersionedOwned;
fn versionize(&self) -> Self::Versioned<'_> {
self.on_cpu().clone().versionize_owned()
}
}
impl VersionizeOwned for InnerProvenCompactCiphertextList {
type VersionedOwned =
<crate::integer::ciphertext::ProvenCompactCiphertextList as VersionizeOwned>::VersionedOwned;
fn versionize_owned(self) -> Self::VersionedOwned {
self.on_cpu().clone().versionize_owned()
}
}
impl Unversionize for InnerProvenCompactCiphertextList {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
Ok(Self::Cpu(
crate::integer::ciphertext::ProvenCompactCiphertextList::unversionize(versioned)?,
))
}
}
impl Tagged for ProvenCompactCiphertextList {
fn tag(&self) -> &Tag {
&self.tag
@@ -258,7 +388,7 @@ mod zk {
}
pub fn len(&self) -> usize {
self.inner.len()
self.inner.on_cpu().len()
}
pub fn is_empty(&self) -> bool {
@@ -266,8 +396,9 @@ mod zk {
}
pub fn get_kind_of(&self, index: usize) -> Option<crate::FheTypes> {
self.inner.get_kind_of(index).and_then(|data_kind| {
crate::FheTypes::from_data_kind(data_kind, self.inner.ct_list.message_modulus())
let inner_cpu = self.inner.on_cpu();
inner_cpu.get_kind_of(index).and_then(|data_kind| {
crate::FheTypes::from_data_kind(data_kind, inner_cpu.ct_list.message_modulus())
})
}
@@ -277,7 +408,7 @@ mod zk {
pk: &CompactPublicKey,
metadata: &[u8],
) -> crate::zk::ZkVerificationOutcome {
self.inner.verify(crs, &pk.key.key, metadata)
self.inner.on_cpu().verify(crs, &pk.key.key, metadata)
}
pub fn verify_and_expand(
@@ -286,36 +417,112 @@ mod zk {
pk: &CompactPublicKey,
metadata: &[u8],
) -> crate::Result<CompactCiphertextListExpander> {
// For WASM
if !self.inner.is_packed() && !self.inner.needs_casting() {
// No ServerKey required, short circuit to avoid the global state call
return Ok(CompactCiphertextListExpander {
inner: self.inner.verify_and_expand(
#[allow(irrefutable_let_patterns)]
if let InnerProvenCompactCiphertextList::Cpu(inner) = &self.inner {
// For WASM
if !inner.is_packed() && !inner.needs_casting() {
let expander = inner.verify_and_expand(
crs,
&pk.key.key,
metadata,
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking,
)?,
tag: self.tag.clone(),
});
)?;
// No ServerKey required, short circuit to avoid the global state call
return Ok(CompactCiphertextListExpander {
inner: InnerCompactCiphertextListExpander::Cpu(expander),
tag: self.tag.clone(),
});
}
}
global_state::try_with_internal_keys(|maybe_keys| match maybe_keys {
None => Err(crate::high_level_api::errors::UninitializedServerKey.into()),
Some(InternalServerKey::Cpu(cpu_key)) => self
.inner
.verify_and_expand(
crs,
&pk.key.key,
metadata,
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
)
.map(|expander| CompactCiphertextListExpander {
inner: expander,
tag: self.tag.clone(),
}),
#[cfg(any(feature = "gpu", feature = "hpu"))]
Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())),
Some(InternalServerKey::Cpu(cpu_key)) => match &self.inner {
InnerProvenCompactCiphertextList::Cpu(inner) => inner
.verify_and_expand(
crs,
&pk.key.key,
metadata,
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
)
.map(|expander| CompactCiphertextListExpander {
inner: InnerCompactCiphertextListExpander::Cpu(expander),
tag: self.tag.clone(),
}),
#[cfg(feature = "gpu")]
InnerProvenCompactCiphertextList::Cuda(inner) => inner
.h_proved_lists
.verify_and_expand(
crs,
&pk.key.key,
metadata,
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
)
.map(|expander| CompactCiphertextListExpander {
inner: InnerCompactCiphertextListExpander::Cpu(expander),
tag: self.tag.clone(),
}),
},
#[cfg(feature = "gpu")]
Some(InternalServerKey::Cuda(gpu_key)) => match &self.inner {
InnerProvenCompactCiphertextList::Cuda(inner) => {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
let ksk = CudaKeySwitchingKey {
key_switching_key_material: gpu_key
.key
.cpk_key_switching_key_material
.as_ref()
.unwrap(),
dest_server_key: &gpu_key.key.key,
};
let expander = inner.verify_and_expand(
crs,
&pk.key.key,
metadata,
&ksk,
streams,
)?;
Ok(CompactCiphertextListExpander {
inner: InnerCompactCiphertextListExpander::Cuda(expander),
tag: self.tag.clone(),
})
})
}
InnerProvenCompactCiphertextList::Cpu(cpu_inner) => {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
let gpu_proven_ct = CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
cpu_inner, streams,
);
let ksk = CudaKeySwitchingKey {
key_switching_key_material: gpu_key
.key
.cpk_key_switching_key_material
.as_ref()
.unwrap(),
dest_server_key: &gpu_key.key.key,
};
let expander = gpu_proven_ct.verify_and_expand(
crs,
&pk.key.key,
metadata,
&ksk,
streams,
)?;
Ok(CompactCiphertextListExpander {
inner: InnerCompactCiphertextListExpander::Cuda(expander),
tag: self.tag.clone(),
})
})
}
},
#[cfg(feature = "hpu")]
Some(InternalServerKey::Hpu(_)) => Err(crate::error!(
"Hpu does not support ProvenCompactCiphertextList"
)),
})
}
@@ -324,30 +531,86 @@ mod zk {
///
/// If you are here you were probably looking for it: use at your own risks.
pub fn expand_without_verification(&self) -> crate::Result<CompactCiphertextListExpander> {
// For WASM
if !self.inner.is_packed() && !self.inner.needs_casting() {
// No ServerKey required, short circuit to avoid the global state call
return Ok(CompactCiphertextListExpander {
inner: self.inner.expand_without_verification(
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking,
)?,
tag: self.tag.clone(),
});
#[allow(irrefutable_let_patterns)]
if let InnerProvenCompactCiphertextList::Cpu(inner) = &self.inner {
// For WASM
if !inner.is_packed() && !inner.needs_casting() {
// No ServerKey required, short circuit to avoid the global state call
return Ok(CompactCiphertextListExpander {
inner: InnerCompactCiphertextListExpander::Cpu(
inner.expand_without_verification(
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking,
)?,
),
tag: self.tag.clone(),
});
}
}
global_state::try_with_internal_keys(|maybe_keys| match maybe_keys {
global_state::try_with_internal_keys(|maybe_keys| {
match maybe_keys {
None => Err(crate::high_level_api::errors::UninitializedServerKey.into()),
Some(InternalServerKey::Cpu(cpu_key)) => self
.inner
.expand_without_verification(
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
)
.map(|expander| CompactCiphertextListExpander {
inner: expander,
tag: self.tag.clone(),
}),
#[cfg(any(feature = "gpu", feature = "hpu"))]
Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())),
Some(InternalServerKey::Cpu(cpu_key)) => match &self.inner {
InnerProvenCompactCiphertextList::Cpu(inner) => inner
.expand_without_verification(
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
)
.map(|expander| CompactCiphertextListExpander {
inner: InnerCompactCiphertextListExpander::Cpu(expander),
tag: self.tag.clone(),
}),
#[cfg(feature = "gpu")]
InnerProvenCompactCiphertextList::Cuda(_) => {
Err(crate::Error::new("Tried expanding a ProvenCompactCiphertextList on the GPU, but the set ServerKey is a ServerKey".to_string()))
}
},
#[cfg(feature = "gpu")]
Some(InternalServerKey::Cuda(gpu_key)) => match &self.inner {
InnerProvenCompactCiphertextList::Cuda(inner) => {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
let ksk = CudaKeySwitchingKey {
key_switching_key_material: gpu_key
.key
.cpk_key_switching_key_material
.as_ref()
.unwrap(),
dest_server_key: &gpu_key.key.key,
};
let expander = inner.expand_without_verification(&ksk, streams)?;
Ok(CompactCiphertextListExpander {
inner: InnerCompactCiphertextListExpander::Cuda(expander),
tag: self.tag.clone(),
})
})
}
InnerProvenCompactCiphertextList::Cpu(inner) => {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
let gpu_proven_ct = CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
inner, streams,
);
let ksk = CudaKeySwitchingKey {
key_switching_key_material: gpu_key
.key
.cpk_key_switching_key_material
.as_ref()
.unwrap(),
dest_server_key: &gpu_key.key.key,
};
let expander = gpu_proven_ct.expand_without_verification(&ksk, streams)?;
Ok(CompactCiphertextListExpander {
inner: InnerCompactCiphertextListExpander::Cuda(expander),
tag: self.tag.clone(),
})
})
}
},
#[cfg(feature = "hpu")]
Some(InternalServerKey::Hpu(_)) => Err(crate::error!("Hpu does not support ProvenCompactCiphertextList")),
}
})
}
}
@@ -356,9 +619,7 @@ mod zk {
type ParameterSet = IntegerProvenCompactCiphertextListConformanceParams;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
let Self { inner, tag: _ } = self;
inner.is_conformant(parameter_set)
self.inner.on_cpu().is_conformant(parameter_set)
}
}
@@ -367,7 +628,7 @@ mod zk {
use super::*;
use crate::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams;
use crate::shortint::parameters::*;
use crate::zk::CompactPkeCrs;
use rand::{thread_rng, Rng};
#[test]
@@ -409,14 +670,24 @@ mod zk {
}
}
pub enum InnerCompactCiphertextListExpander {
Cpu(crate::integer::ciphertext::CompactCiphertextListExpander),
#[cfg(feature = "gpu")]
Cuda(crate::integer::gpu::ciphertext::compact_list::CudaCompactCiphertextListExpander),
}
pub struct CompactCiphertextListExpander {
pub(in crate::high_level_api) inner: crate::integer::ciphertext::CompactCiphertextListExpander,
pub inner: InnerCompactCiphertextListExpander,
tag: Tag,
}
impl CiphertextList for CompactCiphertextListExpander {
fn len(&self) -> usize {
self.inner.len()
match &self.inner {
InnerCompactCiphertextListExpander::Cpu(inner) => inner.len(),
#[cfg(feature = "gpu")]
InnerCompactCiphertextListExpander::Cuda(inner) => inner.len(),
}
}
fn is_empty(&self) -> bool {
@@ -424,16 +695,34 @@ impl CiphertextList for CompactCiphertextListExpander {
}
fn get_kind_of(&self, index: usize) -> Option<crate::FheTypes> {
self.inner.get_kind_of(index).and_then(|data_kind| {
crate::FheTypes::from_data_kind(data_kind, self.inner.message_modulus())
})
match &self.inner {
InnerCompactCiphertextListExpander::Cpu(inner) => {
inner.get_kind_of(index).and_then(|data_kind| {
crate::FheTypes::from_data_kind(data_kind, inner.message_modulus())
})
}
#[cfg(feature = "gpu")]
InnerCompactCiphertextListExpander::Cuda(inner) => {
inner.get_kind_of(index).and_then(|data_kind| {
crate::FheTypes::from_data_kind(data_kind, inner.message_modulus(index)?)
})
}
}
}
fn get<T>(&self, index: usize) -> crate::Result<Option<T>>
where
T: Expandable + Tagged,
T: HlExpandable + Tagged,
{
let mut expanded = self.inner.get::<T>(index);
let mut expanded = match &self.inner {
InnerCompactCiphertextListExpander::Cpu(inner) => inner.get::<T>(index),
#[cfg(feature = "gpu")]
InnerCompactCiphertextListExpander::Cuda(inner) => with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
inner.get::<T>(index, streams)
}),
};
if let Ok(Some(inner)) = &mut expanded {
inner.tag_mut().set_data(self.tag.data());
}
@@ -531,7 +820,6 @@ impl CompactCiphertextListBuilder {
})
.expect("Internal error, invalid parameters should not have been allowed")
}
#[cfg(feature = "zk-pok")]
pub fn build_with_proof_packed(
&self,
@@ -542,7 +830,10 @@ impl CompactCiphertextListBuilder {
self.inner
.build_with_proof_packed(crs, metadata, compute_load)
.map(|proved_list| ProvenCompactCiphertextList {
inner: proved_list,
inner:
crate::high_level_api::compact_list::zk::InnerProvenCompactCiphertextList::Cpu(
proved_list,
),
tag: self.tag.clone(),
})
}

View File

@@ -14,7 +14,7 @@ use crate::high_level_api::booleans::InnerBoolean;
use crate::high_level_api::errors::UninitializedServerKey;
use crate::high_level_api::global_state::device_of_internal_keys;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::global_state::with_cuda_internal_keys;
use crate::high_level_api::integers::{FheIntId, FheUintId};
use crate::integer::ciphertext::{DataKind, Expandable};
#[cfg(feature = "gpu")]
@@ -27,7 +27,6 @@ use crate::named::Named;
use crate::prelude::{CiphertextList, Tagged};
use crate::shortint::Ciphertext;
use crate::{Device, FheBool, FheInt, FheUint, Tag};
impl<Id: FheUintId> HlCompressible for FheUint<Id> {
fn compress_into(self, messages: &mut Vec<(ToBeCompressed, DataKind)>) {
match self.ciphertext {
@@ -142,9 +141,12 @@ impl CompressedCiphertextListBuilder {
}
#[cfg(feature = "gpu")]
ToBeCompressed::Cuda(cuda_radix) => {
with_thread_local_cuda_streams(|streams| {
flat_cpu_blocks.append(&mut cuda_radix.to_cpu_blocks(streams));
});
with_thread_local_cuda_streams_for_gpu_indexes(
cuda_radix.d_blocks.0.d_vec.gpu_indexes.as_slice(),
|streams| {
flat_cpu_blocks.append(&mut cuda_radix.to_cpu_blocks(streams));
},
);
}
}
}
@@ -178,17 +180,16 @@ impl CompressedCiphertextListBuilder {
for (element, _) in &self.inner {
match element {
ToBeCompressed::Cpu(cpu_blocks) => {
with_thread_local_cuda_streams(|streams| {
cuda_radixes.push(CudaRadixCiphertext::from_cpu_blocks(
cpu_blocks, streams,
));
})
let streams = &cuda_key.streams;
cuda_radixes
.push(CudaRadixCiphertext::from_cpu_blocks(cpu_blocks, streams));
}
#[cfg(feature = "gpu")]
ToBeCompressed::Cuda(cuda_radix) => {
with_thread_local_cuda_streams(|streams| {
{
let streams = &cuda_key.streams;
cuda_radixes.push(cuda_radix.duplicate(streams));
});
};
}
}
}
@@ -201,10 +202,11 @@ impl CompressedCiphertextListBuilder {
crate::Error::new("Compression key not set in server key".to_owned())
})
.map(|compression_key| {
let packed_list = with_thread_local_cuda_streams(|streams| {
let packed_list = {
let streams = &cuda_key.streams;
compression_key
.compress_ciphertexts_into_list(cuda_radixes.as_slice(), streams)
});
};
let info = self.inner.iter().map(|(_, kind)| *kind).collect();
let compressed_list = CudaCompressedCiphertextList { packed_list, info };
@@ -273,7 +275,8 @@ impl InnerCompressedCiphertextList {
#[cfg(feature = "gpu")]
// We may not be on the correct Cuda device
if let Self::Cuda(cuda_ct) = self {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
if cuda_ct.gpu_indexes() != streams.gpu_indexes() {
*cuda_ct = cuda_ct.duplicate(streams);
}
@@ -295,7 +298,8 @@ impl InnerCompressedCiphertextList {
}
#[cfg(feature = "gpu")]
Device::CudaGpu => {
let new_inner = with_thread_local_cuda_streams(|streams| {
let new_inner = with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cpu_ct.to_cuda_compressed_ciphertext_list(streams)
});
*self = Self::Cuda(new_inner);
@@ -353,7 +357,8 @@ impl Versionize for InnerCompressedCiphertextList {
Self::Cpu(inner) => inner.clone().versionize_owned(),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => {
let cpu_data = with_thread_local_cuda_streams(|streams| {
let cpu_data = with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
inner.to_compressed_ciphertext_list(streams)
});
cpu_data.versionize_owned()
@@ -371,7 +376,8 @@ impl VersionizeOwned for InnerCompressedCiphertextList {
Self::Cpu(inner) => inner.versionize_owned(),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => {
let cpu_data = with_thread_local_cuda_streams(|streams| {
let cpu_data = with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
inner.to_compressed_ciphertext_list(streams)
});
cpu_data.versionize_owned()
@@ -475,11 +481,12 @@ impl CiphertextList for CompressedCiphertextList {
crate::Error::new("Compression key not set in server key".to_owned())
})
.and_then(|decompression_key| {
let mut ct = with_thread_local_cuda_streams(|streams| {
let mut ct = {
let streams = &cuda_key.streams;
self.inner
.on_gpu(streams)
.get::<T>(index, decompression_key, streams)
});
};
if let Ok(Some(ct_ref)) = &mut ct {
ct_ref.tag_mut().set_data(cuda_key.tag.data())
}
@@ -501,7 +508,8 @@ impl CompressedCiphertextList {
InnerCompressedCiphertextList::Cpu(inner) => (inner, tag),
#[cfg(feature = "gpu")]
InnerCompressedCiphertextList::Cuda(inner) => (
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
inner.to_compressed_ciphertext_list(streams)
}),
tag,

View File

@@ -73,14 +73,6 @@ pub fn unset_server_key() {
fn replace_server_key(new_one: Option<impl Into<InternalServerKey>>) -> Option<InternalServerKey> {
let keys = new_one.map(Into::into);
#[cfg(feature = "gpu")]
if let Some(InternalServerKey::Cuda(cuda_key)) = &keys {
gpu::CUDA_STREAMS.with_borrow_mut(|current_streams| {
if current_streams.gpu_indexes() != cuda_key.gpu_indexes() {
*current_streams = cuda_key.build_streams();
}
});
}
INTERNAL_KEYS.replace(keys)
}
@@ -187,7 +179,7 @@ where
.unwrap_display();
let InternalServerKey::Cuda(cuda_key) = key else {
panic!(
"Cpu key requested but only the key for {:?} is available",
"CUDA key requested but only the key for {:?} is available",
key.device()
)
};
@@ -196,13 +188,15 @@ where
}
#[cfg(feature = "gpu")]
pub(in crate::high_level_api) use gpu::{
with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes,
};
pub(in crate::high_level_api) use gpu::with_thread_local_cuda_streams_for_gpu_indexes;
#[cfg(feature = "gpu")]
pub use gpu::CudaGpuChoice;
#[derive(Clone)]
#[cfg(feature = "gpu")]
pub struct CustomMultiGpuIndexes(Vec<GpuIndex>);
#[cfg(feature = "gpu")]
mod gpu {
use crate::core_crypto::gpu::get_number_of_gpus;
@@ -210,28 +204,15 @@ mod gpu {
use super::*;
use std::cell::LazyCell;
thread_local! {
pub(crate) static CUDA_STREAMS: RefCell<CudaStreams> = RefCell::new(CudaStreams::new_multi_gpu());
}
pub(in crate::high_level_api) fn with_thread_local_cuda_streams<
R,
F: for<'a> FnOnce(&'a CudaStreams) -> R,
>(
func: F,
) -> R {
CUDA_STREAMS.with(|cell| func(&cell.borrow()))
}
struct CudaStreamPool {
multi: LazyCell<CudaStreams>,
custom: Option<CudaStreams>,
single: Vec<LazyCell<CudaStreams, Box<dyn Fn() -> CudaStreams>>>,
}
impl CudaStreamPool {
fn new() -> Self {
Self {
multi: LazyCell::new(CudaStreams::new_multi_gpu),
custom: None,
single: (0..get_number_of_gpus())
.map(|index| {
let ctor =
@@ -243,29 +224,6 @@ mod gpu {
}
}
impl<'a> std::ops::Index<&'a [GpuIndex]> for CudaStreamPool {
type Output = CudaStreams;
fn index(&self, indexes: &'a [GpuIndex]) -> &Self::Output {
match indexes.len() {
0 => panic!("Internal error: Gpu indexes must not be empty"),
1 => &self.single[indexes[0].get() as usize],
_ => &self.multi,
}
}
}
impl std::ops::Index<CudaGpuChoice> for CudaStreamPool {
type Output = CudaStreams;
fn index(&self, choice: CudaGpuChoice) -> &Self::Output {
match choice {
CudaGpuChoice::Multi => &self.multi,
CudaGpuChoice::Single(index) => &self.single[index.get() as usize],
}
}
}
pub(in crate::high_level_api) fn with_thread_local_cuda_streams_for_gpu_indexes<
R,
F: for<'a> FnOnce(&'a CudaStreams) -> R,
@@ -276,15 +234,38 @@ mod gpu {
thread_local! {
static POOL: RefCell<CudaStreamPool> = RefCell::new(CudaStreamPool::new());
}
POOL.with_borrow(|stream_pool| {
let stream = &stream_pool[gpu_indexes];
func(stream)
})
if gpu_indexes.len() == 1 {
POOL.with_borrow(|pool| func(&pool.single[gpu_indexes[0].get() as usize]))
} else {
POOL.with_borrow_mut(|pool| match &pool.custom {
Some(streams) if streams.gpu_indexes != gpu_indexes => {
pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes));
}
None => {
pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes));
}
_ => {}
});
POOL.with_borrow(|pool| func(pool.custom.as_ref().unwrap()))
}
}
#[derive(Copy, Clone)]
impl CustomMultiGpuIndexes {
pub fn new(indexes: Vec<GpuIndex>) -> Self {
Self(indexes)
}
pub fn gpu_indexes(&self) -> &[GpuIndex] {
self.0.as_slice()
}
}
#[derive(Clone)]
pub enum CudaGpuChoice {
Single(GpuIndex),
Multi,
Custom(CustomMultiGpuIndexes),
}
impl From<GpuIndex> for CudaGpuChoice {
@@ -293,11 +274,24 @@ mod gpu {
}
}
impl From<Vec<GpuIndex>> for CustomMultiGpuIndexes {
fn from(value: Vec<GpuIndex>) -> Self {
Self(value)
}
}
impl From<CustomMultiGpuIndexes> for CudaGpuChoice {
fn from(values: CustomMultiGpuIndexes) -> Self {
Self::Custom(values)
}
}
impl CudaGpuChoice {
pub(in crate::high_level_api) fn build_streams(self) -> CudaStreams {
match self {
Self::Single(idx) => CudaStreams::new_single_gpu(idx),
Self::Multi => CudaStreams::new_multi_gpu(),
Self::Custom(idxs) => CudaStreams::new_multi_gpu_with_indexes(idxs.gpu_indexes()),
}
}
}

View File

@@ -1,7 +1,5 @@
use super::{FheIntId, FheUint, FheUintId};
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::keys::InternalServerKey;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
@@ -39,7 +37,8 @@ impl<Id: FheUintId> FheUint<Id> {
Self::new(ct, key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let d_ct: CudaUnsignedRadixCiphertext = cuda_key
.key
.key
@@ -50,7 +49,7 @@ impl<Id: FheUintId> FheUint<Id> {
);
Self::new(d_ct, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -92,7 +91,8 @@ impl<Id: FheUintId> FheUint<Id> {
Self::new(ct, key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let d_ct: CudaUnsignedRadixCiphertext = cuda_key
.key
.key
@@ -103,7 +103,7 @@ impl<Id: FheUintId> FheUint<Id> {
streams,
);
Self::new(d_ct, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -145,7 +145,8 @@ impl<Id: FheIntId> FheInt<Id> {
Self::new(ct, key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let d_ct: CudaSignedRadixCiphertext = cuda_key
.key
.key
@@ -156,7 +157,7 @@ impl<Id: FheIntId> FheInt<Id> {
);
Self::new(d_ct, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -200,7 +201,8 @@ impl<Id: FheIntId> FheInt<Id> {
Self::new(ct, key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let d_ct: CudaSignedRadixCiphertext = cuda_key
.key
.key
@@ -211,7 +213,7 @@ impl<Id: FheIntId> FheInt<Id> {
streams,
);
Self::new(d_ct, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")

View File

@@ -17,8 +17,6 @@ use crate::shortint::PBSParameters;
use crate::{Device, FheBool, ServerKey, Tag};
use std::marker::PhantomData;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
pub trait FheIntId: IntegerId {}
/// A Generic FHE signed integer
@@ -197,13 +195,14 @@ where
Self::new(ciphertext, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
.abs(&*self.ciphertext.on_gpu(streams), streams);
Self::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -237,13 +236,14 @@ where
FheBool::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
.is_even(&*self.ciphertext.on_gpu(streams), streams);
FheBool::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -277,13 +277,14 @@ where
FheBool::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
.is_odd(&*self.ciphertext.on_gpu(streams), streams);
FheBool::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -321,7 +322,8 @@ where
crate::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -332,7 +334,7 @@ where
streams,
);
crate::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -370,7 +372,8 @@ where
crate::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -381,7 +384,7 @@ where
streams,
);
crate::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -419,7 +422,8 @@ where
crate::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -430,7 +434,7 @@ where
streams,
);
crate::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -468,7 +472,8 @@ where
crate::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -479,7 +484,7 @@ where
streams,
);
crate::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -601,7 +606,8 @@ where
crate::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -612,7 +618,7 @@ where
streams,
);
crate::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -659,7 +665,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let (result, is_ok) = cuda_key
.key
.key
@@ -673,7 +680,7 @@ where
crate::FheUint32::new(result, cuda_key.tag.clone()),
FheBool::new(is_ok, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -859,7 +866,8 @@ where
Self::new(new_ciphertext, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let target_num_blocks = IntoId::num_blocks(cuda_key.message_modulus());
let new_ciphertext = cuda_key.key.key.cast_to_signed(
input.ciphertext.into_gpu(streams),
@@ -867,7 +875,7 @@ where
streams,
);
Self::new(new_ciphertext, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -908,14 +916,15 @@ where
Self::new(new_ciphertext, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let new_ciphertext = cuda_key.key.key.cast_to_signed(
input.ciphertext.into_gpu(streams),
IntoId::num_blocks(cuda_key.message_modulus()),
streams,
);
Self::new(new_ciphertext, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -959,14 +968,15 @@ where
Self::new(ciphertext, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.cast_to_signed(
input.ciphertext.into_gpu(streams).0,
Id::num_blocks(cuda_key.message_modulus()),
streams,
);
Self::new(inner, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")

View File

@@ -1,7 +1,5 @@
use crate::core_crypto::prelude::SignedNumeric;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::FheIntId;
use crate::high_level_api::keys::InternalServerKey;
use crate::integer::block_decomposition::{DecomposableInto, RecomposableSignedInteger};
@@ -113,14 +111,15 @@ where
Ok(Self::new(ciphertext, key.tag.clone()))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner: CudaSignedRadixCiphertext = cuda_key.key.key.create_trivial_radix(
value,
Id::num_blocks(cuda_key.key.key.message_modulus),
streams,
);
Ok(Self::new(inner, cuda_key.tag.clone()))
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_) => panic!("Hpu does not currently support signed operation"),
})

View File

@@ -4,9 +4,9 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::high_level_api::details::MaybeCloned;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::{
with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes,
};
use crate::high_level_api::global_state::with_cuda_internal_keys;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams_for_gpu_indexes;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
#[cfg(feature = "gpu")]
@@ -38,10 +38,12 @@ impl Clone for SignedRadixCiphertext {
match self {
Self::Cpu(inner) => Self::Cpu(inner.clone()),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => with_thread_local_cuda_streams(|streams| {
let inner = inner.duplicate(streams);
Self::Cuda(inner)
}),
Self::Cuda(inner) => {
with_thread_local_cuda_streams_for_gpu_indexes(inner.gpu_indexes(), |streams| {
let inner = inner.duplicate(streams);
Self::Cuda(inner)
})
}
}
}
}
@@ -141,10 +143,10 @@ impl SignedRadixCiphertext {
streams: &CudaStreams,
) -> MaybeCloned<'_, CudaSignedRadixCiphertext> {
match self {
Self::Cpu(ct) => with_thread_local_cuda_streams(|streams| {
Self::Cpu(ct) => {
let ct = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(ct, streams);
MaybeCloned::Cloned(ct)
}),
}
#[cfg(feature = "gpu")]
Self::Cuda(ct) => {
if ct.gpu_indexes() == streams.gpu_indexes() {
@@ -220,7 +222,8 @@ impl SignedRadixCiphertext {
#[cfg(feature = "gpu")]
(Self::Cuda(cuda_ct), Device::CudaGpu) => {
// We are on a GPU, but it may not be the correct one
let new = with_thread_local_cuda_streams(|streams| {
let new = with_cuda_internal_keys(|key| {
let streams = &key.streams;
if cuda_ct.gpu_indexes() == streams.gpu_indexes() {
None
} else {
@@ -233,14 +236,16 @@ impl SignedRadixCiphertext {
}
#[cfg(feature = "gpu")]
(Self::Cpu(ct), Device::CudaGpu) => {
let new_inner = with_thread_local_cuda_streams(|streams| {
let new_inner = with_cuda_internal_keys(|key| {
let streams = &key.streams;
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(ct, streams)
});
*self = Self::Cuda(new_inner);
}
#[cfg(feature = "gpu")]
(Self::Cuda(ct), Device::Cpu) => {
let new_inner = with_thread_local_cuda_streams(|streams| {
let new_inner = with_cuda_internal_keys(|key| {
let streams = &key.streams;
ct.to_signed_radix_ciphertext(streams)
});
*self = Self::Cpu(new_inner);

View File

@@ -1,12 +1,15 @@
#[cfg(feature = "gpu")]
use crate::high_level_api::details::MaybeCloned;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_cuda_internal_keys;
use crate::high_level_api::integers::{FheIntId, FheUintId};
use crate::high_level_api::keys::InternalServerKey;
#[cfg(feature = "gpu")]
use crate::high_level_api::traits::{
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
SubSizeOnGpu,
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu,
ShlSizeOnGpu, ShrSizeOnGpu, SizeOnGpu, SubSizeOnGpu,
};
use crate::high_level_api::traits::{
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
@@ -21,9 +24,6 @@ use std::ops::{
Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
impl<'a, Id> std::iter::Sum<&'a Self> for FheInt<Id>
where
Id: FheIntId,
@@ -77,7 +77,8 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{
let streams = &cuda_key.streams;
let cts = iter
.map(|fhe_uint| {
match fhe_uint.ciphertext.on_gpu(streams) {
@@ -107,7 +108,7 @@ where
)
});
Self::new(inner, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -151,14 +152,15 @@ where
Self::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.max(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
Self::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -201,14 +203,15 @@ where
Self::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.min(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
Self::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -262,14 +265,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.eq(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -305,14 +309,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.ne(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -374,14 +379,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.lt(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -417,14 +423,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.le(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -460,14 +467,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.gt(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -503,14 +511,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.ge(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -589,7 +598,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let (q, r) = cuda_key.key.key.div_rem(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
@@ -599,7 +609,7 @@ where
FheInt::<Id>::new(q, cuda_key.tag.clone()),
FheInt::<Id>::new(r, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -677,11 +687,11 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -724,11 +734,11 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -771,11 +781,11 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -816,11 +826,11 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -861,11 +871,11 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -906,11 +916,11 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -958,14 +968,14 @@ generic_integer_impl_operation!(
FheInt::new(inner_result, cpu_key.tag.clone())
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
}),
},
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1013,14 +1023,14 @@ generic_integer_impl_operation!(
FheInt::new(inner_result, cpu_key.tag.clone())
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
}),
},
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1132,11 +1142,11 @@ generic_integer_impl_shift_rotate!(
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1180,11 +1190,11 @@ generic_integer_impl_shift_rotate!(
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1228,11 +1238,11 @@ generic_integer_impl_shift_rotate!(
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1276,11 +1286,11 @@ generic_integer_impl_shift_rotate!(
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1330,13 +1340,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.add_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
})
let streams = &cuda_key.streams;
cuda_key.key.key.add_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1381,13 +1390,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.sub_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
})
let streams = &cuda_key.streams;
cuda_key.key.key.sub_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1432,13 +1440,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.mul_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
})
let streams = &cuda_key.streams;
cuda_key.key.key.mul_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1481,13 +1488,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.bitand_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
})
let streams = &cuda_key.streams;
cuda_key.key.key.bitand_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1530,13 +1536,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.bitor_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
})
let streams = &cuda_key.streams;
cuda_key.key.key.bitor_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1579,13 +1584,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.bitxor_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
})
let streams = &cuda_key.streams;
cuda_key.key.key.bitxor_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1633,7 +1637,8 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{
let streams = &cuda_key.streams;
let cuda_lhs = self.ciphertext.as_gpu_mut(streams);
let cuda_result = cuda_key.pbs_key().div(
&*cuda_lhs,
@@ -1641,7 +1646,7 @@ where
streams,
);
*cuda_lhs = cuda_result;
});
};
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1689,15 +1694,13 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let cuda_lhs = self.ciphertext.as_gpu_mut(streams);
let cuda_result = cuda_key.pbs_key().rem(
&*cuda_lhs,
&rhs.ciphertext.on_gpu(streams),
streams,
);
*cuda_lhs = cuda_result;
});
let streams = &cuda_key.streams;
let cuda_lhs = self.ciphertext.as_gpu_mut(streams);
let cuda_result =
cuda_key
.pbs_key()
.rem(&*cuda_lhs, &rhs.ciphertext.on_gpu(streams), streams);
*cuda_lhs = cuda_result;
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1750,13 +1753,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.left_shift_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
});
let streams = &cuda_key.streams;
cuda_key.key.key.left_shift_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1808,13 +1810,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.right_shift_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
});
let streams = &cuda_key.streams;
cuda_key.key.key.right_shift_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1867,13 +1868,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.rotate_left_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
});
let streams = &cuda_key.streams;
cuda_key.key.key.rotate_left_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1926,13 +1926,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.rotate_right_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
});
let streams = &cuda_key.streams;
cuda_key.key.key.rotate_right_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -2002,13 +2001,14 @@ where
FheInt::new(ciphertext, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key
.key
.key
.neg(&*self.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -2075,13 +2075,14 @@ where
FheInt::new(ciphertext, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key
.key
.key
.bitnot(&*self.ciphertext.on_gpu(streams), streams);
FheInt::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -2101,13 +2102,12 @@ where
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_add_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_add_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2126,13 +2126,12 @@ where
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_sub_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_sub_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2148,12 +2147,11 @@ where
fn get_size_on_gpu(&self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key
.key
.key
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
})
let streams = &cuda_key.streams;
cuda_key
.key
.key
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
} else {
0
}
@@ -2171,13 +2169,12 @@ where
let rhs = rhs.borrow();
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_bitand_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_bitand_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2195,13 +2192,12 @@ where
let rhs = rhs.borrow();
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_bitor_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_bitor_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2219,13 +2215,12 @@ where
let rhs = rhs.borrow();
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_bitxor_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_bitxor_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2241,11 +2236,219 @@ where
fn get_bitnot_size_on_gpu(&self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key
.key
.key
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
let streams = &cuda_key.streams;
cuda_key
.key
.key
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> FheOrdSizeOnGpu<&Self> for FheInt<Id>
where
Id: FheIntId,
{
fn get_gt_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_gt_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
fn get_ge_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_ge_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
fn get_lt_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_lt_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
fn get_le_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_le_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> FheMinSizeOnGpu<&Self> for FheInt<Id>
where
Id: FheIntId,
{
fn get_min_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_min_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> FheMaxSizeOnGpu<&Self> for FheInt<Id>
where
Id: FheIntId,
{
fn get_max_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_max_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id, Id2> ShlSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
where
Id: FheIntId,
Id2: FheUintId,
{
fn get_left_shift_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_left_shift_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id, Id2> ShrSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
where
Id: FheIntId,
Id2: FheUintId,
{
fn get_right_shift_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_right_shift_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id, Id2> RotateLeftSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
where
Id: FheIntId,
Id2: FheUintId,
{
fn get_rotate_left_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_rotate_left_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id, Id2> RotateRightSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
where
Id: FheIntId,
Id2: FheUintId,
{
fn get_rotate_right_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_rotate_right_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0

View File

@@ -1,7 +1,5 @@
use crate::core_crypto::prelude::SignedNumeric;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::FheIntId;
use crate::high_level_api::keys::InternalServerKey;
use crate::integer::block_decomposition::DecomposableInto;
@@ -53,7 +51,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let (result, overflow) = cuda_key.key.key.signed_overflowing_add(
&self.ciphertext.on_gpu(streams),
&other.ciphertext.on_gpu(streams),
@@ -63,7 +62,7 @@ where
FheInt::new(result, cuda_key.tag.clone()),
FheBool::new(overflow, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -153,7 +152,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_add(
&self.ciphertext.on_gpu(streams),
other,
@@ -163,7 +163,7 @@ where
FheInt::new(result, cuda_key.tag.clone()),
FheBool::new(overflow, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -291,7 +291,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let (result, overflow) = cuda_key.key.key.signed_overflowing_sub(
&self.ciphertext.on_gpu(streams),
&other.ciphertext.on_gpu(streams),
@@ -301,7 +302,7 @@ where
FheInt::new(result, cuda_key.tag.clone()),
FheBool::new(overflow, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -390,7 +391,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_sub(
&self.ciphertext.on_gpu(streams),
other,
@@ -400,7 +402,7 @@ where
FheInt::new(result, cuda_key.tag.clone()),
FheBool::new(overflow, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")

View File

@@ -3,13 +3,15 @@ use crate::core_crypto::commons::numeric::CastFrom;
use crate::high_level_api::errors::UnwrapResultExt;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::global_state::with_cuda_internal_keys;
use crate::high_level_api::integers::signed::inner::SignedRadixCiphertext;
use crate::high_level_api::integers::FheIntId;
use crate::high_level_api::keys::InternalServerKey;
#[cfg(feature = "gpu")]
use crate::high_level_api::traits::{
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SubSizeOnGpu,
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, FheMaxSizeOnGpu,
FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu,
ShrSizeOnGpu, SubSizeOnGpu,
};
use crate::high_level_api::traits::{
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
@@ -59,14 +61,13 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
let inner_result = cuda_key.key.key.scalar_max(
&*self.ciphertext.on_gpu(streams),
rhs,
streams,
);
Self::new(inner_result, cuda_key.tag.clone())
})
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams);
Self::new(inner_result, cuda_key.tag.clone())
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -111,14 +112,13 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
let inner_result = cuda_key.key.key.scalar_min(
&*self.ciphertext.on_gpu(streams),
rhs,
streams,
);
Self::new(inner_result, cuda_key.tag.clone())
})
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams);
Self::new(inner_result, cuda_key.tag.clone())
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -162,14 +162,13 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
let inner_result =
cuda_key
.key
.key
.scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
})
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -207,14 +206,13 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
let inner_result =
cuda_key
.key
.key
.scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
})
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -257,14 +255,13 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
let inner_result =
cuda_key
.key
.key
.scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
})
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -301,14 +298,13 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
let inner_result =
cuda_key
.key
.key
.scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
})
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -345,14 +341,13 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
let inner_result =
cuda_key
.key
.key
.scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
})
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -389,14 +384,13 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
let inner_result =
cuda_key
.key
.key
.scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
})
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -406,6 +400,118 @@ where
}
}
#[cfg(feature = "gpu")]
impl<Id, Clear> FheOrdSizeOnGpu<Clear> for FheInt<Id>
where
Id: FheIntId,
Clear: DecomposableInto<u64>,
{
fn get_gt_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
fn get_ge_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
fn get_lt_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
fn get_le_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id, Clear> FheMinSizeOnGpu<Clear> for FheInt<Id>
where
Id: FheIntId,
Clear: DecomposableInto<u64>,
{
fn get_min_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_min_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id, Clear> FheMaxSizeOnGpu<Clear> for FheInt<Id>
where
Id: FheIntId,
Clear: DecomposableInto<u64>,
{
fn get_max_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_max_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
}
// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
@@ -444,11 +550,11 @@ macro_rules! generic_integer_impl_scalar_div_rem {
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
let (inner_q, inner_r) = {let streams = &cuda_key.streams;
cuda_key.key.key.signed_scalar_div_rem(
&*self.ciphertext.on_gpu(streams), rhs, streams
)
});
};
let (q, r) = (
SignedRadixCiphertext::Cuda(inner_q),
SignedRadixCiphertext::Cuda(inner_r),
@@ -501,11 +607,11 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_left_shift(
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -521,6 +627,31 @@ macro_rules! define_scalar_rotate_shifts {
)*
);
#[cfg(feature = "gpu")]
generic_integer_impl_get_scalar_operation_size_on_gpu!(
rust_trait: ShlSizeOnGpu(get_left_shift_size_on_gpu),
implem: {
|lhs: &FheInt<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_scalar_left_shift_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
}
},
fhe_and_scalar_type:
$(
($concrete_type, $($scalar_type,)*),
)*
);
generic_integer_impl_scalar_operation!(
rust_trait: Shr(shr),
implem: {
@@ -534,11 +665,11 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_right_shift(
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -554,6 +685,31 @@ macro_rules! define_scalar_rotate_shifts {
)*
);
#[cfg(feature = "gpu")]
generic_integer_impl_get_scalar_operation_size_on_gpu!(
rust_trait: ShrSizeOnGpu(get_right_shift_size_on_gpu),
implem: {
|lhs: &FheInt<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_scalar_right_shift_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
}
},
fhe_and_scalar_type:
$(
($concrete_type, $($scalar_type,)*),
)*
);
generic_integer_impl_scalar_operation!(
rust_trait: RotateLeft(rotate_left),
implem: {
@@ -567,11 +723,11 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_rotate_left(
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -587,6 +743,31 @@ macro_rules! define_scalar_rotate_shifts {
)*
);
#[cfg(feature = "gpu")]
generic_integer_impl_get_scalar_operation_size_on_gpu!(
rust_trait: RotateLeftSizeOnGpu(get_rotate_left_size_on_gpu),
implem: {
|lhs: &FheInt<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_scalar_rotate_left_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
}
},
fhe_and_scalar_type:
$(
($concrete_type, $($scalar_type,)*),
)*
);
generic_integer_impl_scalar_operation!(
rust_trait: RotateRight(rotate_right),
implem: {
@@ -600,11 +781,11 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_rotate_right(
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -620,6 +801,31 @@ macro_rules! define_scalar_rotate_shifts {
)*
);
#[cfg(feature = "gpu")]
generic_integer_impl_get_scalar_operation_size_on_gpu!(
rust_trait: RotateRightSizeOnGpu(get_rotate_right_size_on_gpu),
implem: {
|lhs: &FheInt<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_scalar_rotate_right_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
}
},
fhe_and_scalar_type:
$(
($concrete_type, $($scalar_type,)*),
)*
);
generic_integer_impl_scalar_operation_assign!(
rust_trait: ShlAssign(shl_assign),
implem: {
@@ -631,11 +837,10 @@ macro_rules! define_scalar_rotate_shifts {
.scalar_left_shift_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) =>
{let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -662,10 +867,9 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -691,11 +895,10 @@ macro_rules! define_scalar_rotate_shifts {
.scalar_rotate_left_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) =>
{let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -722,10 +925,10 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -839,11 +1042,11 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_add(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -866,12 +1069,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheInt<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_add_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -896,11 +1098,11 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_sub(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -923,12 +1125,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheInt<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_sub_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -953,11 +1154,11 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_mul(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -987,11 +1188,11 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_bitand(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1014,12 +1215,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheInt<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitand_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1044,11 +1244,11 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_bitor(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1071,12 +1271,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheInt<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitor_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1102,11 +1301,11 @@ macro_rules! define_scalar_ops {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.scalar_bitxor(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1129,12 +1328,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheInt<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1159,11 +1357,11 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.signed_scalar_div(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1193,11 +1391,11 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {let streams = &cuda_key.streams;
cuda_key.key.key.signed_scalar_rem(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
SignedRadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1237,12 +1435,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheInt<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_add_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1267,12 +1464,11 @@ macro_rules! define_scalar_ops {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext;
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let mut result: CudaSignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix(
lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
cuda_key.pbs_key().sub_assign(&mut result, &*rhs.ciphertext.on_gpu(streams), streams);
SignedRadixCiphertext::Cuda(result)
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1294,12 +1490,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheInt<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_sub_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1350,12 +1545,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheInt<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitand_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1389,12 +1583,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheInt<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitor_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1428,12 +1621,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheInt<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1459,10 +1651,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1494,10 +1685,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1525,10 +1715,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1557,10 +1746,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1588,10 +1776,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1618,10 +1805,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1647,11 +1833,12 @@ macro_rules! define_scalar_ops {
.signed_scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
let cuda_result = cuda_key.pbs_key().signed_scalar_div(&cuda_lhs, rhs, streams);
*cuda_lhs = cuda_result;
}),
},
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1676,11 +1863,11 @@ macro_rules! define_scalar_ops {
.signed_scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
let cuda_result = cuda_key.pbs_key().signed_scalar_rem(&cuda_lhs, rhs, streams);
*cuda_lhs = cuda_result;
}),
},
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")

View File

@@ -2,13 +2,13 @@ use crate::high_level_api::integers::signed::tests::{
test_case_ilog2, test_case_leading_trailing_zeros_ones,
};
use crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu;
use crate::high_level_api::traits::AddSizeOnGpu;
use crate::prelude::{
check_valid_cuda_malloc, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
FheTryEncrypt, SubSizeOnGpu,
check_valid_cuda_malloc, AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu,
BitXorSizeOnGpu, FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, FheTryEncrypt,
RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu, ShrSizeOnGpu, SubSizeOnGpu,
};
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS;
use crate::{FheInt32, GpuIndex};
use crate::{FheInt32, FheUint32, GpuIndex};
use rand::Rng;
#[test]
@@ -162,3 +162,132 @@ fn test_gpu_get_bitops_size_on_gpu() {
GpuIndex::new(0)
));
}
#[test]
fn test_gpu_get_comparisons_size_on_gpu() {
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
let mut rng = rand::thread_rng();
let clear_a = rng.gen_range(1..=i32::MAX);
let clear_b = rng.gen_range(1..=i32::MAX);
let mut a = FheInt32::try_encrypt(clear_a, &cks).unwrap();
let mut b = FheInt32::try_encrypt(clear_b, &cks).unwrap();
a.move_to_current_device();
b.move_to_current_device();
let a = &a;
let b = &b;
let gt_tmp_buffer_size = a.get_gt_size_on_gpu(b);
let scalar_gt_tmp_buffer_size = a.get_gt_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
gt_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_gt_tmp_buffer_size,
GpuIndex::new(0)
));
let ge_tmp_buffer_size = a.get_ge_size_on_gpu(b);
let scalar_ge_tmp_buffer_size = a.get_ge_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
ge_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_ge_tmp_buffer_size,
GpuIndex::new(0)
));
let lt_tmp_buffer_size = a.get_lt_size_on_gpu(b);
let scalar_lt_tmp_buffer_size = a.get_lt_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
lt_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_lt_tmp_buffer_size,
GpuIndex::new(0)
));
let le_tmp_buffer_size = a.get_le_size_on_gpu(b);
let scalar_le_tmp_buffer_size = a.get_le_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
le_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_le_tmp_buffer_size,
GpuIndex::new(0)
));
let max_tmp_buffer_size = a.get_max_size_on_gpu(b);
let scalar_max_tmp_buffer_size = a.get_max_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
max_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_max_tmp_buffer_size,
GpuIndex::new(0)
));
let min_tmp_buffer_size = a.get_min_size_on_gpu(b);
let scalar_min_tmp_buffer_size = a.get_min_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
min_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_min_tmp_buffer_size,
GpuIndex::new(0)
));
}
#[test]
fn test_gpu_get_shift_rotate_size_on_gpu() {
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
let mut rng = rand::thread_rng();
let clear_a = rng.gen_range(1..=i32::MAX);
let clear_b = rng.gen_range(1..=u32::MAX);
let mut a = FheInt32::try_encrypt(clear_a, &cks).unwrap();
let mut b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
a.move_to_current_device();
b.move_to_current_device();
let a = &a;
let b = &b;
let left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(b);
let scalar_left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
left_shift_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_left_shift_tmp_buffer_size,
GpuIndex::new(0)
));
let right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(b);
let scalar_right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
right_shift_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_right_shift_tmp_buffer_size,
GpuIndex::new(0)
));
let rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(b);
let scalar_rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
rotate_left_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rotate_left_tmp_buffer_size,
GpuIndex::new(0)
));
let rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(b);
let scalar_rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
rotate_right_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rotate_right_tmp_buffer_size,
GpuIndex::new(0)
));
}

View File

@@ -4,8 +4,6 @@ use super::inner::RadixCiphertext;
use crate::backward_compatibility::integers::FheUintVersions;
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::prelude::{CastFrom, UnsignedInteger, UnsignedNumeric};
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::signed::{FheInt, FheIntId};
use crate::high_level_api::integers::IntegerId;
use crate::high_level_api::keys::InternalServerKey;
@@ -304,13 +302,14 @@ where
FheBool::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
.is_even(&*self.ciphertext.on_gpu(streams), streams);
FheBool::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -344,13 +343,14 @@ where
FheBool::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
.is_odd(&*self.ciphertext.on_gpu(streams), streams);
FheBool::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -481,7 +481,8 @@ where
super::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -492,7 +493,7 @@ where
streams,
);
super::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -530,7 +531,8 @@ where
super::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -541,7 +543,7 @@ where
streams,
);
super::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -579,7 +581,8 @@ where
super::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -590,7 +593,7 @@ where
streams,
);
super::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -628,7 +631,8 @@ where
super::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -639,7 +643,7 @@ where
streams,
);
super::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -761,7 +765,8 @@ where
super::FheUint32::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key
.key
.key
@@ -772,7 +777,7 @@ where
streams,
);
super::FheUint32::new(result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -819,7 +824,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let (result, is_ok) = cuda_key
.key
.key
@@ -833,7 +839,7 @@ where
super::FheUint32::new(result, cuda_key.tag.clone()),
FheBool::new(is_ok, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -905,7 +911,8 @@ where
}
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let (result, matched) = cuda_key.key.key.match_value(
&self.ciphertext.on_gpu(streams),
matches,
@@ -920,7 +927,7 @@ where
} else {
Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string()))
}
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -987,7 +994,8 @@ where
}
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let result = cuda_key.key.key.match_value_or(
&self.ciphertext.on_gpu(streams),
matches,
@@ -1000,7 +1008,7 @@ where
} else {
Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string()))
}
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1216,14 +1224,15 @@ where
Self::new(casted, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let casted = cuda_key.key.key.cast_to_unsigned(
input.ciphertext.into_gpu(streams),
IntoId::num_blocks(cuda_key.message_modulus()),
streams,
);
Self::new(casted, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1264,14 +1273,15 @@ where
Self::new(casted, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let casted = cuda_key.key.key.cast_to_unsigned(
input.ciphertext.into_gpu(streams),
IntoId::num_blocks(cuda_key.message_modulus()),
streams,
);
Self::new(casted, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1312,14 +1322,15 @@ where
Self::new(ciphertext, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.cast_to_unsigned(
input.ciphertext.into_gpu(streams).0,
Id::num_blocks(cuda_key.message_modulus()),
streams,
);
Self::new(inner, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")

View File

@@ -1,7 +1,5 @@
use crate::core_crypto::prelude::UnsignedNumeric;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::FheUintId;
use crate::high_level_api::keys::InternalServerKey;
use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom};
@@ -115,14 +113,15 @@ where
Ok(Self::new(ciphertext, key.tag.clone()))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix(
value,
Id::num_blocks(cuda_key.key.key.message_modulus),
streams,
);
Ok(Self::new(inner, cuda_key.tag.clone()))
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support trivial encryption")

View File

@@ -4,9 +4,9 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::high_level_api::details::MaybeCloned;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::{
with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes,
};
use crate::high_level_api::global_state::with_cuda_internal_keys;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams_for_gpu_indexes;
#[cfg(feature = "hpu")]
use crate::high_level_api::keys::HpuTaggedDevice;
#[cfg(feature = "gpu")]
@@ -52,9 +52,10 @@ impl Clone for RadixCiphertext {
match self {
Self::Cpu(inner) => Self::Cpu(inner.clone()),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => {
with_thread_local_cuda_streams(|streams| Self::Cuda(inner.duplicate(streams)))
}
Self::Cuda(inner) => with_cuda_internal_keys(|key| {
let streams = &key.streams;
Self::Cuda(inner.duplicate(streams))
}),
#[cfg(feature = "hpu")]
Self::Hpu(inner) => {
// NB: Hpu backends flavor behavs differently regarding memory.
@@ -166,10 +167,13 @@ impl RadixCiphertext {
match self {
Self::Cpu(ct) => MaybeCloned::Borrowed(ct),
#[cfg(feature = "gpu")]
Self::Cuda(ct) => with_thread_local_cuda_streams(|streams| {
let cpu_ct = ct.to_radix_ciphertext(streams);
MaybeCloned::Cloned(cpu_ct)
}),
Self::Cuda(ct) => with_thread_local_cuda_streams_for_gpu_indexes(
ct.ciphertext.d_blocks.0.d_vec.gpu_indexes.as_slice(),
|streams| {
let cpu_ct = ct.to_radix_ciphertext(streams);
MaybeCloned::Cloned(cpu_ct)
},
),
#[cfg(feature = "hpu")]
Self::Hpu(hpu_ct) => {
let cpu_inner = hpu_ct.to_radix_ciphertext();
@@ -297,7 +301,8 @@ impl RadixCiphertext {
#[cfg(feature = "gpu")]
// We may not be on the correct Cuda device
if let Self::Cuda(cuda_ct) = self {
with_thread_local_cuda_streams(|streams| {
with_cuda_internal_keys(|key| {
let streams = &key.streams;
if cuda_ct.gpu_indexes() != streams.gpu_indexes() {
*cuda_ct = cuda_ct.duplicate(streams);
}
@@ -319,7 +324,8 @@ impl RadixCiphertext {
}
#[cfg(feature = "gpu")]
Device::CudaGpu => {
let new_inner = with_thread_local_cuda_streams(|streams| {
let new_inner = with_cuda_internal_keys(|key| {
let streams = &key.streams;
crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext::from_radix_ciphertext(
&cpu_ct, streams,
)

View File

@@ -7,13 +7,14 @@ use super::inner::RadixCiphertext;
use crate::high_level_api::details::MaybeCloned;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::global_state::with_cuda_internal_keys;
use crate::high_level_api::integers::FheUintId;
use crate::high_level_api::keys::InternalServerKey;
#[cfg(feature = "gpu")]
use crate::high_level_api::traits::{
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
SubSizeOnGpu,
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu,
ShlSizeOnGpu, ShrSizeOnGpu, SizeOnGpu, SubSizeOnGpu,
};
use crate::high_level_api::traits::{
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
@@ -83,7 +84,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let cts = iter
.map(|fhe_uint| fhe_uint.ciphertext.into_gpu(streams))
.collect::<Vec<_>>();
@@ -100,7 +102,7 @@ where
)
});
Self::new(inner, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let mut iter = iter;
@@ -171,7 +173,8 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{
let streams = &cuda_key.streams;
let cts = iter
.map(|fhe_uint| {
match fhe_uint.ciphertext.on_gpu(streams) {
@@ -201,7 +204,7 @@ where
)
});
Self::new(inner, cuda_key.tag.clone())
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -256,14 +259,15 @@ where
Self::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.max(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
Self::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -306,14 +310,15 @@ where
Self::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.min(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
Self::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -367,14 +372,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.eq(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.on_hpu(device);
@@ -428,14 +434,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.ne(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.on_hpu(device);
@@ -515,14 +522,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.lt(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.on_hpu(device);
@@ -576,14 +584,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.le(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.on_hpu(device);
@@ -637,14 +646,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.gt(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.on_hpu(device);
@@ -698,14 +708,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.ge(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.on_hpu(device);
@@ -803,7 +814,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.div_rem(
&*self.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
@@ -813,7 +825,7 @@ where
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
FheUint::<Id>::new(inner_result.1, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -893,11 +905,10 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -942,11 +953,10 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -991,11 +1001,10 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -1038,11 +1047,10 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -1085,11 +1093,10 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -1132,11 +1139,10 @@ generic_integer_impl_operation!(
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -1186,14 +1192,15 @@ generic_integer_impl_operation!(
FheUint::new(inner_result, cpu_key.tag.clone())
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
}),
},
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1241,14 +1248,16 @@ generic_integer_impl_operation!(
FheUint::new(inner_result, cpu_key.tag.clone())
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) =>
{
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
}),
},
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1360,11 +1369,10 @@ generic_integer_impl_shift_rotate!(
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1408,11 +1416,10 @@ generic_integer_impl_shift_rotate!(
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1456,11 +1463,10 @@ generic_integer_impl_shift_rotate!(
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1504,11 +1510,10 @@ generic_integer_impl_shift_rotate!(
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key
.rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1557,13 +1562,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.add_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
@@ -1608,13 +1614,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.sub_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
@@ -1659,13 +1666,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.mul_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
@@ -1708,13 +1716,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.bitand_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
@@ -1757,13 +1766,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.bitor_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
@@ -1806,13 +1816,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.bitxor_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
@@ -1860,13 +1871,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.div_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1912,13 +1924,14 @@ where
);
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
cuda_key.key.key.rem_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1970,13 +1983,14 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{
let streams = &cuda_key.streams;
cuda_key.key.key.left_shift_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
});
};
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -2028,13 +2042,14 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{
let streams = &cuda_key.streams;
cuda_key.key.key.right_shift_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
});
};
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -2087,13 +2102,14 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{
let streams = &cuda_key.streams;
cuda_key.key.key.rotate_left_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
});
};
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -2146,13 +2162,12 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.rotate_right_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
});
let streams = &cuda_key.streams;
cuda_key.key.key.rotate_right_assign(
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -2230,13 +2245,14 @@ where
FheUint::new(ciphertext, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key
.key
.key
.neg(&*self.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -2303,13 +2319,14 @@ where
FheUint::new(ciphertext, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key
.key
.key
.bitnot(&*self.ciphertext.on_gpu(streams), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support bitnot (operator `!`)")
@@ -2327,13 +2344,12 @@ where
let rhs = rhs.borrow();
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_add_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_add_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2351,13 +2367,12 @@ where
let rhs = rhs.borrow();
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_sub_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_sub_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2372,12 +2387,11 @@ where
fn get_size_on_gpu(&self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key
.key
.key
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
})
let streams = &cuda_key.streams;
cuda_key
.key
.key
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
} else {
0
}
@@ -2395,13 +2409,12 @@ where
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_bitand_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_bitand_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2419,13 +2432,12 @@ where
let rhs = rhs.borrow();
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_bitor_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_bitor_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2443,13 +2455,12 @@ where
let rhs = rhs.borrow();
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.get_bitxor_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
let streams = &cuda_key.streams;
cuda_key.key.key.get_bitxor_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
} else {
0
}
@@ -2465,11 +2476,214 @@ where
fn get_bitnot_size_on_gpu(&self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
cuda_key
.key
.key
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
let streams = &cuda_key.streams;
cuda_key
.key
.key
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> FheOrdSizeOnGpu<&Self> for FheUint<Id>
where
Id: FheUintId,
{
fn get_gt_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_gt_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
fn get_ge_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_ge_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
fn get_lt_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_lt_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
fn get_le_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_le_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> FheMinSizeOnGpu<&Self> for FheUint<Id>
where
Id: FheUintId,
{
fn get_min_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_min_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> FheMaxSizeOnGpu<&Self> for FheUint<Id>
where
Id: FheUintId,
{
fn get_max_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_max_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> ShlSizeOnGpu<&Self> for FheUint<Id>
where
Id: FheUintId,
{
fn get_left_shift_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_left_shift_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> ShrSizeOnGpu<&Self> for FheUint<Id>
where
Id: FheUintId,
{
fn get_right_shift_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_right_shift_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> RotateLeftSizeOnGpu<&Self> for FheUint<Id>
where
Id: FheUintId,
{
fn get_rotate_left_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_rotate_left_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id> RotateRightSizeOnGpu<&Self> for FheUint<Id>
where
Id: FheUintId,
{
fn get_rotate_right_size_on_gpu(&self, rhs: &Self) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_rotate_right_size_on_gpu(
&*self.ciphertext.on_gpu(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0

View File

@@ -1,7 +1,5 @@
use crate::core_crypto::prelude::UnsignedNumeric;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::FheUintId;
use crate::high_level_api::keys::InternalServerKey;
use crate::integer::block_decomposition::DecomposableInto;
@@ -53,7 +51,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.unsigned_overflowing_add(
&self.ciphertext.on_gpu(streams),
&other.ciphertext.on_gpu(streams),
@@ -63,7 +62,7 @@ where
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
FheBool::new(inner_result.1, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -153,7 +152,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.unsigned_overflowing_scalar_add(
&self.ciphertext.on_gpu(streams),
other,
@@ -163,7 +163,7 @@ where
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
FheBool::new(inner_result.1, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -293,7 +293,8 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result = cuda_key.key.key.unsigned_overflowing_sub(
&self.ciphertext.on_gpu(streams),
&other.ciphertext.on_gpu(streams),
@@ -303,7 +304,7 @@ where
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
FheBool::new(inner_result.1, cuda_key.tag.clone()),
)
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")

View File

@@ -8,12 +8,14 @@ use crate::error::InvalidRangeError;
use crate::high_level_api::errors::UnwrapResultExt;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::global_state::with_cuda_internal_keys;
use crate::high_level_api::integers::FheUintId;
use crate::high_level_api::keys::InternalServerKey;
#[cfg(feature = "gpu")]
use crate::high_level_api::traits::{
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SubSizeOnGpu,
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, FheMaxSizeOnGpu,
FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu,
ShrSizeOnGpu, SubSizeOnGpu,
};
use crate::high_level_api::traits::{
BitSlice, DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
@@ -64,14 +66,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -107,14 +110,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -156,14 +160,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -199,14 +204,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -242,14 +248,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -285,14 +292,15 @@ where
FheBool::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams);
FheBool::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -301,6 +309,119 @@ where
}
}
#[cfg(feature = "gpu")]
impl<Id, Clear> FheOrdSizeOnGpu<Clear> for FheUint<Id>
where
Id: FheUintId,
Clear: DecomposableInto<u64>,
{
fn get_gt_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_gt_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
fn get_ge_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_ge_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
fn get_lt_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_lt_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
fn get_le_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id, Clear> FheMinSizeOnGpu<Clear> for FheUint<Id>
where
Id: FheUintId,
Clear: DecomposableInto<u64>,
{
fn get_min_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_min_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
}
#[cfg(feature = "gpu")]
impl<Id, Clear> FheMaxSizeOnGpu<Clear> for FheUint<Id>
where
Id: FheUintId,
Clear: DecomposableInto<u64>,
{
fn get_max_size_on_gpu(&self, _rhs: Clear) -> u64 {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key
.key
.key
.get_scalar_max_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
})
} else {
0
}
})
}
}
impl<Id, Clear> FheMax<Clear> for FheUint<Id>
where
Clear: DecomposableInto<u64>,
@@ -336,14 +457,15 @@ where
Self::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams);
Self::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -387,14 +509,15 @@ where
Self::new(inner_result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let inner_result =
cuda_key
.key
.key
.scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams);
Self::new(inner_result, cuda_key.tag.clone())
}),
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -533,11 +656,12 @@ macro_rules! generic_integer_impl_scalar_div_rem {
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
let (inner_q, inner_r) = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_div_rem(
&*self.ciphertext.on_gpu(streams), rhs, streams
)
});
};
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
(
<$concrete_type>::new(q, cuda_key.tag.clone()),
@@ -785,11 +909,12 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_left_shift(
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -805,6 +930,31 @@ macro_rules! define_scalar_rotate_shifts {
)*
);
#[cfg(feature = "gpu")]
generic_integer_impl_get_scalar_operation_size_on_gpu!(
rust_trait: ShlSizeOnGpu(get_left_shift_size_on_gpu),
implem: {
|lhs: &FheUint<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_scalar_left_shift_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
}
},
fhe_and_scalar_type:
$(
($concrete_type, $($scalar_type,)*),
)*
);
generic_integer_impl_scalar_operation!(
rust_trait: Shr(shr),
implem: {
@@ -818,11 +968,12 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_right_shift(
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -838,6 +989,31 @@ macro_rules! define_scalar_rotate_shifts {
)*
);
#[cfg(feature = "gpu")]
generic_integer_impl_get_scalar_operation_size_on_gpu!(
rust_trait: ShrSizeOnGpu(get_right_shift_size_on_gpu),
implem: {
|lhs: &FheUint<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_scalar_right_shift_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
}
},
fhe_and_scalar_type:
$(
($concrete_type, $($scalar_type,)*),
)*
);
generic_integer_impl_scalar_operation!(
rust_trait: RotateLeft(rotate_left),
implem: {
@@ -851,11 +1027,12 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_rotate_left(
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -871,6 +1048,31 @@ macro_rules! define_scalar_rotate_shifts {
)*
);
#[cfg(feature = "gpu")]
generic_integer_impl_get_scalar_operation_size_on_gpu!(
rust_trait: RotateLeftSizeOnGpu(get_rotate_left_size_on_gpu),
implem: {
|lhs: &FheUint<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_scalar_rotate_left_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
}
},
fhe_and_scalar_type:
$(
($concrete_type, $($scalar_type,)*),
)*
);
generic_integer_impl_scalar_operation!(
rust_trait: RotateRight(rotate_right),
implem: {
@@ -884,11 +1086,12 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_rotate_right(
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -904,6 +1107,31 @@ macro_rules! define_scalar_rotate_shifts {
)*
);
#[cfg(feature = "gpu")]
generic_integer_impl_get_scalar_operation_size_on_gpu!(
rust_trait: RotateRightSizeOnGpu(get_rotate_right_size_on_gpu),
implem: {
|lhs: &FheUint<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
cuda_key.key.key.get_scalar_rotate_right_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
}
},
fhe_and_scalar_type:
$(
($concrete_type, $($scalar_type,)*),
)*
);
generic_integer_impl_scalar_operation_assign!(
rust_trait: ShlAssign(shl_assign),
implem: {
@@ -916,10 +1144,11 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -946,10 +1175,11 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
{
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -976,10 +1206,9 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1006,10 +1235,9 @@ macro_rules! define_scalar_rotate_shifts {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1124,11 +1352,12 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_add(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1154,12 +1383,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheUint<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_add_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1184,11 +1412,12 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_sub(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1214,12 +1443,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheUint<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_sub_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1244,11 +1472,12 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_mul(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1281,11 +1510,12 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_bitand(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1308,12 +1538,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheUint<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitand_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1338,11 +1567,12 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_bitor(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1365,12 +1595,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheUint<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitor_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1396,11 +1625,12 @@ macro_rules! define_scalar_ops {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_bitxor(
&*lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1423,12 +1653,11 @@ macro_rules! define_scalar_ops {
|lhs: &FheUint<_>, _rhs| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
&*lhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1453,11 +1682,12 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_div(
&lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1487,11 +1717,12 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
let inner_result = {
let streams = &cuda_key.streams;
cuda_key.key.key.scalar_rem(
&lhs.ciphertext.on_gpu(streams), rhs, streams
)
});
};
RadixCiphertext::Cuda(inner_result)
}
#[cfg(feature = "hpu")]
@@ -1531,12 +1762,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheUint<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_add_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1560,12 +1790,11 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
let mut result: CudaUnsignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix(
lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
cuda_key.pbs_key().sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams);
RadixCiphertext::Cuda(result)
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1587,12 +1816,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheUint<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_sub_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1643,12 +1871,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheUint<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitand_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1682,12 +1909,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheUint<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitor_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1721,12 +1947,11 @@ macro_rules! define_scalar_ops {
|_lhs, rhs: &FheUint<_>| {
global_state::with_internal_keys(|key|
if let InternalServerKey::Cuda(cuda_key) = key {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
&*rhs.ciphertext.on_gpu(streams),
streams,
)
})
} else {
0
})
@@ -1752,10 +1977,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -1790,10 +2014,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -1824,10 +2047,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(device) => {
@@ -1859,10 +2081,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1890,10 +2111,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1920,10 +2140,9 @@ macro_rules! define_scalar_ops {
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_streams(|streams| {
let streams = &cuda_key.streams;
cuda_key.key.key
.scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
@@ -1949,11 +2168,12 @@ macro_rules! define_scalar_ops {
.scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
let cuda_result = cuda_key.pbs_key().scalar_div(&cuda_lhs, rhs, streams);
*cuda_lhs = cuda_result;
}),
},
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")
@@ -1978,11 +2198,12 @@ macro_rules! define_scalar_ops {
.scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
let cuda_result = cuda_key.pbs_key().scalar_rem(&cuda_lhs, rhs, streams);
*cuda_lhs = cuda_result;
}),
},
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation yet.")

View File

@@ -1,7 +1,8 @@
use crate::high_level_api::traits::AddSizeOnGpu;
use crate::prelude::{
check_valid_cuda_malloc, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
FheTryEncrypt, SubSizeOnGpu,
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, FheTryEncrypt, RotateLeftSizeOnGpu,
RotateRightSizeOnGpu, ShlSizeOnGpu, ShrSizeOnGpu, SubSizeOnGpu,
};
use crate::shortint::parameters::{
TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS,
@@ -254,3 +255,132 @@ fn test_gpu_get_bitops_size_on_gpu() {
GpuIndex::new(0)
));
}
#[test]
fn test_gpu_get_comparisons_size_on_gpu() {
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
let mut rng = rand::thread_rng();
let clear_a = rng.gen_range(1..=u32::MAX);
let clear_b = rng.gen_range(1..=u32::MAX);
let mut a = FheUint32::try_encrypt(clear_a, &cks).unwrap();
let mut b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
a.move_to_current_device();
b.move_to_current_device();
let a = &a;
let b = &b;
let gt_tmp_buffer_size = a.get_gt_size_on_gpu(b);
let scalar_gt_tmp_buffer_size = a.get_gt_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
gt_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_gt_tmp_buffer_size,
GpuIndex::new(0)
));
let ge_tmp_buffer_size = a.get_ge_size_on_gpu(b);
let scalar_ge_tmp_buffer_size = a.get_ge_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
ge_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_ge_tmp_buffer_size,
GpuIndex::new(0)
));
let lt_tmp_buffer_size = a.get_lt_size_on_gpu(b);
let scalar_lt_tmp_buffer_size = a.get_lt_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
lt_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_lt_tmp_buffer_size,
GpuIndex::new(0)
));
let le_tmp_buffer_size = a.get_le_size_on_gpu(b);
let scalar_le_tmp_buffer_size = a.get_le_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
le_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_le_tmp_buffer_size,
GpuIndex::new(0)
));
let max_tmp_buffer_size = a.get_max_size_on_gpu(b);
let scalar_max_tmp_buffer_size = a.get_max_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
max_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_max_tmp_buffer_size,
GpuIndex::new(0)
));
let min_tmp_buffer_size = a.get_min_size_on_gpu(b);
let scalar_min_tmp_buffer_size = a.get_min_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
min_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_min_tmp_buffer_size,
GpuIndex::new(0)
));
}
#[test]
fn test_gpu_get_shift_rotate_size_on_gpu() {
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
let mut rng = rand::thread_rng();
let clear_a = rng.gen_range(1..=u32::MAX);
let clear_b = rng.gen_range(1..=u32::MAX);
let mut a = FheUint32::try_encrypt(clear_a, &cks).unwrap();
let mut b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
a.move_to_current_device();
b.move_to_current_device();
let a = &a;
let b = &b;
let left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(b);
let scalar_left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
left_shift_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_left_shift_tmp_buffer_size,
GpuIndex::new(0)
));
let right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(b);
let scalar_right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
right_shift_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_right_shift_tmp_buffer_size,
GpuIndex::new(0)
));
let rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(b);
let scalar_rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
rotate_left_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rotate_left_tmp_buffer_size,
GpuIndex::new(0)
));
let rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(b);
let scalar_rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
rotate_right_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rotate_right_tmp_buffer_size,
GpuIndex::new(0)
));
}

View File

@@ -110,7 +110,7 @@ impl IntegerClientKey {
);
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
let cks = crate::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder)
.new_client_key(config.block_parameters.into());
.new_client_key(config.block_parameters);
let key = crate::integer::ClientKey::from(cks);
@@ -172,7 +172,7 @@ impl IntegerClientKey {
if let Some(dedicated_compact_private_key) = dedicated_compact_private_key.as_ref() {
assert_eq!(
shortint_cks.parameters.message_modulus(),
shortint_cks.parameters().message_modulus(),
dedicated_compact_private_key
.0
.key
@@ -180,7 +180,7 @@ impl IntegerClientKey {
.message_modulus,
);
assert_eq!(
shortint_cks.parameters.carry_modulus(),
shortint_cks.parameters().carry_modulus(),
dedicated_compact_private_key
.0
.key
@@ -315,6 +315,9 @@ impl IntegerServerKey {
#[cfg(feature = "gpu")]
pub struct IntegerCudaServerKey {
pub(crate) key: crate::integer::gpu::CudaServerKey,
#[allow(dead_code)]
pub(crate) cpk_key_switching_key_material:
Option<crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial>,
pub(crate) compression_key:
Option<crate::integer::gpu::list_compression::server_keys::CudaCompressionKey>,
pub(crate) decompression_key:

View File

@@ -8,6 +8,8 @@ use super::ClientKey;
use crate::backward_compatibility::keys::{CompressedServerKeyVersions, ServerKeyVersions};
use crate::conformance::ParameterSetConformant;
#[cfg(feature = "gpu")]
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
#[cfg(feature = "gpu")]
use crate::core_crypto::gpu::{synchronize_devices, CudaStreams};
#[cfg(feature = "gpu")]
use crate::high_level_api::keys::inner::IntegerCudaServerKey;
@@ -292,6 +294,21 @@ impl CompressedServerKey {
&self.integer_key.key,
&streams,
);
let cpk_key_switching_key_material = self
.integer_key
.cpk_key_switching_key_material
.as_ref()
.map(|cpk_ksk_material| {
let ksk_material = cpk_ksk_material.decompress();
let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
&ksk_material.material.key_switching_key,
&streams,
);
CudaKeySwitchingKeyMaterial {
lwe_keyswitch_key: d_ksk,
destination_key: ksk_material.material.destination_key,
}
});
let compression_key: Option<
crate::integer::gpu::list_compression::server_keys::CudaCompressionKey,
> = self
@@ -328,10 +345,12 @@ impl CompressedServerKey {
CudaServerKey {
key: Arc::new(IntegerCudaServerKey {
key,
cpk_key_switching_key_material,
compression_key,
decompression_key,
}),
tag: self.tag.clone(),
streams,
}
}
}
@@ -355,6 +374,7 @@ impl Named for CompressedServerKey {
pub struct CudaServerKey {
pub(crate) key: Arc<IntegerCudaServerKey>,
pub(crate) tag: Tag,
pub(crate) streams: CudaStreams,
}
#[cfg(feature = "gpu")]
@@ -370,14 +390,6 @@ impl CudaServerKey {
pub fn gpu_indexes(&self) -> &[GpuIndex] {
&self.key.key.key_switching_key.d_vec.gpu_indexes
}
pub(crate) fn build_streams(&self) -> CudaStreams {
if self.gpu_indexes().len() == 1 {
CudaStreams::new_single_gpu(self.gpu_indexes()[0])
} else {
CudaStreams::new_multi_gpu()
}
}
}
#[cfg(feature = "gpu")]
@@ -460,6 +472,9 @@ mod hpu {
use crate::high_level_api::keys::inner::IntegerServerKeyConformanceParams;
#[cfg(feature = "gpu")]
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial;
impl ParameterSetConformant for ServerKey {
type ParameterSet = IntegerServerKeyConformanceParams;

View File

@@ -26,6 +26,7 @@ pub use crate::high_level_api::gpu_utils::*;
pub use crate::high_level_api::strings::traits::*;
#[cfg(feature = "gpu")]
pub use crate::high_level_api::traits::{
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
SubSizeOnGpu,
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu,
ShlSizeOnGpu, ShrSizeOnGpu, SizeOnGpu, SubSizeOnGpu,
};

View File

@@ -1,9 +1,11 @@
use rand::Rng;
use crate::core_crypto::gpu::get_number_of_gpus;
use crate::high_level_api::global_state::CustomMultiGpuIndexes;
use crate::prelude::*;
use crate::{
set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device, FheUint32, GpuIndex,
set_server_key, unset_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device,
FheUint32, GpuIndex,
};
#[test]
@@ -117,3 +119,74 @@ fn test_gpu_selection_2() {
assert_eq!(c.gpu_indexes(), &[first_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
}
#[test]
fn test_specific_gpu_selection() {
let config = ConfigBuilder::default().build();
let keys = ClientKey::generate(config);
let compressed_server_keys = CompressedServerKey::new(&keys);
let mut rng = rand::thread_rng();
let total_gpus = get_number_of_gpus() as usize;
// There are 2^total_gpus possible subsets, excluding the empty one
for num_gpus_to_use in 1..(1 << total_gpus) {
let mut selected_indices = Vec::new();
for j in 0..total_gpus {
if (num_gpus_to_use & (1 << j)) != 0 {
selected_indices.push(j);
}
}
// Convert the selected indices to GpuIndex objects
let gpus_to_be_used = CustomMultiGpuIndexes::new(
selected_indices
.iter()
.map(|idx| GpuIndex::new(*idx as u32))
.collect(),
);
let cuda_key = compressed_server_keys.decompress_to_specific_gpu(gpus_to_be_used);
let first_gpu = GpuIndex::new(selected_indices[0] as u32);
let clear_a: u32 = rng.gen();
let clear_b: u32 = rng.gen();
let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap();
let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap();
assert_eq!(a.current_device(), Device::Cpu);
assert_eq!(b.current_device(), Device::Cpu);
assert_eq!(a.gpu_indexes(), &[]);
assert_eq!(b.gpu_indexes(), &[]);
set_server_key(cuda_key);
let c = &a + &b;
let decrypted: u32 = c.decrypt(&keys);
assert_eq!(c.current_device(), Device::CudaGpu);
assert_eq!(c.gpu_indexes(), &[first_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
// Check explicit move, but first make sure input are on Cpu still
assert_eq!(a.current_device(), Device::Cpu);
assert_eq!(b.current_device(), Device::Cpu);
assert_eq!(a.gpu_indexes(), &[]);
assert_eq!(b.gpu_indexes(), &[]);
a.move_to_current_device();
b.move_to_current_device();
assert_eq!(a.current_device(), Device::CudaGpu);
assert_eq!(b.current_device(), Device::CudaGpu);
assert_eq!(a.gpu_indexes(), &[first_gpu]);
assert_eq!(b.gpu_indexes(), &[first_gpu]);
let c = &a + &b;
let decrypted: u32 = c.decrypt(&keys);
assert_eq!(c.current_device(), Device::CudaGpu);
assert_eq!(c.gpu_indexes(), &[first_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
unset_server_key();
}
}

View File

@@ -143,6 +143,115 @@ fn test_tag_propagation_zk_pok() {
}
}
#[test]
#[cfg(feature = "zk-pok")]
#[cfg(feature = "gpu")]
fn test_tag_propagation_zk_pok_gpu() {
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let config =
ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128)
.use_dedicated_compact_public_key_parameters((
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
))
.build();
let crs = crate::zk::CompactPkeCrs::from_config(config, (2 * 32) + (2 * 64) + 2).unwrap();
let metadata = [b'h', b'l', b'a', b'p', b'i'];
let mut cks = ClientKey::generate(config);
let tag_value = random();
cks.tag_mut().set_u64(tag_value);
let cks = serialize_then_deserialize(&cks);
assert_eq!(cks.tag().as_u64(), tag_value);
let compressed_server_key = CompressedServerKey::new(&cks);
let gpu_sks = compressed_server_key.decompress_to_gpu();
assert_eq!(gpu_sks.tag(), cks.tag());
set_server_key(gpu_sks);
let cpk = CompactPublicKey::new(&cks);
assert_eq!(cpk.tag(), cks.tag());
let mut builder = CompactCiphertextList::builder(&cpk);
let list_packed = builder
.push(32u32)
.push(1u32)
.push(-1i64)
.push(i64::MIN)
.push(false)
.push(true)
.build_with_proof_packed(&crs, &metadata, crate::zk::ZkComputeLoad::Proof)
.unwrap();
let expander = list_packed
.verify_and_expand(&crs, &cpk, &metadata)
.unwrap();
{
let au32: FheUint32 = expander.get(0).unwrap().unwrap();
let bu32: FheUint32 = expander.get(1).unwrap().unwrap();
assert_eq!(au32.tag(), cks.tag());
assert_eq!(bu32.tag(), cks.tag());
let cu32 = au32 + bu32;
assert_eq!(cu32.tag(), cks.tag());
}
{
let ai64: FheInt64 = expander.get(2).unwrap().unwrap();
let bi64: FheInt64 = expander.get(3).unwrap().unwrap();
assert_eq!(ai64.tag(), cks.tag());
assert_eq!(bi64.tag(), cks.tag());
let ci64 = ai64 + bi64;
assert_eq!(ci64.tag(), cks.tag());
}
{
let abool: FheBool = expander.get(4).unwrap().unwrap();
let bbool: FheBool = expander.get(5).unwrap().unwrap();
assert_eq!(abool.tag(), cks.tag());
assert_eq!(bbool.tag(), cks.tag());
let cbool = abool & bbool;
assert_eq!(cbool.tag(), cks.tag());
}
let unverified_expander = list_packed.expand_without_verification().unwrap();
{
let au32: FheUint32 = unverified_expander.get(0).unwrap().unwrap();
let bu32: FheUint32 = unverified_expander.get(1).unwrap().unwrap();
assert_eq!(au32.tag(), cks.tag());
assert_eq!(bu32.tag(), cks.tag());
let cu32 = au32 + bu32;
assert_eq!(cu32.tag(), cks.tag());
}
{
let ai64: FheInt64 = unverified_expander.get(2).unwrap().unwrap();
let bi64: FheInt64 = unverified_expander.get(3).unwrap().unwrap();
assert_eq!(ai64.tag(), cks.tag());
assert_eq!(bi64.tag(), cks.tag());
let ci64 = ai64 + bi64;
assert_eq!(ci64.tag(), cks.tag());
}
{
let abool: FheBool = unverified_expander.get(4).unwrap().unwrap();
let bbool: FheBool = unverified_expander.get(5).unwrap().unwrap();
assert_eq!(abool.tag(), cks.tag());
assert_eq!(bbool.tag(), cks.tag());
let cbool = abool & bbool;
assert_eq!(cbool.tag(), cks.tag());
}
}
#[test]
#[cfg(feature = "gpu")]
fn test_tag_propagation_gpu() {

View File

@@ -265,3 +265,40 @@ pub trait BitXorSizeOnGpu<Rhs = Self> {
pub trait BitNotSizeOnGpu {
fn get_bitnot_size_on_gpu(&self) -> u64;
}
#[cfg(feature = "gpu")]
pub trait FheOrdSizeOnGpu<Rhs = Self> {
fn get_gt_size_on_gpu(&self, amount: Rhs) -> u64;
fn get_lt_size_on_gpu(&self, amount: Rhs) -> u64;
fn get_ge_size_on_gpu(&self, amount: Rhs) -> u64;
fn get_le_size_on_gpu(&self, amount: Rhs) -> u64;
}
#[cfg(feature = "gpu")]
pub trait FheMinSizeOnGpu<Rhs = Self> {
fn get_min_size_on_gpu(&self, other: Rhs) -> u64;
}
#[cfg(feature = "gpu")]
pub trait FheMaxSizeOnGpu<Rhs = Self> {
fn get_max_size_on_gpu(&self, other: Rhs) -> u64;
}
#[cfg(feature = "gpu")]
pub trait ShlSizeOnGpu<Rhs = Self> {
fn get_left_shift_size_on_gpu(&self, other: Rhs) -> u64;
}
#[cfg(feature = "gpu")]
pub trait ShrSizeOnGpu<Rhs = Self> {
fn get_right_shift_size_on_gpu(&self, other: Rhs) -> u64;
}
#[cfg(feature = "gpu")]
pub trait RotateLeftSizeOnGpu<Rhs = Self> {
fn get_rotate_left_size_on_gpu(&self, other: Rhs) -> u64;
}
#[cfg(feature = "gpu")]
pub trait RotateRightSizeOnGpu<Rhs = Self> {
fn get_rotate_right_size_on_gpu(&self, other: Rhs) -> u64;
}

View File

@@ -310,7 +310,7 @@ pub struct CompactCiphertextListExpander {
}
impl CompactCiphertextListExpander {
fn new(expanded_blocks: Vec<Ciphertext>, info: Vec<DataKind>) -> Self {
pub(crate) fn new(expanded_blocks: Vec<Ciphertext>, info: Vec<DataKind>) -> Self {
Self {
expanded_blocks,
info,

View File

@@ -143,7 +143,7 @@ impl ClientKey {
}
pub fn parameters(&self) -> crate::shortint::AtomicPatternParameters {
self.key.parameters.ap_parameters().unwrap()
self.key.parameters().ap_parameters().unwrap()
}
#[cfg(test)]
@@ -333,7 +333,7 @@ impl ClientKey {
return T::ZERO;
}
let bits_in_block = self.key.parameters.message_modulus().0.ilog2();
let bits_in_block = self.key.parameters().message_modulus().0.ilog2();
let decrypted_block_iter = blocks.iter().map(|block| decrypt_block(&self.key, block));
BlockRecomposer::recompose_unsigned(decrypted_block_iter, bits_in_block)
}
@@ -417,7 +417,7 @@ impl ClientKey {
return T::ZERO;
}
let bits_in_block = self.key.parameters.message_modulus().0.ilog2();
let bits_in_block = self.key.parameters().message_modulus().0.ilog2();
let decrypted_block_iter = ctxt
.blocks
.iter()

View File

@@ -7,7 +7,7 @@ pub(crate) trait KnowsMessageModulus {
impl KnowsMessageModulus for crate::shortint::ClientKey {
fn message_modulus(&self) -> MessageModulus {
self.parameters.message_modulus()
self.parameters().message_modulus()
}
}

View File

@@ -1,13 +1,16 @@
use crate::core_crypto::commons::traits::contiguous_entity_container::ContiguousEntityContainer;
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::lwe_compact_ciphertext_list::CudaLweCompactCiphertextList;
use crate::core_crypto::gpu::CudaStreams;
use crate::integer::ciphertext::DataKind;
use crate::core_crypto::prelude::LweCiphertext;
use crate::integer::ciphertext::{CompactCiphertextListExpander, DataKind};
use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaExpandable;
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
use crate::integer::gpu::ciphertext::CudaRadixCiphertext;
use crate::shortint::ciphertext::CompactCiphertextList;
use crate::shortint::parameters::{CompactCiphertextListExpansionKind, Degree};
use crate::shortint::{CarryModulus, MessageModulus};
use crate::shortint::{CarryModulus, Ciphertext, MessageModulus};
use itertools::Itertools;
pub struct CudaCompactCiphertextList {
pub(crate) d_ct_list: CudaLweCompactCiphertextList<u64>,
@@ -17,12 +20,25 @@ pub struct CudaCompactCiphertextList {
pub(crate) expansion_kind: CompactCiphertextListExpansionKind,
}
impl CudaCompactCiphertextList {
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
Self {
d_ct_list: self.d_ct_list.duplicate(streams),
degree: self.degree,
message_modulus: self.message_modulus,
carry_modulus: self.carry_modulus,
expansion_kind: self.expansion_kind,
}
}
}
#[derive(Clone)]
pub struct CudaCompactCiphertextListInfo {
pub info: CudaBlockInfo,
pub data_kind: DataKind,
}
#[derive(Clone)]
pub struct CudaCompactCiphertextListExpander {
pub(crate) expanded_blocks: CudaLweCiphertextList<u64>,
pub(crate) blocks_info: Vec<CudaCompactCiphertextListInfo>,
@@ -39,6 +55,29 @@ impl CudaCompactCiphertextListExpander {
}
}
pub fn len(&self) -> usize {
self.expanded_blocks.lwe_ciphertext_count().0
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn get_kind_of(&self, index: usize) -> Option<DataKind> {
let blocks = self.blocks_info.get(index)?;
Some(blocks.data_kind)
}
pub fn message_modulus(&self, index: usize) -> Option<MessageModulus> {
let blocks = self.blocks_info.get(index)?;
Some(blocks.info.message_modulus)
}
pub fn carry_modulus(&self, index: usize) -> Option<CarryModulus> {
let blocks = self.blocks_info.get(index)?;
Some(blocks.info.carry_modulus)
}
fn blocks_of(
&self,
index: usize,
@@ -92,6 +131,44 @@ impl CudaCompactCiphertextListExpander {
.map(|(blocks, kind)| T::from_expanded_blocks(blocks, kind))
.transpose()
}
pub fn to_compact_ciphertext_list_expander(
&self,
streams: &CudaStreams,
) -> CompactCiphertextListExpander {
let lwe_ciphertext_list = self.expanded_blocks.to_lwe_ciphertext_list(streams);
let ciphertext_modulus = self.expanded_blocks.ciphertext_modulus();
let expanded_blocks = lwe_ciphertext_list
.iter()
.zip(self.blocks_info.clone())
.map(|(ct, info)| {
let lwe = LweCiphertext::from_container(ct.as_ref().to_vec(), ciphertext_modulus);
Ciphertext::new(
lwe,
info.info.degree,
info.info.noise_level,
info.info.message_modulus,
info.info.carry_modulus,
info.info.atomic_pattern,
)
})
.collect_vec();
let info = self
.blocks_info
.iter()
.map(|ct_info| ct_info.data_kind)
.collect_vec();
CompactCiphertextListExpander::new(expanded_blocks, info)
}
pub fn duplicate(&self) -> Self {
Self {
expanded_blocks: self.expanded_blocks.clone(),
blocks_info: self.blocks_info.iter().cloned().collect_vec(),
}
}
}
impl CudaCompactCiphertextList {

View File

@@ -10,6 +10,7 @@ use crate::integer::gpu::list_compression::server_keys::{
};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::RadixClientKey;
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::EncryptionKeyChoice;
@@ -21,7 +22,11 @@ impl RadixClientKey {
) -> (CudaCompressionKey, CudaDecompressionKey) {
let private_compression_key = &private_compression_key.key;
let params = &private_compression_key.params;
let compression_params = &private_compression_key.params;
let AtomicPatternClientKey::Standard(std_cks) = &self.as_ref().key.atomic_pattern else {
panic!("Only the standard atomic pattern is supported on GPU")
};
assert_eq!(
self.parameters().encryption_key_choice(),
@@ -32,11 +37,11 @@ impl RadixClientKey {
// Compression key
let packing_key_switching_key = ShortintEngine::with_thread_local_mut(|engine| {
allocate_and_generate_new_lwe_packing_keyswitch_key(
&self.as_ref().key.large_lwe_secret_key(),
&std_cks.large_lwe_secret_key(),
&private_compression_key.post_packing_ks_key,
params.packing_ks_base_log,
params.packing_ks_level,
params.packing_ks_key_noise_distribution,
compression_params.packing_ks_base_log,
compression_params.packing_ks_level,
compression_params.packing_ks_key_noise_distribution,
self.parameters().ciphertext_modulus(),
&mut engine.encryption_generator,
)
@@ -45,7 +50,7 @@ impl RadixClientKey {
let glwe_compression_key = CompressionKey {
key: crate::shortint::list_compression::CompressionKey {
packing_key_switching_key,
lwe_per_glwe: params.lwe_per_glwe,
lwe_per_glwe: compression_params.lwe_per_glwe,
storage_log_modulus: private_compression_key.params.storage_log_modulus,
},
};
@@ -60,9 +65,9 @@ impl RadixClientKey {
self.parameters().polynomial_size(),
private_compression_key.params.br_base_log,
private_compression_key.params.br_level,
params
compression_params
.packing_ks_glwe_dimension
.to_equivalent_lwe_dimension(params.packing_ks_polynomial_size),
.to_equivalent_lwe_dimension(compression_params.packing_ks_polynomial_size),
self.parameters().ciphertext_modulus(),
);
@@ -71,7 +76,7 @@ impl RadixClientKey {
&private_compression_key
.post_packing_ks_key
.as_lwe_secret_key(),
&self.as_ref().key.glwe_secret_key,
&std_cks.glwe_secret_key,
&mut bsk,
self.parameters().glwe_noise_distribution(),
&mut engine.encryption_generator,
@@ -84,7 +89,7 @@ impl RadixClientKey {
let cuda_decompression_key = CudaDecompressionKey {
blind_rotate_key,
lwe_per_glwe: params.lwe_per_glwe,
lwe_per_glwe: compression_params.lwe_per_glwe,
glwe_dimension: self.parameters().glwe_dimension(),
polynomial_size: self.parameters().polynomial_size(),
message_modulus: self.parameters().message_modulus(),

View File

@@ -1,60 +1,47 @@
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
use crate::core_crypto::gpu::CudaStreams;
use crate::integer::client_key::secret_encryption_key::SecretEncryptionKeyView;
use crate::integer::gpu::CudaServerKey;
use crate::integer::ClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::ShortintKeySwitchingParameters;
use crate::integer::key_switching_key::KeySwitchingKey;
use crate::shortint::EncryptionKeyChoice;
#[derive(Clone)]
#[allow(dead_code)]
pub struct CudaKeySwitchingKey<'keys> {
pub(crate) key_switching_key: CudaLweKeyswitchKey<u64>,
pub(crate) dest_server_key: &'keys CudaServerKey,
pub struct CudaKeySwitchingKeyMaterial {
pub(crate) lwe_keyswitch_key: CudaLweKeyswitchKey<u64>,
pub(crate) destination_key: EncryptionKeyChoice,
}
impl<'keys> CudaKeySwitchingKey<'keys> {
pub fn new<'input_key, InputEncryptionKey>(
input_key_pair: (InputEncryptionKey, Option<&'keys CudaServerKey>),
output_key_pair: (&'keys ClientKey, &'keys CudaServerKey),
params: ShortintKeySwitchingParameters,
#[allow(dead_code)]
pub struct CudaKeySwitchingKey<'key> {
pub(crate) key_switching_key_material: &'key CudaKeySwitchingKeyMaterial,
pub(crate) dest_server_key: &'key CudaServerKey,
}
impl CudaKeySwitchingKeyMaterial {
pub fn from_key_switching_key(
key_switching_key: &KeySwitchingKey,
streams: &CudaStreams,
) -> Self
where
InputEncryptionKey: Into<SecretEncryptionKeyView<'input_key>>,
{
let input_secret_key: SecretEncryptionKeyView<'_> = input_key_pair.0.into();
// Creation of the key switching key
let key_switching_key = ShortintEngine::with_thread_local_mut(|engine| {
engine.new_key_switching_key(&input_secret_key.key, output_key_pair.0.as_ref(), params)
});
let d_key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&key_switching_key, streams);
let full_message_modulus_input =
input_secret_key.key.carry_modulus.0 * input_secret_key.key.message_modulus.0;
let full_message_modulus_output = output_key_pair.0.key.parameters.carry_modulus().0
* output_key_pair.0.key.parameters.message_modulus().0;
assert!(
full_message_modulus_input.is_power_of_two()
&& full_message_modulus_output.is_power_of_two(),
"Cannot create casting key if the full messages moduli are not a power of 2"
) -> Self {
let key_switching_key_material = &key_switching_key.key.key_switching_key_material;
let d_lwe_keyswich_key = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
&key_switching_key_material.key_switching_key,
streams,
);
if full_message_modulus_input > full_message_modulus_output {
assert!(
input_key_pair.1.is_some(),
"Trying to build a integer::gpu::KeySwitchingKey \
going from a large modulus {full_message_modulus_input} \
to a smaller modulus {full_message_modulus_output} \
without providing a source CudaServerKey, this is not supported"
);
}
CudaKeySwitchingKey {
key_switching_key: d_key_switching_key,
dest_server_key: output_key_pair.1,
destination_key: params.destination_key,
Self {
lwe_keyswitch_key: d_lwe_keyswich_key,
destination_key: key_switching_key_material.destination_key,
}
}
}
impl<'key> CudaKeySwitchingKey<'key> {
pub fn from_cuda_key_switching_key_material(
key_switching_key_material: &'key CudaKeySwitchingKeyMaterial,
dest_server_key: &'key CudaServerKey,
) -> Self {
Self {
key_switching_key_material,
dest_server_key,
}
}
}

View File

@@ -1338,6 +1338,64 @@ pub unsafe fn unchecked_comparison_integer_radix_kb_async<T: UnsignedInteger, B:
update_noise_degree(radix_lwe_out, &cuda_ffi_radix_lwe_out);
}
#[allow(clippy::too_many_arguments)]
pub fn get_comparison_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
op: ComparisonType,
is_signed: bool,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_comparison_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
op as u32,
is_signed,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_comparison(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
@@ -2950,6 +3008,399 @@ pub unsafe fn unchecked_rotate_left_integer_radix_kb_assign_async<
update_noise_degree(radix_input, &cuda_ffi_radix_lwe_left);
}
#[allow(clippy::too_many_arguments)]
pub fn get_scalar_left_shift_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_logical_scalar_shift_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
ShiftRotateType::LeftShift as u32,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_radix_logical_scalar_shift(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
pub fn get_scalar_logical_right_shift_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_logical_scalar_shift_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
ShiftRotateType::RightShift as u32,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_radix_logical_scalar_shift(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
pub fn get_scalar_arithmetic_right_shift_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_arithmetic_scalar_shift_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
ShiftRotateType::RightShift as u32,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_radix_arithmetic_scalar_shift(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
pub fn get_right_shift_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
is_signed: bool,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_shift_and_rotate_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
ShiftRotateType::RightShift as u32,
is_signed,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_radix_shift_and_rotate(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
pub fn get_left_shift_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
is_signed: bool,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_shift_and_rotate_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
ShiftRotateType::LeftShift as u32,
is_signed,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_radix_shift_and_rotate(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
pub fn get_rotate_right_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
is_signed: bool,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_shift_and_rotate_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
ShiftRotateType::RightRotate as u32,
is_signed,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_radix_shift_and_rotate(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
pub fn get_rotate_left_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
is_signed: bool,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_shift_and_rotate_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
ShiftRotateType::LeftRotate as u32,
is_signed,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_radix_shift_and_rotate(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
@@ -3381,6 +3832,116 @@ pub unsafe fn unchecked_scalar_rotate_right_integer_radix_kb_assign_async<
update_noise_degree(radix_input, &cuda_ffi_radix_lwe_left);
}
#[allow(clippy::too_many_arguments)]
pub fn get_scalar_rotate_left_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_scalar_rotate_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
ShiftRotateType::LeftShift as u32,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_radix_scalar_rotate(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
pub fn get_scalar_rotate_right_integer_radix_kb_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
num_blocks: u32,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) -> u64 {
let allocate_ms_noise_array = noise_reduction_key.is_some();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_integer_radix_scalar_rotate_kb_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
ShiftRotateType::RightShift as u32,
false,
allocate_ms_noise_array,
)
};
unsafe {
cleanup_cuda_integer_radix_scalar_rotate(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}
size_tracker
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///

View File

@@ -12,11 +12,10 @@ use crate::integer::server_key::num_bits_to_represent_unsigned_value;
use crate::integer::ClientKey;
use crate::shortint::atomic_pattern::compressed::CompressedAtomicPatternServerKey;
use crate::shortint::ciphertext::{MaxDegree, MaxNoiseLevel};
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::server_key::ModulusSwitchNoiseReductionKey;
use crate::shortint::{
AtomicPatternParameters, CarryModulus, CiphertextModulus, MessageModulus, PBSOrder,
};
use crate::shortint::{CarryModulus, CiphertextModulus, MessageModulus, PBSOrder};
mod radix;
pub enum CudaBootstrappingKey {
@@ -72,8 +71,8 @@ impl CudaServerKey {
// It should remain just enough space to add a carry
let client_key = cks.as_ref();
let max_degree = MaxDegree::integer_radix_server_key(
client_key.key.parameters.message_modulus(),
client_key.key.parameters.carry_modulus(),
client_key.key.parameters().message_modulus(),
client_key.key.parameters().carry_modulus(),
);
Self::new_server_key_with_max_degree(client_key, max_degree, streams)
}
@@ -86,16 +85,18 @@ impl CudaServerKey {
let mut engine = ShortintEngine::new();
// Generate a regular keyset and convert to the GPU
let AtomicPatternParameters::Standard(pbs_params_base) = &cks.parameters() else {
let AtomicPatternClientKey::Standard(std_cks) = &cks.key.atomic_pattern else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let pbs_params_base = std_cks.parameters;
let d_bootstrapping_key = match pbs_params_base {
crate::shortint::PBSParameters::PBS(pbs_params) => {
let h_bootstrap_key: LweBootstrapKeyOwned<u64> =
par_allocate_and_generate_new_lwe_bootstrap_key(
&cks.key.small_lwe_secret_key(),
&cks.key.glwe_secret_key,
&std_cks.lwe_secret_key,
&std_cks.glwe_secret_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.glwe_noise_distribution,
@@ -107,7 +108,7 @@ impl CudaServerKey {
.map(|modulus_switch_noise_reduction_params| {
ModulusSwitchNoiseReductionKey::new(
modulus_switch_noise_reduction_params,
&cks.key.small_lwe_secret_key(),
&std_cks.lwe_secret_key,
&mut engine,
pbs_params.ciphertext_modulus,
pbs_params.lwe_noise_distribution,
@@ -124,8 +125,8 @@ impl CudaServerKey {
crate::shortint::PBSParameters::MultiBitPBS(pbs_params) => {
let h_bootstrap_key: LweMultiBitBootstrapKeyOwned<u64> =
par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key(
&cks.key.small_lwe_secret_key(),
&cks.key.glwe_secret_key,
&std_cks.lwe_secret_key,
&std_cks.glwe_secret_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.grouping_factor,
@@ -145,12 +146,12 @@ impl CudaServerKey {
// Creation of the key switching key
let h_key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
&cks.key.large_lwe_secret_key(),
&cks.key.small_lwe_secret_key(),
cks.parameters().ks_base_log(),
cks.parameters().ks_level(),
cks.parameters().lwe_noise_distribution(),
cks.parameters().ciphertext_modulus(),
&std_cks.large_lwe_secret_key(),
&std_cks.small_lwe_secret_key(),
std_cks.parameters.ks_base_log(),
std_cks.parameters.ks_level(),
std_cks.parameters.lwe_noise_distribution(),
std_cks.parameters.ciphertext_modulus(),
&mut engine.encryption_generator,
);
@@ -158,7 +159,7 @@ impl CudaServerKey {
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
assert!(matches!(
cks.parameters().encryption_key_choice().into(),
std_cks.parameters.encryption_key_choice().into(),
PBSOrder::KeyswitchBootstrap
));
@@ -166,12 +167,12 @@ impl CudaServerKey {
Self {
key_switching_key: d_key_switching_key,
bootstrapping_key: d_bootstrapping_key,
message_modulus: cks.parameters().message_modulus(),
carry_modulus: cks.parameters().carry_modulus(),
message_modulus: std_cks.parameters.message_modulus(),
carry_modulus: std_cks.parameters.carry_modulus(),
max_degree,
max_noise_level: cks.parameters().max_noise_level(),
ciphertext_modulus: cks.parameters().ciphertext_modulus(),
pbs_order: cks.parameters().encryption_key_choice().into(),
max_noise_level: std_cks.parameters.max_noise_level(),
ciphertext_modulus: std_cks.parameters.ciphertext_modulus(),
pbs_order: std_cks.parameters.encryption_key_choice().into(),
}
}

View File

@@ -6,6 +6,7 @@ use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
get_comparison_integer_radix_kb_size_on_gpu, get_full_propagate_assign_size_on_gpu,
unchecked_comparison_integer_radix_kb_async, ComparisonType, CudaServerKey, PBSType,
};
use crate::shortint::ciphertext::Degree;
@@ -348,6 +349,120 @@ impl CudaServerKey {
result
}
pub(crate) fn get_comparison_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
op: ComparisonType,
streams: &CudaStreams,
) -> u64 {
assert_eq!(
ct_left.as_ref().d_blocks.lwe_dimension(),
ct_right.as_ref().d_blocks.lwe_dimension()
);
assert_eq!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
streams,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_full_propagate_assign_size_on_gpu(
streams,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
let actual_full_prop_mem = match (
ct_left.block_carries_are_empty(),
ct_right.block_carries_are_empty(),
) {
(true, true) => 0,
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
(false, true) => full_prop_mem,
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
};
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let comparison_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_comparison_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
op,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_comparison_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
op,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
actual_full_prop_mem.max(comparison_mem)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
@@ -936,6 +1051,60 @@ impl CudaServerKey {
result
}
pub fn get_eq_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::EQ, streams)
}
pub fn get_ne_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::NE, streams)
}
pub fn get_gt_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::GT, streams)
}
pub fn get_ge_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::GE, streams)
}
pub fn get_lt_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::LT, streams)
}
pub fn get_le_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::LE, streams)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
@@ -1221,4 +1390,21 @@ impl CudaServerKey {
streams.synchronize();
result
}
pub fn get_max_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::MAX, streams)
}
pub fn get_min_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::MIN, streams)
}
}

View File

@@ -568,6 +568,7 @@ pub(crate) mod test {
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{gen_keys_gpu, CudaServerKey};
use crate::integer::{ClientKey, RadixCiphertext};
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::oprf::create_random_from_seed_modulus_switched;
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
use rayon::prelude::*;
@@ -635,7 +636,11 @@ pub(crate) mod test {
sk.ciphertext_modulus,
);
let sk = ck.key.small_lwe_secret_key();
let AtomicPatternClientKey::Standard(std_ck) = &ck.key.atomic_pattern else {
panic!("Only std AP is supported on GPU")
};
let sk = std_ck.small_lwe_secret_key();
let plain_prf_input = decrypt_lwe_ciphertext(&sk, &ct)
.0
.wrapping_add(1 << (64 - log_input_p - 1))

View File

@@ -3,6 +3,8 @@ use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
get_full_propagate_assign_size_on_gpu, get_rotate_left_integer_radix_kb_size_on_gpu,
get_rotate_right_integer_radix_kb_size_on_gpu,
unchecked_rotate_left_integer_radix_kb_assign_async,
unchecked_rotate_right_integer_radix_kb_assign_async, CudaServerKey, PBSType,
};
@@ -556,4 +558,226 @@ impl CudaServerKey {
unsafe { self.rotate_left_assign_async(ct, rotate, streams) };
streams.synchronize();
}
pub fn get_rotate_left_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> u64 {
assert_eq!(
ct_left.as_ref().d_blocks.lwe_dimension(),
ct_right.as_ref().d_blocks.lwe_dimension()
);
assert_eq!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
streams,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_full_propagate_assign_size_on_gpu(
streams,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
let actual_full_prop_mem = match (
ct_left.block_carries_are_empty(),
ct_right.block_carries_are_empty(),
) {
(true, true) => 0,
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
(false, true) => full_prop_mem,
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
};
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let rotate_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_rotate_left_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_rotate_left_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
actual_full_prop_mem.max(rotate_mem)
}
pub fn get_rotate_right_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> u64 {
assert_eq!(
ct_left.as_ref().d_blocks.lwe_dimension(),
ct_right.as_ref().d_blocks.lwe_dimension()
);
assert_eq!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
streams,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_full_propagate_assign_size_on_gpu(
streams,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
let actual_full_prop_mem = match (
ct_left.block_carries_are_empty(),
ct_right.block_carries_are_empty(),
) {
(true, true) => 0,
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
(false, true) => full_prop_mem,
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
};
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let rotate_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_rotate_right_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_rotate_right_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
actual_full_prop_mem.max(rotate_mem)
}
}

View File

@@ -1057,6 +1057,54 @@ impl CudaServerKey {
result
}
pub fn get_scalar_eq_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::EQ, streams)
}
pub fn get_scalar_ne_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::NE, streams)
}
pub fn get_scalar_gt_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::GT, streams)
}
pub fn get_scalar_ge_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::GE, streams)
}
pub fn get_scalar_lt_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::LT, streams)
}
pub fn get_scalar_le_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::LE, streams)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
@@ -1192,4 +1240,19 @@ impl CudaServerKey {
streams.synchronize();
result
}
pub fn get_scalar_max_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::MAX, streams)
}
pub fn get_scalar_min_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
streams: &CudaStreams,
) -> u64 {
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::MIN, streams)
}
}

View File

@@ -3,6 +3,8 @@ use crate::core_crypto::prelude::{CastFrom, LweBskGroupingFactor};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
get_full_propagate_assign_size_on_gpu, get_scalar_rotate_left_integer_radix_kb_size_on_gpu,
get_scalar_rotate_right_integer_radix_kb_size_on_gpu,
unchecked_scalar_rotate_left_integer_radix_kb_assign_async,
unchecked_scalar_rotate_right_integer_radix_kb_assign_async, CudaServerKey, PBSType,
};
@@ -275,4 +277,193 @@ impl CudaServerKey {
self.scalar_rotate_right_assign(&mut result, shift, stream);
result
}
pub fn get_scalar_rotate_left_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
where
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let full_prop_mem = if ct.block_carries_are_empty() {
0
} else {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
streams,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_full_propagate_assign_size_on_gpu(
streams,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
}
};
let scalar_shift_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
get_scalar_rotate_left_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
)
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_scalar_rotate_left_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
full_prop_mem.max(scalar_shift_mem)
}
pub fn get_scalar_rotate_right_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
where
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let full_prop_mem = if ct.block_carries_are_empty() {
0
} else {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
streams,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_full_propagate_assign_size_on_gpu(
streams,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
}
};
let scalar_shift_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
get_scalar_rotate_right_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
)
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_scalar_rotate_right_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
full_prop_mem.max(scalar_shift_mem)
}
}

View File

@@ -3,6 +3,10 @@ use crate::core_crypto::prelude::{CastFrom, LweBskGroupingFactor};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
get_full_propagate_assign_size_on_gpu,
get_scalar_arithmetic_right_shift_integer_radix_kb_size_on_gpu,
get_scalar_left_shift_integer_radix_kb_size_on_gpu,
get_scalar_logical_right_shift_integer_radix_kb_size_on_gpu,
unchecked_scalar_arithmetic_right_shift_integer_radix_kb_assign_async,
unchecked_scalar_left_shift_integer_radix_kb_assign_async,
unchecked_scalar_logical_right_shift_integer_radix_kb_assign_async, CudaServerKey, PBSType,
@@ -647,4 +651,245 @@ impl CudaServerKey {
}
}
}
pub fn get_scalar_left_shift_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
where
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let full_prop_mem = if ct.block_carries_are_empty() {
0
} else {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
streams,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_full_propagate_assign_size_on_gpu(
streams,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
}
};
let scalar_shift_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
get_scalar_left_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
)
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_scalar_left_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
full_prop_mem.max(scalar_shift_mem)
}
pub fn get_scalar_right_shift_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
where
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let full_prop_mem = if ct.block_carries_are_empty() {
0
} else {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
streams,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_full_propagate_assign_size_on_gpu(
streams,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
}
};
let scalar_shift_mem = if T::IS_SIGNED {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
get_scalar_arithmetic_right_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
)
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_scalar_arithmetic_right_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
}
} else {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
get_scalar_logical_right_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
)
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_scalar_logical_right_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
}
};
full_prop_mem.max(scalar_shift_mem)
}
}

View File

@@ -3,6 +3,8 @@ use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
get_full_propagate_assign_size_on_gpu, get_left_shift_integer_radix_kb_size_on_gpu,
get_right_shift_integer_radix_kb_size_on_gpu,
unchecked_left_shift_integer_radix_kb_assign_async,
unchecked_right_shift_integer_radix_kb_assign_async, CudaServerKey, PBSType,
};
@@ -551,4 +553,226 @@ impl CudaServerKey {
unsafe { self.left_shift_assign_async(ct, shift, streams) };
streams.synchronize();
}
pub fn get_left_shift_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> u64 {
assert_eq!(
ct_left.as_ref().d_blocks.lwe_dimension(),
ct_right.as_ref().d_blocks.lwe_dimension()
);
assert_eq!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
streams,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_full_propagate_assign_size_on_gpu(
streams,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
let actual_full_prop_mem = match (
ct_left.block_carries_are_empty(),
ct_right.block_carries_are_empty(),
) {
(true, true) => 0,
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
(false, true) => full_prop_mem,
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
};
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let shift_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_left_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_left_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
actual_full_prop_mem.max(shift_mem)
}
pub fn get_right_shift_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> u64 {
assert_eq!(
ct_left.as_ref().d_blocks.lwe_dimension(),
ct_right.as_ref().d_blocks.lwe_dimension()
);
assert_eq!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
streams,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_full_propagate_assign_size_on_gpu(
streams,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
let actual_full_prop_mem = match (
ct_left.block_carries_are_empty(),
ct_right.block_carries_are_empty(),
) {
(true, true) => 0,
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
(false, true) => full_prop_mem,
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
};
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let shift_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => get_right_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.d_ms_noise_reduction_key.as_ref(),
),
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
get_right_shift_integer_radix_kb_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
)
}
};
actual_full_prop_mem.max(shift_mem)
}
}

View File

@@ -17,39 +17,39 @@ where
P: Into<TestParameters> + Clone,
{
// Binary Ops Executors
let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
//let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
//let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
let bitwise_and_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
let bitwise_or_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
let bitwise_xor_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
let mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
//let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
//let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
//let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
//let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
//let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
//let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
// Binary Ops Clear functions
let clear_add = |x, y| x + y;
let clear_sub = |x, y| x - y;
//let clear_add = |x, y| x + y;
//let clear_sub = |x, y| x - y;
let clear_bitwise_and = |x, y| x & y;
let clear_bitwise_or = |x, y| x | y;
let clear_bitwise_xor = |x, y| x ^ y;
let clear_mul = |x, y| x * y;
// Warning this rotate definition only works with 64-bit ciphertexts
let clear_rotate_left = |x: u64, y: u64| x.rotate_left(y as u32);
let clear_left_shift = |x, y| x << y;
// Warning this rotate definition only works with 64-bit ciphertexts
let clear_rotate_right = |x: u64, y: u64| x.rotate_right(y as u32);
let clear_right_shift = |x, y| x >> y;
let clear_max = |x: u64, y: u64| max(x, y);
let clear_min = |x: u64, y: u64| min(x, y);
//let clear_rotate_left = |x: u64, y: u64| x.rotate_left(y as u32);
//let clear_left_shift = |x, y| x << y;
//// Warning this rotate definition only works with 64-bit ciphertexts
//let clear_rotate_right = |x: u64, y: u64| x.rotate_right(y as u32);
//let clear_right_shift = |x, y| x >> y;
//let clear_max = |x: u64, y: u64| max(x, y);
//let clear_min = |x: u64, y: u64| min(x, y);
#[allow(clippy::type_complexity)]
let mut binary_ops: Vec<(BinaryOpExecutor, &dyn Fn(u64, u64) -> u64, String)> = vec![
(Box::new(add_executor), &clear_add, "add".to_string()),
(Box::new(sub_executor), &clear_sub, "sub".to_string()),
//(Box::new(add_executor), &clear_add, "add".to_string()),
//(Box::new(sub_executor), &clear_sub, "sub".to_string()),
(
Box::new(bitwise_and_executor),
&clear_bitwise_and,
@@ -66,28 +66,28 @@ where
"bitxor".to_string(),
),
(Box::new(mul_executor), &clear_mul, "mul".to_string()),
(
Box::new(rotate_left_executor),
&clear_rotate_left,
"rotate left".to_string(),
),
(
Box::new(left_shift_executor),
&clear_left_shift,
"left shift".to_string(),
),
(
Box::new(rotate_right_executor),
&clear_rotate_right,
"rotate right".to_string(),
),
(
Box::new(right_shift_executor),
&clear_right_shift,
"right shift".to_string(),
),
(Box::new(max_executor), &clear_max, "max".to_string()),
(Box::new(min_executor), &clear_min, "min".to_string()),
//(
// Box::new(rotate_left_executor),
// &clear_rotate_left,
// "rotate left".to_string(),
//),
//(
// Box::new(left_shift_executor),
// &clear_left_shift,
// "left shift".to_string(),
//),
//(
// Box::new(rotate_right_executor),
// &clear_rotate_right,
// "rotate right".to_string(),
//),
//(
// Box::new(right_shift_executor),
// &clear_right_shift,
// "right shift".to_string(),
//),
//(Box::new(max_executor), &clear_max, "max".to_string()),
//(Box::new(min_executor), &clear_min, "min".to_string()),
];
// Unary Ops Executors
@@ -115,8 +115,8 @@ where
];
// Scalar binary Ops Executors
let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
//let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
//let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
let scalar_bitwise_and_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
let scalar_bitwise_or_executor =
@@ -124,27 +124,27 @@ where
let scalar_bitwise_xor_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
let scalar_mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
let scalar_rotate_left_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
let scalar_left_shift_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
let scalar_rotate_right_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
let scalar_right_shift_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
//let scalar_rotate_left_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
//let scalar_left_shift_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
//let scalar_rotate_right_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
//let scalar_right_shift_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
#[allow(clippy::type_complexity)]
let mut scalar_binary_ops: Vec<(ScalarBinaryOpExecutor, &dyn Fn(u64, u64) -> u64, String)> = vec![
(
Box::new(scalar_add_executor),
&clear_add,
"scalar add".to_string(),
),
(
Box::new(scalar_sub_executor),
&clear_sub,
"scalar sub".to_string(),
),
//(
// Box::new(scalar_add_executor),
// &clear_add,
// "scalar add".to_string(),
//),
//(
// Box::new(scalar_sub_executor),
// &clear_sub,
// "scalar sub".to_string(),
//),
(
Box::new(scalar_bitwise_and_executor),
&clear_bitwise_and,
@@ -165,26 +165,26 @@ where
&clear_mul,
"scalar mul".to_string(),
),
(
Box::new(scalar_rotate_left_executor),
&clear_rotate_left,
"scalar rotate left".to_string(),
),
(
Box::new(scalar_left_shift_executor),
&clear_left_shift,
"scalar left shift".to_string(),
),
(
Box::new(scalar_rotate_right_executor),
&clear_rotate_right,
"scalar rotate right".to_string(),
),
(
Box::new(scalar_right_shift_executor),
&clear_right_shift,
"scalar right shift".to_string(),
),
//(
// Box::new(scalar_rotate_left_executor),
// &clear_rotate_left,
// "scalar rotate left".to_string(),
//),
//(
// Box::new(scalar_left_shift_executor),
// &clear_left_shift,
// "scalar left shift".to_string(),
//),
//(
// Box::new(scalar_rotate_right_executor),
// &clear_rotate_right,
// "scalar rotate right".to_string(),
//),
//(
// Box::new(scalar_right_shift_executor),
// &clear_right_shift,
// "scalar right shift".to_string(),
//),
];
// Overflowing Ops Executors
@@ -249,37 +249,37 @@ where
// Comparison Ops Executors
let gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
//let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
//let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
//let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
//let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
//let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
// Comparison Ops Clear functions
let clear_gt = |x: u64, y: u64| -> bool { x > y };
let clear_ge = |x: u64, y: u64| -> bool { x >= y };
let clear_lt = |x: u64, y: u64| -> bool { x < y };
let clear_le = |x: u64, y: u64| -> bool { x <= y };
let clear_eq = |x: u64, y: u64| -> bool { x == y };
let clear_ne = |x: u64, y: u64| -> bool { x != y };
//let clear_ge = |x: u64, y: u64| -> bool { x >= y };
//let clear_lt = |x: u64, y: u64| -> bool { x < y };
//let clear_le = |x: u64, y: u64| -> bool { x <= y };
//let clear_eq = |x: u64, y: u64| -> bool { x == y };
//let clear_ne = |x: u64, y: u64| -> bool { x != y };
#[allow(clippy::type_complexity)]
let mut comparison_ops: Vec<(ComparisonOpExecutor, &dyn Fn(u64, u64) -> bool, String)> = vec![
(Box::new(gt_executor), &clear_gt, "gt".to_string()),
(Box::new(ge_executor), &clear_ge, "ge".to_string()),
(Box::new(lt_executor), &clear_lt, "lt".to_string()),
(Box::new(le_executor), &clear_le, "le".to_string()),
(Box::new(eq_executor), &clear_eq, "eq".to_string()),
(Box::new(ne_executor), &clear_ne, "ne".to_string()),
//(Box::new(ge_executor), &clear_ge, "ge".to_string()),
//(Box::new(lt_executor), &clear_lt, "lt".to_string()),
//(Box::new(le_executor), &clear_le, "le".to_string()),
//(Box::new(eq_executor), &clear_eq, "eq".to_string()),
//(Box::new(ne_executor), &clear_ne, "ne".to_string()),
];
// Scalar Comparison Ops Executors
let scalar_gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
//let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
//let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
//let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
//let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
//let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
#[allow(clippy::type_complexity)]
let mut scalar_comparison_ops: Vec<(
@@ -292,31 +292,31 @@ where
&clear_gt,
"scalar gt".to_string(),
),
(
Box::new(scalar_ge_executor),
&clear_ge,
"scalar ge".to_string(),
),
(
Box::new(scalar_lt_executor),
&clear_lt,
"scalar lt".to_string(),
),
(
Box::new(scalar_le_executor),
&clear_le,
"scalar le".to_string(),
),
(
Box::new(scalar_eq_executor),
&clear_eq,
"scalar eq".to_string(),
),
(
Box::new(scalar_ne_executor),
&clear_ne,
"scalar ne".to_string(),
),
//(
// Box::new(scalar_ge_executor),
// &clear_ge,
// "scalar ge".to_string(),
//),
//(
// Box::new(scalar_lt_executor),
// &clear_lt,
// "scalar lt".to_string(),
//),
//(
// Box::new(scalar_le_executor),
// &clear_le,
// "scalar le".to_string(),
//),
//(
// Box::new(scalar_eq_executor),
// &clear_eq,
// "scalar eq".to_string(),
//),
//(
// Box::new(scalar_ne_executor),
// &clear_ne,
// "scalar ne".to_string(),
//),
];
// Select Executor

View File

@@ -19,29 +19,29 @@ where
P: Into<TestParameters> + Clone,
{
// Binary Ops Executors
let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
//let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
//let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
let bitwise_and_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
let bitwise_or_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
let bitwise_xor_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
let mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
//let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
//let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
// Binary Ops Clear functions
let clear_add = |x, y| x + y;
let clear_sub = |x, y| x - y;
//let clear_add = |x, y| x + y;
//let clear_sub = |x, y| x - y;
let clear_bitwise_and = |x, y| x & y;
let clear_bitwise_or = |x, y| x | y;
let clear_bitwise_xor = |x, y| x ^ y;
let clear_mul = |x, y| x * y;
let clear_max = |x: i64, y: i64| max(x, y);
let clear_min = |x: i64, y: i64| min(x, y);
//let clear_max = |x: i64, y: i64| max(x, y);
//let clear_min = |x: i64, y: i64| min(x, y);
#[allow(clippy::type_complexity)]
let mut binary_ops: Vec<(SignedBinaryOpExecutor, &dyn Fn(i64, i64) -> i64, String)> = vec![
(Box::new(add_executor), &clear_add, "add".to_string()),
(Box::new(sub_executor), &clear_sub, "sub".to_string()),
//(Box::new(add_executor), &clear_add, "add".to_string()),
//(Box::new(sub_executor), &clear_sub, "sub".to_string()),
(
Box::new(bitwise_and_executor),
&clear_bitwise_and,
@@ -58,14 +58,14 @@ where
"bitxor".to_string(),
),
(Box::new(mul_executor), &clear_mul, "mul".to_string()),
(Box::new(max_executor), &clear_max, "max".to_string()),
(Box::new(min_executor), &clear_min, "min".to_string()),
//(Box::new(max_executor), &clear_max, "max".to_string()),
//(Box::new(min_executor), &clear_min, "min".to_string()),
];
let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
//let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
//let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
//let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
// Warning this rotate definition only works with 64-bit ciphertexts
let clear_rotate_left = |x: i64, y: u64| x.rotate_left(y as u32);
let clear_left_shift = |x: i64, y: u64| x << y;
@@ -83,21 +83,21 @@ where
&clear_rotate_left,
"rotate left".to_string(),
),
(
Box::new(left_shift_executor),
&clear_left_shift,
"left shift".to_string(),
),
(
Box::new(rotate_right_executor),
&clear_rotate_right,
"rotate right".to_string(),
),
(
Box::new(right_shift_executor),
&clear_right_shift,
"right shift".to_string(),
),
//(
// Box::new(left_shift_executor),
// &clear_left_shift,
// "left shift".to_string(),
//),
//(
// Box::new(rotate_right_executor),
// &clear_rotate_right,
// "rotate right".to_string(),
//),
//(
// Box::new(right_shift_executor),
// &clear_right_shift,
// "right shift".to_string(),
//),
];
// Unary Ops Executors
@@ -125,8 +125,8 @@ where
];
// Scalar binary Ops Executors
let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
//let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
//let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
let scalar_bitwise_and_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
let scalar_bitwise_or_executor =
@@ -141,16 +141,16 @@ where
&dyn Fn(i64, i64) -> i64,
String,
)> = vec![
(
Box::new(scalar_add_executor),
&clear_add,
"scalar add".to_string(),
),
(
Box::new(scalar_sub_executor),
&clear_sub,
"scalar sub".to_string(),
),
//(
// Box::new(scalar_add_executor),
// &clear_add,
// "scalar add".to_string(),
//),
//(
// Box::new(scalar_sub_executor),
// &clear_sub,
// "scalar sub".to_string(),
//),
(
Box::new(scalar_bitwise_and_executor),
&clear_bitwise_and,
@@ -175,12 +175,12 @@ where
let scalar_rotate_left_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
let scalar_left_shift_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
let scalar_rotate_right_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
let scalar_right_shift_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
//let scalar_left_shift_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
//let scalar_rotate_right_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
//let scalar_right_shift_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
#[allow(clippy::type_complexity)]
let mut scalar_shift_rotate_ops: Vec<(
SignedScalarShiftRotateExecutor,
@@ -192,21 +192,21 @@ where
&clear_rotate_left,
"scalar rotate left".to_string(),
),
(
Box::new(scalar_left_shift_executor),
&clear_left_shift,
"scalar left shift".to_string(),
),
(
Box::new(scalar_rotate_right_executor),
&clear_rotate_right,
"scalar rotate right".to_string(),
),
(
Box::new(scalar_right_shift_executor),
&clear_right_shift,
"scalar right shift".to_string(),
),
//(
// Box::new(scalar_left_shift_executor),
// &clear_left_shift,
// "scalar left shift".to_string(),
//),
//(
// Box::new(scalar_rotate_right_executor),
// &clear_rotate_right,
// "scalar rotate right".to_string(),
//),
//(
// Box::new(scalar_right_shift_executor),
// &clear_right_shift,
// "scalar right shift".to_string(),
//),
];
// Overflowing Ops Executors
@@ -271,11 +271,11 @@ where
// Comparison Ops Executors
let gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
//let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
//let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
//let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
//let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
//let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
// Comparison Ops Clear functions
let clear_gt = |x: i64, y: i64| -> bool { x > y };
@@ -292,20 +292,20 @@ where
String,
)> = vec![
(Box::new(gt_executor), &clear_gt, "gt".to_string()),
(Box::new(ge_executor), &clear_ge, "ge".to_string()),
(Box::new(lt_executor), &clear_lt, "lt".to_string()),
(Box::new(le_executor), &clear_le, "le".to_string()),
(Box::new(eq_executor), &clear_eq, "eq".to_string()),
(Box::new(ne_executor), &clear_ne, "ne".to_string()),
//(Box::new(ge_executor), &clear_ge, "ge".to_string()),
//(Box::new(lt_executor), &clear_lt, "lt".to_string()),
//(Box::new(le_executor), &clear_le, "le".to_string()),
//(Box::new(eq_executor), &clear_eq, "eq".to_string()),
//(Box::new(ne_executor), &clear_ne, "ne".to_string()),
];
// Scalar Comparison Ops Executors
let scalar_gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
//let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
//let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
//let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
//let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
//let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
#[allow(clippy::type_complexity)]
let mut scalar_comparison_ops: Vec<(
@@ -318,31 +318,31 @@ where
&clear_gt,
"scalar gt".to_string(),
),
(
Box::new(scalar_ge_executor),
&clear_ge,
"scalar ge".to_string(),
),
(
Box::new(scalar_lt_executor),
&clear_lt,
"scalar lt".to_string(),
),
(
Box::new(scalar_le_executor),
&clear_le,
"scalar le".to_string(),
),
(
Box::new(scalar_eq_executor),
&clear_eq,
"scalar eq".to_string(),
),
(
Box::new(scalar_ne_executor),
&clear_ne,
"scalar ne".to_string(),
),
//(
// Box::new(scalar_ge_executor),
// &clear_ge,
// "scalar ge".to_string(),
//),
//(
// Box::new(scalar_lt_executor),
// &clear_lt,
// "scalar lt".to_string(),
//),
//(
// Box::new(scalar_le_executor),
// &clear_le,
// "scalar le".to_string(),
//),
//(
// Box::new(scalar_eq_executor),
// &clear_eq,
// "scalar eq".to_string(),
//),
//(
// Box::new(scalar_ne_executor),
// &clear_ne,
// "scalar ne".to_string(),
//),
];
// Select Executor

View File

@@ -18,6 +18,7 @@ use crate::integer::{CompactPublicKey, ProvenCompactCiphertextList};
use crate::shortint::ciphertext::{Degree, NoiseLevel};
use crate::shortint::AtomicPatternKind;
use crate::zk::CompactPkeCrs;
use crate::GpuIndex;
use itertools::Itertools;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu;
@@ -28,6 +29,28 @@ pub struct CudaProvenCompactCiphertextList {
}
impl CudaProvenCompactCiphertextList {
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
Self {
h_proved_lists: self.h_proved_lists.clone(),
d_compact_lists: self
.d_compact_lists
.iter()
.map(|ct_list| ct_list.duplicate(streams))
.collect_vec(),
}
}
pub fn gpu_indexes(&self) -> &[GpuIndex] {
self.d_compact_lists
.first()
.unwrap()
.d_ct_list
.0
.d_vec
.gpu_indexes
.as_slice()
}
unsafe fn flatten_async(
slice_ciphertext_list: &[CudaCompactCiphertextList],
streams: &CudaStreams,
@@ -152,13 +175,9 @@ impl CudaProvenCompactCiphertextList {
DataKind::Boolean => 1,
DataKind::Signed(x) => *x,
DataKind::Unsigned(x) => *x,
_ => panic!("DataKind not supported on GPUs"),
DataKind::String { .. } => panic!("DataKind not supported on GPUs"),
};
std::iter::repeat(match data_kind {
DataKind::Boolean => true,
_ => false,
})
.take(repetitions)
std::iter::repeat_n(matches!(data_kind, DataKind::Boolean), repetitions)
})
.collect_vec();
@@ -194,25 +213,22 @@ impl CudaProvenCompactCiphertextList {
streams,
);
let d_input = &CudaProvenCompactCiphertextList::flatten_async(
self.d_compact_lists.as_slice(),
streams,
);
let casting_key = &key.key_switching_key;
let sks = key.dest_server_key;
let d_input = &Self::flatten_async(self.d_compact_lists.as_slice(), streams);
let casting_key = &key.key_switching_key_material;
let sks = &key.dest_server_key;
let computing_ks_key = &key.dest_server_key.key_switching_key;
let casting_key_type: KsType = key.destination_key.into();
let casting_key_type: KsType = casting_key.destination_key.into();
match &sks.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
expand_async(
streams,
&mut d_output,
&d_input,
d_input,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
&casting_key.d_vec,
&casting_key.lwe_keyswitch_key.d_vec,
sks.message_modulus,
sks.carry_modulus,
d_bsk.glwe_dimension(),
@@ -220,10 +236,16 @@ impl CudaProvenCompactCiphertextList {
d_bsk.input_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
casting_key.input_key_lwe_size().to_lwe_dimension(),
casting_key.output_key_lwe_size().to_lwe_dimension(),
casting_key.decomposition_level_count(),
casting_key.decomposition_base_log(),
casting_key
.lwe_keyswitch_key
.input_key_lwe_size()
.to_lwe_dimension(),
casting_key
.lwe_keyswitch_key
.output_key_lwe_size()
.to_lwe_dimension(),
casting_key.lwe_keyswitch_key.decomposition_level_count(),
casting_key.lwe_keyswitch_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -238,13 +260,10 @@ impl CudaProvenCompactCiphertextList {
expand_async(
streams,
&mut d_output,
&CudaProvenCompactCiphertextList::flatten_async(
self.d_compact_lists.as_slice(),
streams,
),
&Self::flatten_async(self.d_compact_lists.as_slice(), streams),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
&casting_key.d_vec,
&casting_key.lwe_keyswitch_key.d_vec,
sks.message_modulus,
sks.carry_modulus,
d_multibit_bsk.glwe_dimension(),
@@ -252,10 +271,16 @@ impl CudaProvenCompactCiphertextList {
d_multibit_bsk.input_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
casting_key.input_key_lwe_size().to_lwe_dimension(),
casting_key.output_key_lwe_size().to_lwe_dimension(),
casting_key.decomposition_level_count(),
casting_key.decomposition_base_log(),
casting_key
.lwe_keyswitch_key
.input_key_lwe_size()
.to_lwe_dimension(),
casting_key
.lwe_keyswitch_key
.output_key_lwe_size()
.to_lwe_dimension(),
casting_key.lwe_keyswitch_key.decomposition_level_count(),
casting_key.lwe_keyswitch_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -294,7 +319,7 @@ impl CudaProvenCompactCiphertextList {
})
.collect();
CudaProvenCompactCiphertextList {
Self {
h_proved_lists: h_proved_lists.clone(),
d_compact_lists,
}
@@ -328,6 +353,18 @@ impl CudaProvenCompactCiphertextList {
}
}
impl<'de> serde::Deserialize<'de> for CudaProvenCompactCiphertextList {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let cpu_ct = ProvenCompactCiphertextList::deserialize(deserializer)?;
let streams = CudaStreams::new_multi_gpu();
Ok(Self::from_proven_compact_ciphertext_list(&cpu_ct, &streams))
}
}
#[cfg(feature = "zk-pok")]
#[cfg(test)]
mod tests {
@@ -344,9 +381,12 @@ mod tests {
use crate::integer::ciphertext::{CompactCiphertextList, DataKind};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey;
use crate::integer::gpu::key_switching_key::{
CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial,
};
use crate::integer::gpu::zk::CudaProvenCompactCiphertextList;
use crate::integer::gpu::CudaServerKey;
use crate::integer::key_switching_key::KeySwitchingKey;
use crate::integer::{
ClientKey, CompactPrivateKey, CompactPublicKey, CompressedServerKey,
ProvenCompactCiphertextList,
@@ -405,15 +445,16 @@ mod tests {
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
let streams = CudaStreams::new_multi_gpu();
let sk = compressed_server_key.decompress();
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
let compact_private_key = CompactPrivateKey::new(pke_params);
let d_ksk = CudaKeySwitchingKey::new(
(&compact_private_key, None),
(&cks, &gpu_sk),
ksk_params,
&streams,
);
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params);
let d_ksk_material =
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material, &gpu_sk);
let pk = CompactPublicKey::new(&compact_private_key);
let msgs = (0..512)
@@ -493,17 +534,17 @@ mod tests {
.unwrap();
let cks = ClientKey::new(fhe_params);
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
let sk = compressed_server_key.decompress();
let streams = CudaStreams::new_multi_gpu();
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
let compact_private_key = CompactPrivateKey::new(pke_params);
let d_ksk = CudaKeySwitchingKey::new(
(&compact_private_key, None),
(&cks, &gpu_sk),
ksk_params,
&streams,
);
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params);
let d_ksk_material =
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material, &gpu_sk);
let pk = CompactPublicKey::new(&compact_private_key);
let msgs = (0..2).map(|_| random::<u64>()).collect::<Vec<_>>();
@@ -583,17 +624,17 @@ mod tests {
.unwrap();
let cks = ClientKey::new(fhe_params);
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
let sk = compressed_server_key.decompress();
let streams = CudaStreams::new_multi_gpu();
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
let compact_private_key = CompactPrivateKey::new(pke_params);
let d_ksk = CudaKeySwitchingKey::new(
(&compact_private_key, None),
(&cks, &gpu_sk),
ksk_params,
&streams,
);
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params);
let d_ksk_material =
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material, &gpu_sk);
let pk = CompactPublicKey::new(&compact_private_key);
let msgs = (0..2).map(|_| random::<u64>()).collect::<Vec<_>>();

View File

@@ -97,8 +97,8 @@ impl ServerKey {
// It should remain just enough space to add a carry
let client_key = cks.as_ref();
let max_degree = MaxDegree::integer_radix_server_key(
client_key.key.parameters.message_modulus(),
client_key.key.parameters.carry_modulus(),
client_key.key.parameters().message_modulus(),
client_key.key.parameters().carry_modulus(),
);
let sks = crate::shortint::server_key::ServerKey::new_with_max_degree(
@@ -115,8 +115,8 @@ impl ServerKey {
{
let client_key = cks.as_ref();
let max_degree = MaxDegree::integer_crt_server_key(
client_key.key.parameters.message_modulus(),
client_key.key.parameters.carry_modulus(),
client_key.key.parameters().message_modulus(),
client_key.key.parameters().carry_modulus(),
);
let sks = crate::shortint::server_key::ServerKey::new_with_max_degree(
@@ -255,8 +255,8 @@ pub struct CompressedServerKey {
impl CompressedServerKey {
pub fn new_radix_compressed_server_key(client_key: &ClientKey) -> Self {
let max_degree = MaxDegree::integer_radix_server_key(
client_key.key.parameters.message_modulus(),
client_key.key.parameters.carry_modulus(),
client_key.key.parameters().message_modulus(),
client_key.key.parameters().carry_modulus(),
);
let key =

View File

@@ -3,4 +3,4 @@ pub(crate) mod test_random_op_sequence;
pub(crate) mod test_signed_erc20;
pub(crate) mod test_signed_random_op_sequence;
pub(crate) const NB_CTXT_LONG_RUN: usize = 32;
pub(crate) const NB_TESTS_LONG_RUN: usize = 20000;
pub(crate) const NB_TESTS_LONG_RUN: usize = 200;

View File

@@ -588,6 +588,9 @@ pub(crate) fn random_op_sequence_test<P>(
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let input_degrees_right: Vec<u64> =
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Input degrees left: {input_degrees_left:?}, right {input_degrees_right:?}, Output degrees {:?}", output_degrees);
let decrypted_res: u64 = cks.decrypt(&res);
let expected_res: u64 = clear_fn(clear_left, clear_right);

View File

@@ -428,7 +428,7 @@ where
// Div executor
let div_rem_executor = CpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
// Div Rem Clear functions
let clear_div_rem = |x: i64, y: i64| -> (i64, i64) { (x.wrapping_div(y), x % y) };
let clear_div_rem = |x: i64, y: i64| -> (i64, i64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
#[allow(clippy::type_complexity)]
let mut div_rem_op: Vec<(
SignedDivRemOpExecutor,
@@ -680,6 +680,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let input_degrees_right: Vec<u64> =
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Input degrees left: {input_degrees_left:?}, right {input_degrees_right:?}, Output degrees {:?}", output_degrees);
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
let expected_res: i64 = clear_fn(clear_left, clear_right);
@@ -731,6 +734,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
"Determinism check failed on unary op {fn_name} with clear input {clear_input}.",
);
let input_degrees: Vec<u64> = input.blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Output degrees {:?}", output_degrees);
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
let expected_res: i64 = clear_fn(clear_input);
if i % 2 == 0 {
@@ -774,6 +780,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
);
let input_degrees_left: Vec<u64> =
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Output degrees {:?}", output_degrees);
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
let expected_res: i64 = clear_fn(clear_left, clear_right);
@@ -829,6 +838,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let input_degrees_right: Vec<u64> =
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Output degrees {:?}", output_degrees);
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
let decrypt_signed_overflow = cks.decrypt_bool(&overflow);
let (expected_res, expected_overflow) = clear_fn(clear_left, clear_right);
@@ -889,6 +901,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
);
let input_degrees_left: Vec<u64> =
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Output degrees {:?}", output_degrees);
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
let decrypt_signed_overflow = cks.decrypt_bool(&overflow);
let (expected_res, expected_overflow) = clear_fn(clear_left, clear_right);
@@ -1020,6 +1035,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let input_degrees_right: Vec<u64> =
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Output degrees {:?}", output_degrees);
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
let expected_res = clear_fn(clear_bool, clear_left, clear_right);
@@ -1081,6 +1099,12 @@ pub(crate) fn signed_random_op_sequence_test<P>(
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let input_degrees_right: Vec<u64> =
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let output_degrees_q: Vec<u64> =
res_q.blocks.iter().map(|b| b.degree.0).collect();
let output_degrees_r: Vec<u64> =
res_r.blocks.iter().map(|b| b.degree.0).collect();
println!("Output degrees {:?}", output_degrees_q);
println!("Output degrees {:?}", output_degrees_r);
let decrypt_signed_res_q: i64 = cks.decrypt_signed(&res_q);
let decrypt_signed_res_r: i64 = cks.decrypt_signed(&res_r);
let (expected_res_q, expected_res_r) = clear_fn(clear_left, clear_right);
@@ -1147,6 +1171,12 @@ pub(crate) fn signed_random_op_sequence_test<P>(
);
let input_degrees_left: Vec<u64> =
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let output_q_degrees: Vec<u64> =
res_r.blocks.iter().map(|b| b.degree.0).collect();
let output_r_degrees: Vec<u64> =
res_r.blocks.iter().map(|b| b.degree.0).collect();
println!("Output r degrees {:?}", output_r_degrees);
println!("Output q degrees {:?}", output_q_degrees);
let decrypt_signed_res_q: i64 = cks.decrypt_signed(&res_q);
let decrypt_signed_res_r: i64 = cks.decrypt_signed(&res_r);
let (expected_res_q, expected_res_r) = clear_fn(clear_left, clear_right);
@@ -1205,6 +1235,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
"Determinism check failed on op {fn_name} with clear input {clear_input}.",
);
let input_degrees: Vec<u64> = input.blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Output degrees {:?}", output_degrees);
let cast_res = sks.cast_to_signed(res, NB_CTXT_LONG_RUN);
let decrypt_signed_res: i64 = cks.decrypt_signed(&cast_res);
let expected_res = clear_fn(clear_input) as i64;
@@ -1252,6 +1285,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let input_degrees_right: Vec<u64> =
unsigned_right.blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Output degrees {:?}", output_degrees);
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
let expected_res: i64 = clear_fn(clear_left, clear_right as u64);
@@ -1297,6 +1333,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
);
let input_degrees_left: Vec<u64> =
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
let output_degrees: Vec<u64> =
res.blocks.iter().map(|b| b.degree.0).collect();
println!("Output degrees {:?}", output_degrees);
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
let expected_res: i64 = clear_fn(clear_left, clear_right as u64);

View File

@@ -349,7 +349,7 @@ where
{
let cks = cks.as_ref();
let max_degree_acceptable = cks.key.parameters.message_modulus().0 - 1;
let max_degree_acceptable = cks.key.parameters().message_modulus().0 - 1;
let num_blocks = ct.blocks.len();
for (i, block) in ct.blocks.iter().enumerate() {
@@ -385,7 +385,7 @@ where
{
let cks = cks.as_ref();
let max_degree_acceptable = cks.key.parameters.message_modulus().0 - 1;
let max_degree_acceptable = cks.key.parameters().message_modulus().0 - 1;
for (i, block) in ct.blocks.iter().enumerate() {
if block.is_trivial() {

View File

@@ -24,6 +24,8 @@ mod experimental {
use crate::integer::{ClientKey, CrtCiphertext, IntegerCiphertext, RadixCiphertext, ServerKey};
use crate::shortint::atomic_pattern::AtomicPattern;
use crate::shortint::ciphertext::{Degree, NoiseLevel};
use crate::shortint::client_key::atomic_pattern::EncryptionAtomicPattern;
use crate::shortint::client_key::StandardClientKeyView;
use crate::shortint::server_key::StandardServerKeyView;
use crate::shortint::WopbsParameters;
@@ -238,12 +240,17 @@ mod experimental {
sks.key.atomic_pattern.kind()
)
});
let cks = cks.as_ref();
let ck = StandardClientKeyView::try_from(cks.key.as_view()).unwrap_or_else(|_| {
panic!(
"Wopbs is not supported by the chosen encryption atomic pattern: {:?}",
cks.key.atomic_pattern.kind()
)
});
Self {
wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key(
&cks.as_ref().key,
sk,
parameters,
),
wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key(ck, sk, parameters),
}
}
@@ -268,11 +275,16 @@ mod experimental {
)
});
let cks = cks.as_ref();
let ck = StandardClientKeyView::try_from(cks.key.as_view()).unwrap_or_else(|_| {
panic!(
"Wopbs is not supported by the chosen encryption atomic pattern: {:?}",
cks.key.atomic_pattern.kind()
)
});
Self {
wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(
&cks.as_ref().key,
sk,
),
wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(ck, sk),
}
}

View File

@@ -563,7 +563,7 @@ impl Shortint {
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(Seed(seed));
ShortintClientKey(
crate::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder)
.new_client_key(parameters.0.into()),
.new_client_key(parameters.0),
)
}

View File

@@ -1,10 +1,9 @@
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::algorithms::lwe_keyswitch_key_generation::allocate_and_generate_new_seeded_lwe_keyswitch_key;
use crate::core_crypto::entities::lwe_secret_key::LweSecretKey;
use crate::core_crypto::entities::seeded_lwe_keyswitch_key::SeededLweKeyswitchKeyOwned;
use crate::shortint::atomic_pattern::ks32::KS32AtomicPatternServerKey;
use crate::shortint::backward_compatibility::atomic_pattern::CompressedKS32AtomicPatternServerKeyVersions;
use crate::shortint::client_key::ClientKey;
use crate::shortint::client_key::atomic_pattern::KS32AtomicPatternClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{KeySwitch32PBSParameters, LweDimension};
use crate::shortint::server_key::ShortintCompressedBootstrappingKey;
@@ -22,24 +21,15 @@ pub struct CompressedKS32AtomicPatternServerKey {
}
impl CompressedKS32AtomicPatternServerKey {
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
pub fn new(cks: &KS32AtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
let params = &cks.parameters;
let pbs_params = params.ks32_parameters().unwrap();
let in_key = LweSecretKey::from_container(
cks.small_lwe_secret_key()
.as_ref()
.iter()
.copied()
.map(|x| x as u32)
.collect::<Vec<_>>(),
);
let in_key = cks.small_lwe_secret_key();
let out_key = &cks.glwe_secret_key;
let bootstrapping_key_base =
engine.new_compressed_bootstrapping_key_ks32(pbs_params, &in_key, out_key);
engine.new_compressed_bootstrapping_key_ks32(*params, &in_key, out_key);
// Creation of the key switching key
let key_switching_key = allocate_and_generate_new_seeded_lwe_keyswitch_key(
@@ -47,8 +37,8 @@ impl CompressedKS32AtomicPatternServerKey {
&in_key,
params.ks_base_log(),
params.ks_level(),
pbs_params.lwe_noise_distribution(),
pbs_params.post_keyswitch_ciphertext_modulus(),
params.lwe_noise_distribution(),
params.post_keyswitch_ciphertext_modulus(),
&mut engine.seeder,
);

View File

@@ -6,6 +6,7 @@ pub use standard::*;
use super::AtomicPatternServerKey;
use crate::conformance::ParameterSetConformant;
use crate::shortint::backward_compatibility::atomic_pattern::CompressedAtomicPatternServerKeyVersions;
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::client_key::ClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{AtomicPatternParameters, CiphertextModulus, LweDimension};
@@ -24,14 +25,12 @@ pub enum CompressedAtomicPatternServerKey {
impl CompressedAtomicPatternServerKey {
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
let params = &cks.parameters;
match params.ap_parameters().unwrap() {
AtomicPatternParameters::Standard(_) => {
Self::Standard(CompressedStandardAtomicPatternServerKey::new(cks, engine))
}
AtomicPatternParameters::KeySwitch32(_) => {
Self::KeySwitch32(CompressedKS32AtomicPatternServerKey::new(cks, engine))
match &cks.atomic_pattern {
AtomicPatternClientKey::Standard(ap_cks) => Self::Standard(
CompressedStandardAtomicPatternServerKey::new(ap_cks, engine),
),
AtomicPatternClientKey::KeySwitch32(ap_cks) => {
Self::KeySwitch32(CompressedKS32AtomicPatternServerKey::new(ap_cks, engine))
}
}
}

View File

@@ -3,7 +3,7 @@ use crate::core_crypto::algorithms::lwe_keyswitch_key_generation::allocate_and_g
use crate::core_crypto::entities::seeded_lwe_keyswitch_key::SeededLweKeyswitchKeyOwned;
use crate::shortint::atomic_pattern::standard::StandardAtomicPatternServerKey;
use crate::shortint::backward_compatibility::atomic_pattern::CompressedStandardAtomicPatternServerKeyVersions;
use crate::shortint::client_key::ClientKey;
use crate::shortint::client_key::atomic_pattern::StandardAtomicPatternClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{CiphertextModulus, LweDimension, PBSOrder, PBSParameters};
use crate::shortint::server_key::ShortintCompressedBootstrappingKey;
@@ -22,17 +22,15 @@ pub struct CompressedStandardAtomicPatternServerKey {
}
impl CompressedStandardAtomicPatternServerKey {
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
pub fn new(cks: &StandardAtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
let params = &cks.parameters;
let pbs_params_base = params.pbs_parameters().unwrap();
let in_key = &cks.small_lwe_secret_key();
let out_key = &cks.glwe_secret_key;
let bootstrapping_key_base =
engine.new_compressed_bootstrapping_key(pbs_params_base, in_key, out_key);
engine.new_compressed_bootstrapping_key(*params, in_key, out_key);
// Creation of the key switching key
let key_switching_key = allocate_and_generate_new_seeded_lwe_keyswitch_key(
@@ -48,7 +46,7 @@ impl CompressedStandardAtomicPatternServerKey {
Self::from_raw_parts(
key_switching_key,
bootstrapping_key_base,
pbs_params_base.encryption_key_choice().into(),
params.encryption_key_choice().into(),
)
}

View File

@@ -10,11 +10,12 @@ use crate::conformance::ParameterSetConformant;
use crate::core_crypto::prelude::{
allocate_and_generate_new_lwe_keyswitch_key, extract_lwe_sample_from_glwe_ciphertext,
keyswitch_lwe_ciphertext_with_scalar_change, CiphertextModulus as CoreCiphertextModulus,
LweCiphertext, LweCiphertextOwned, LweDimension, LweKeyswitchKeyOwned, LweSecretKey,
MonomialDegree, MsDecompressionType,
LweCiphertext, LweCiphertextOwned, LweDimension, LweKeyswitchKeyOwned, MonomialDegree,
MsDecompressionType,
};
use crate::shortint::backward_compatibility::atomic_pattern::KS32AtomicPatternServerKeyVersions;
use crate::shortint::ciphertext::{CompressedModulusSwitchedCiphertext, Degree, NoiseLevel};
use crate::shortint::client_key::atomic_pattern::KS32AtomicPatternClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::oprf::generate_pseudo_random_from_pbs;
use crate::shortint::parameters::KeySwitch32PBSParameters;
@@ -22,7 +23,7 @@ use crate::shortint::server_key::{
decompress_and_apply_lookup_table, switch_modulus_and_compress, LookupTableOwned,
LookupTableSize, ManyLookupTableOwned, ShortintBootstrappingKey,
};
use crate::shortint::{Ciphertext, CiphertextModulus, ClientKey};
use crate::shortint::{Ciphertext, CiphertextModulus};
/// The definition of the server key elements used in the
/// [`KeySwitch32`](AtomicPatternKind::KeySwitch32) atomic pattern
@@ -57,24 +58,14 @@ impl ParameterSetConformant for KS32AtomicPatternServerKey {
}
impl KS32AtomicPatternServerKey {
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
pub fn new(cks: &KS32AtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
let params = &cks.parameters;
let pbs_params = params.ks32_parameters().unwrap();
let in_key = LweSecretKey::from_container(
cks.small_lwe_secret_key()
.as_ref()
.iter()
.copied()
.map(|x| x as u32)
.collect::<Vec<_>>(),
);
let in_key = cks.small_lwe_secret_key();
let out_key = &cks.glwe_secret_key;
let bootstrapping_key_base =
engine.new_bootstrapping_key_ks32(pbs_params, &in_key, out_key);
let bootstrapping_key_base = engine.new_bootstrapping_key_ks32(*params, &in_key, out_key);
// Creation of the key switching key
let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
@@ -82,8 +73,8 @@ impl KS32AtomicPatternServerKey {
&in_key,
params.ks_base_log(),
params.ks_level(),
pbs_params.lwe_noise_distribution(),
pbs_params.post_keyswitch_ciphertext_modulus(),
params.lwe_noise_distribution(),
params.post_keyswitch_ciphertext_modulus(),
&mut engine.encryption_generator,
);

View File

@@ -21,6 +21,7 @@ use crate::core_crypto::prelude::{
use super::backward_compatibility::atomic_pattern::*;
use super::ciphertext::{CompressedModulusSwitchedCiphertext, Degree};
use super::client_key::atomic_pattern::AtomicPatternClientKey;
use super::engine::ShortintEngine;
use super::parameters::{DynamicDistribution, KeySwitch32PBSParameters};
use super::prelude::{DecompositionBaseLog, DecompositionLevelCount};
@@ -259,14 +260,12 @@ pub enum AtomicPatternServerKey {
impl AtomicPatternServerKey {
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
let params = &cks.parameters;
match params.ap_parameters().unwrap() {
AtomicPatternParameters::Standard(_) => {
Self::Standard(StandardAtomicPatternServerKey::new(cks, engine))
match &cks.atomic_pattern {
AtomicPatternClientKey::Standard(ap_cks) => {
Self::Standard(StandardAtomicPatternServerKey::new(ap_cks, engine))
}
AtomicPatternParameters::KeySwitch32(_) => {
Self::KeySwitch32(KS32AtomicPatternServerKey::new(cks, engine))
AtomicPatternClientKey::KeySwitch32(ap_cks) => {
Self::KeySwitch32(KS32AtomicPatternServerKey::new(ap_cks, engine))
}
}
}

View File

@@ -14,13 +14,14 @@ use crate::core_crypto::prelude::{
};
use crate::shortint::backward_compatibility::atomic_pattern::StandardAtomicPatternServerKeyVersions;
use crate::shortint::ciphertext::{CompressedModulusSwitchedCiphertext, Degree, NoiseLevel};
use crate::shortint::client_key::atomic_pattern::StandardAtomicPatternClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::oprf::generate_pseudo_random_from_pbs;
use crate::shortint::server_key::{
decompress_and_apply_lookup_table, switch_modulus_and_compress, LookupTableOwned,
LookupTableSize, ManyLookupTableOwned, ShortintBootstrappingKey,
};
use crate::shortint::{Ciphertext, CiphertextModulus, ClientKey, PBSOrder, PBSParameters};
use crate::shortint::{Ciphertext, CiphertextModulus, PBSOrder, PBSParameters};
/// The definition of the server key elements used in the [`Standard`](AtomicPatternKind::Standard)
/// atomic pattern
@@ -58,16 +59,14 @@ impl ParameterSetConformant for StandardAtomicPatternServerKey {
}
impl StandardAtomicPatternServerKey {
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
pub fn new(cks: &StandardAtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
let params = &cks.parameters;
let pbs_params_base = params.pbs_parameters().unwrap();
let in_key = &cks.small_lwe_secret_key();
let out_key = &cks.glwe_secret_key;
let bootstrapping_key_base = engine.new_bootstrapping_key(pbs_params_base, in_key, out_key);
let bootstrapping_key_base = engine.new_bootstrapping_key(*params, in_key, out_key);
// Creation of the key switching key
let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
@@ -83,7 +82,7 @@ impl StandardAtomicPatternServerKey {
Self::from_raw_parts(
key_switching_key,
bootstrapping_key_base,
pbs_params_base.encryption_key_choice().into(),
params.encryption_key_choice().into(),
)
}

View File

@@ -0,0 +1,20 @@
use tfhe_versionable::VersionsDispatch;
use crate::shortint::client_key::atomic_pattern::{
AtomicPatternClientKey, KS32AtomicPatternClientKey, StandardAtomicPatternClientKey,
};
#[derive(VersionsDispatch)]
pub enum AtomicPatternClientKeyVersions {
V0(AtomicPatternClientKey),
}
#[derive(VersionsDispatch)]
pub enum StandardAtomicPatternClientKeyVersions {
V0(StandardAtomicPatternClientKey),
}
#[derive(VersionsDispatch)]
pub enum KS32AtomicPatternClientKeyVersions {
V0(KS32AtomicPatternClientKey),
}

View File

@@ -1,8 +1,67 @@
use tfhe_versionable::VersionsDispatch;
pub mod atomic_pattern;
use crate::shortint::ClientKey;
use std::any::{Any, TypeId};
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
use crate::core_crypto::prelude::{GlweSecretKeyOwned, LweSecretKeyOwned};
use crate::shortint::client_key::atomic_pattern::{
AtomicPatternClientKey, StandardAtomicPatternClientKey,
};
use crate::shortint::client_key::GenericClientKey;
use crate::shortint::ShortintParameterSet;
use crate::Error;
#[derive(Version)]
pub struct ClientKeyV0 {
glwe_secret_key: GlweSecretKeyOwned<u64>,
lwe_secret_key: LweSecretKeyOwned<u64>,
parameters: ShortintParameterSet,
}
impl<AP: 'static> Upgrade<GenericClientKey<AP>> for ClientKeyV0 {
type Error = Error;
fn upgrade(self) -> Result<GenericClientKey<AP>, Self::Error> {
let ap_params = self.parameters.pbs_parameters().ok_or_else(|| {
Error::new(
"ClientKey from TFHE-rs 1.2 and before needs PBS parameters to be upgraded to the latest version"
.to_string(),
)
})?;
let std_ap = StandardAtomicPatternClientKey::from_raw_parts(
self.glwe_secret_key,
self.lwe_secret_key,
ap_params,
self.parameters.wopbs_parameters(),
);
if TypeId::of::<AP>() == TypeId::of::<AtomicPatternClientKey>() {
let atomic_pattern = AtomicPatternClientKey::Standard(std_ap);
let ck: Box<dyn Any + 'static> = Box::new(GenericClientKey { atomic_pattern });
Ok(*ck.downcast::<GenericClientKey<AP>>().unwrap()) // We know from the TypeId that
// AP is of the right type so we
// can unwrap
} else if TypeId::of::<AP>() == TypeId::of::<StandardAtomicPatternClientKey>() {
let ck: Box<dyn Any + 'static> = Box::new(GenericClientKey {
atomic_pattern: std_ap,
});
Ok(*ck.downcast::<GenericClientKey<AP>>().unwrap()) // We know from the TypeId that
// AP is of the right type so we
// can unwrap
} else {
Err(Error::new(
"ClientKey from TFHE-rs 1.2 and before can only be deserialized to the standard \
Atomic Pattern"
.to_string(),
))
}
}
}
#[derive(VersionsDispatch)]
pub enum ClientKeyVersions {
V0(ClientKey),
pub enum ClientKeyVersions<AP> {
V0(ClientKeyV0),
V1(GenericClientKey<AP>),
}

View File

@@ -82,7 +82,7 @@ pub struct ServerKeyV1 {
pbs_order: PBSOrder,
}
impl<AP: Clone + 'static> Upgrade<GenericServerKey<AP>> for ServerKeyV1 {
impl<AP: 'static> Upgrade<GenericServerKey<AP>> for ServerKeyV1 {
type Error = Error;
fn upgrade(self) -> Result<GenericServerKey<AP>, Self::Error> {
@@ -94,32 +94,30 @@ impl<AP: Clone + 'static> Upgrade<GenericServerKey<AP>> for ServerKeyV1 {
if TypeId::of::<AP>() == TypeId::of::<AtomicPatternServerKey>() {
let ap = AtomicPatternServerKey::Standard(std_ap);
let sk = ServerKey::from_raw_parts(
let sk: Box<dyn Any + 'static> = Box::new(ServerKey::from_raw_parts(
ap,
self.message_modulus,
self.carry_modulus,
self.max_degree,
self.max_noise_level,
);
Ok((&sk as &dyn Any)
.downcast_ref::<GenericServerKey<AP>>()
.unwrap() // We know from the TypeId that AP is of the right type so we can unwrap
.clone())
));
Ok(*sk.downcast::<GenericServerKey<AP>>().unwrap()) // We know from the TypeId that
// AP is of the right type so we
// can unwrap
} else if TypeId::of::<AP>() == TypeId::of::<StandardAtomicPatternServerKey>() {
let sk = StandardServerKey::from_raw_parts(
let sk: Box<dyn Any + 'static> = Box::new(StandardServerKey::from_raw_parts(
std_ap,
self.message_modulus,
self.carry_modulus,
self.max_degree,
self.max_noise_level,
);
Ok((&sk as &dyn Any)
.downcast_ref::<GenericServerKey<AP>>()
.unwrap() // We know from the TypeId that AP is of the right type so we can unwrap
.clone())
));
Ok(*sk.downcast::<GenericServerKey<AP>>().unwrap()) // We know from the TypeId that
// AP is of the right type so we
// can unwrap
} else {
Err(Error::new(
"ServerKey from TFHE-rs 1.0 and before can only be deserialized to the classical \
"ServerKey from TFHE-rs 1.0 and before can only be deserialized to the standard \
Atomic Pattern"
.to_string(),
))

View File

@@ -26,21 +26,20 @@ where
type Error = Error;
fn upgrade(self) -> Result<ModulusSwitchNoiseReductionKey<InputScalar>, Self::Error> {
let modulus_switch_zeros = &self.modulus_switch_zeros as &dyn Any;
let modulus_switch_zeros: Box<dyn Any> = Box::new(self.modulus_switch_zeros);
// Keys from previous versions where only stored as u64, we check if the destination
// key is also u64 or we return an error
Ok(ModulusSwitchNoiseReductionKey {
modulus_switch_zeros: modulus_switch_zeros
.downcast_ref::<LweCiphertextListOwned<InputScalar>>()
.ok_or_else(|| {
modulus_switch_zeros: *modulus_switch_zeros
.downcast::<LweCiphertextListOwned<InputScalar>>()
.map_err(|_| {
Error::new(format!(
"Expected u64 as InputScalar while upgrading \
ModulusSwitchNoiseReductionKey, got {}",
std::any::type_name::<InputScalar>(),
))
})?
.clone(),
})?,
ms_bound: self.ms_bound,
ms_r_sigma_factor: self.ms_r_sigma_factor,
ms_input_variance: self.ms_input_variance,
@@ -73,21 +72,20 @@ where
type Error = Error;
fn upgrade(self) -> Result<CompressedModulusSwitchNoiseReductionKey<InputScalar>, Self::Error> {
let modulus_switch_zeros = &self.modulus_switch_zeros as &dyn Any;
let modulus_switch_zeros: Box<dyn Any> = Box::new(self.modulus_switch_zeros);
// Keys from previous versions where only stored as u64, we check if the destination
// key is also u64 or we return an error
Ok(CompressedModulusSwitchNoiseReductionKey {
modulus_switch_zeros: modulus_switch_zeros
.downcast_ref::<SeededLweCiphertextListOwned<InputScalar>>()
.ok_or_else(|| {
modulus_switch_zeros: *modulus_switch_zeros
.downcast::<SeededLweCiphertextListOwned<InputScalar>>()
.map_err(|_| {
Error::new(format!(
"Expected u64 as InputScalar while upgrading \
CompressedModulusSwitchNoiseReductionKey, got {}",
std::any::type_name::<InputScalar>(),
))
})?
.clone(),
})?,
ms_bound: self.ms_bound,
ms_r_sigma_factor: self.ms_r_sigma_factor,
ms_input_variance: self.ms_input_variance,

View File

@@ -0,0 +1,169 @@
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
use crate::core_crypto::prelude::{
allocate_and_generate_new_binary_glwe_secret_key,
allocate_and_generate_new_binary_lwe_secret_key,
};
use crate::shortint::backward_compatibility::client_key::atomic_pattern::KS32AtomicPatternClientKeyVersions;
use crate::shortint::client_key::{GlweSecretKeyOwned, LweSecretKeyOwned, LweSecretKeyView};
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{DynamicDistribution, KeySwitch32PBSParameters};
use crate::shortint::{AtomicPatternKind, ShortintParameterSet};
use super::EncryptionAtomicPattern;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(KS32AtomicPatternClientKeyVersions)]
pub struct KS32AtomicPatternClientKey {
pub(crate) glwe_secret_key: GlweSecretKeyOwned<u64>,
/// Key used as the output of the keyswitch operation
pub(crate) lwe_secret_key: LweSecretKeyOwned<u32>,
pub parameters: KeySwitch32PBSParameters,
}
impl KS32AtomicPatternClientKey {
pub(crate) fn new_with_engine(
parameters: KeySwitch32PBSParameters,
engine: &mut ShortintEngine,
) -> Self {
// generate the lwe secret key
let lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
parameters.lwe_dimension(),
&mut engine.secret_generator,
);
// generate the rlwe secret key
let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
parameters.glwe_dimension(),
parameters.polynomial_size(),
&mut engine.secret_generator,
);
// pack the keys in the client key set
Self {
glwe_secret_key,
lwe_secret_key,
parameters,
}
}
pub fn new(parameters: KeySwitch32PBSParameters) -> Self {
ShortintEngine::with_thread_local_mut(|engine| Self::new_with_engine(parameters, engine))
}
pub fn into_raw_parts(
self,
) -> (
GlweSecretKeyOwned<u64>,
LweSecretKeyOwned<u32>,
KeySwitch32PBSParameters,
) {
let Self {
glwe_secret_key,
lwe_secret_key,
parameters,
} = self;
(glwe_secret_key, lwe_secret_key, parameters)
}
pub fn from_raw_parts(
glwe_secret_key: GlweSecretKeyOwned<u64>,
lwe_secret_key: LweSecretKeyOwned<u32>,
parameters: KeySwitch32PBSParameters,
) -> Self {
assert_eq!(
lwe_secret_key.lwe_dimension(),
parameters.lwe_dimension(),
"Mismatch between the LweSecretKey LweDimension ({:?}) \
and the parameters LweDimension ({:?})",
lwe_secret_key.lwe_dimension(),
parameters.lwe_dimension()
);
assert_eq!(
glwe_secret_key.glwe_dimension(),
parameters.glwe_dimension(),
"Mismatch between the GlweSecretKey GlweDimension ({:?}) \
and the parameters GlweDimension ({:?})",
glwe_secret_key.glwe_dimension(),
parameters.glwe_dimension()
);
assert_eq!(
glwe_secret_key.polynomial_size(),
parameters.polynomial_size(),
"Mismatch between the GlweSecretKey PolynomialSize ({:?}) \
and the parameters PolynomialSize ({:?})",
glwe_secret_key.polynomial_size(),
parameters.polynomial_size()
);
Self {
glwe_secret_key,
lwe_secret_key,
parameters,
}
}
pub fn try_from_lwe_encryption_key(
encryption_key: LweSecretKeyOwned<u64>,
parameters: KeySwitch32PBSParameters,
) -> crate::Result<Self> {
let expected_lwe_dimension = parameters.encryption_lwe_dimension();
if encryption_key.lwe_dimension() != expected_lwe_dimension {
return Err(
crate::Error::new(
format!(
"The given encryption key does not have the correct LweDimension, expected: {:?}, got: {:?}",
encryption_key.lwe_dimension(),
expected_lwe_dimension)));
}
// The key we got is the one used to encrypt,
// we have to generate the other key. The KS32 ap only support KS-PBS order so we need to
// generate the small key.
let small_key = ShortintEngine::with_thread_local_mut(|engine| {
allocate_and_generate_new_binary_lwe_secret_key(
parameters.lwe_dimension(),
&mut engine.secret_generator,
)
});
Ok(Self {
glwe_secret_key: GlweSecretKeyOwned::from_container(
encryption_key.into_container(),
parameters.polynomial_size(),
),
lwe_secret_key: small_key,
parameters,
})
}
pub fn large_lwe_secret_key(&self) -> LweSecretKeyView<'_, u64> {
self.glwe_secret_key.as_lwe_secret_key()
}
pub fn small_lwe_secret_key(&self) -> LweSecretKeyView<'_, u32> {
self.lwe_secret_key.as_view()
}
}
impl EncryptionAtomicPattern for KS32AtomicPatternClientKey {
fn parameters(&self) -> ShortintParameterSet {
self.parameters.into()
}
fn encryption_key(&self) -> LweSecretKeyView<'_, u64> {
// The KS32 atomic pattern is only supported with the KsPbs order
self.glwe_secret_key.as_lwe_secret_key()
}
fn encryption_noise(&self) -> DynamicDistribution<u64> {
// The KS32 atomic pattern is only supported with the KsPbs order
self.parameters.glwe_noise_distribution()
}
fn kind(&self) -> AtomicPatternKind {
AtomicPatternKind::KeySwitch32
}
}

View File

@@ -0,0 +1,135 @@
pub mod ks32;
pub mod standard;
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
use crate::shortint::backward_compatibility::client_key::atomic_pattern::AtomicPatternClientKeyVersions;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::DynamicDistribution;
use crate::shortint::{AtomicPatternKind, AtomicPatternParameters, ShortintParameterSet};
use super::{LweSecretKeyOwned, LweSecretKeyView};
pub use ks32::*;
pub use standard::*;
/// An atomic pattern used for encryption
///
/// This category of atomic patterns will be used by the [`ClientKey`](super::ClientKey) to encrypt
/// ciphertexts, and to generate a [`ServerKey`](crate::shortint::ServerKey) that can be used for
/// evaluation.
pub trait EncryptionAtomicPattern {
/// The parameters associated with this client key
fn parameters(&self) -> ShortintParameterSet;
/// The secret key used for encryption
fn encryption_key(&self) -> LweSecretKeyView<'_, u64>;
/// The noise distribution used for encryption
fn encryption_noise(&self) -> DynamicDistribution<u64>;
/// The kind of atomic pattern that will be used by the generated
/// [`ServerKey`](crate::shortint::ServerKey)
fn kind(&self) -> AtomicPatternKind;
fn encryption_key_and_noise(&self) -> (LweSecretKeyView<'_, u64>, DynamicDistribution<u64>) {
(self.encryption_key(), self.encryption_noise())
}
}
// This blancket impl is used to allow "views" of client keys, without having to re-implement the
// trait
impl<T: EncryptionAtomicPattern> EncryptionAtomicPattern for &T {
fn parameters(&self) -> ShortintParameterSet {
(*self).parameters()
}
fn encryption_key(&self) -> LweSecretKeyView<'_, u64> {
(*self).encryption_key()
}
fn encryption_noise(&self) -> DynamicDistribution<u64> {
(*self).encryption_noise()
}
fn kind(&self) -> AtomicPatternKind {
(*self).kind()
}
}
/// The client key materials for all the supported Atomic Patterns
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(AtomicPatternClientKeyVersions)]
#[allow(clippy::large_enum_variant)] // The difference in size is just because of the wopbs params in std key
pub enum AtomicPatternClientKey {
Standard(StandardAtomicPatternClientKey),
KeySwitch32(KS32AtomicPatternClientKey),
}
impl AtomicPatternClientKey {
pub(crate) fn new_with_engine(
parameters: AtomicPatternParameters,
engine: &mut ShortintEngine,
) -> Self {
match parameters {
AtomicPatternParameters::Standard(ap_params) => Self::Standard(
StandardAtomicPatternClientKey::new_with_engine(ap_params, None, engine),
),
AtomicPatternParameters::KeySwitch32(ap_params) => Self::KeySwitch32(
KS32AtomicPatternClientKey::new_with_engine(ap_params, engine),
),
}
}
pub fn new(parameters: AtomicPatternParameters) -> Self {
ShortintEngine::with_thread_local_mut(|engine| Self::new_with_engine(parameters, engine))
}
pub fn try_from_lwe_encryption_key(
encryption_key: LweSecretKeyOwned<u64>,
parameters: AtomicPatternParameters,
) -> crate::Result<Self> {
match parameters {
AtomicPatternParameters::Standard(ap_params) => Ok(Self::Standard(
StandardAtomicPatternClientKey::try_from_lwe_encryption_key(
encryption_key,
ap_params,
)?,
)),
AtomicPatternParameters::KeySwitch32(ap_params) => Ok(Self::KeySwitch32(
KS32AtomicPatternClientKey::try_from_lwe_encryption_key(encryption_key, ap_params)?,
)),
}
}
}
impl EncryptionAtomicPattern for AtomicPatternClientKey {
fn parameters(&self) -> ShortintParameterSet {
match self {
Self::Standard(ap) => ap.parameters.into(),
Self::KeySwitch32(ap) => ap.parameters.into(),
}
}
fn encryption_key(&self) -> LweSecretKeyView<'_, u64> {
match self {
Self::Standard(ap) => ap.encryption_key(),
Self::KeySwitch32(ap) => ap.encryption_key(),
}
}
fn encryption_noise(&self) -> DynamicDistribution<u64> {
match self {
Self::Standard(ap) => ap.encryption_noise(),
Self::KeySwitch32(ap) => ap.encryption_noise(),
}
}
fn kind(&self) -> AtomicPatternKind {
match self {
Self::Standard(ap_cks) => ap_cks.kind(),
Self::KeySwitch32(ap_cks) => ap_cks.kind(),
}
}
}

Some files were not shown because too many files have changed in this diff Show More