mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
27 Commits
backup/sum
...
al/debug_l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de710cb2fb | ||
|
|
59a78c76a9 | ||
|
|
1025246b17 | ||
|
|
338e9eaeef | ||
|
|
0bec4d2ba1 | ||
|
|
c5fab98900 | ||
|
|
14e1ee5bd3 | ||
|
|
52bc778629 | ||
|
|
10405c9836 | ||
|
|
5eaf6cec55 | ||
|
|
3bfacc1e9d | ||
|
|
a47a418d41 | ||
|
|
75b3141e19 | ||
|
|
d01328e0fe | ||
|
|
6e102b5fa1 | ||
|
|
8aa6fa514e | ||
|
|
21a19cd3c5 | ||
|
|
f51c70d536 | ||
|
|
66e3c02838 | ||
|
|
408e81c45a | ||
|
|
4152906c5d | ||
|
|
9fc8a0b5bc | ||
|
|
5dc3e59d13 | ||
|
|
b40996a7e5 | ||
|
|
b066ef19fa | ||
|
|
25d008bae8 | ||
|
|
2749c1088c |
2
.github/workflows/benchmark_gpu_common.yml
vendored
2
.github/workflows/benchmark_gpu_common.yml
vendored
@@ -84,7 +84,7 @@ jobs:
|
||||
run: |
|
||||
# Use Sed to extract a value from a string, this cannot be done with the ${variable//search/replace} pattern.
|
||||
# shellcheck disable=SC2001
|
||||
PARSED_COMMAND=$(echo "${INPUTS_COMMAND}" | sed 's/[[:space:]]*,[[:space:]]*/\\", \\"/g')
|
||||
PARSED_COMMAND=$(echo "${INPUTS_COMMAND}" | sed 's/[[:space:]]*,[[:space:]]*/\", \"/g')
|
||||
echo "COMMAND=[\"${PARSED_COMMAND}\"]" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Set single operations flavor
|
||||
|
||||
4
.github/workflows/benchmark_hpu_integer.yml
vendored
4
.github/workflows/benchmark_hpu_integer.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
} >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Install rust
|
||||
uses: dtolnay/rust-toolchain@a54c7afa936fefeb4456b2dd8068152669aa8203
|
||||
uses: dtolnay/rust-toolchain@888c2e1ea69ab0d4330cbf0af1ecc7b68f368cc1
|
||||
with:
|
||||
toolchain: nightly
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
|
||||
- name: Upload parsed results artifact
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02
|
||||
with:
|
||||
name: ${{ github.sha }}_integer_benchmarks
|
||||
path: ${{ env.RESULTS_FILENAME }}
|
||||
|
||||
4
.github/workflows/ci_lint.yml
vendored
4
.github/workflows/ci_lint.yml
vendored
@@ -38,9 +38,11 @@ jobs:
|
||||
- name: Check workflows security
|
||||
run: |
|
||||
make check_workflow_security
|
||||
env:
|
||||
GH_TOKEN: ${{ env.CHECKOUT_TOKEN }}
|
||||
|
||||
- name: Ensure SHA pinned actions
|
||||
uses: zgosalvez/github-actions-ensure-sha-pinned-actions@4830be28ce81da52ec70d65c552a7403821d98d4 # v3.0.23
|
||||
uses: zgosalvez/github-actions-ensure-sha-pinned-actions@fc87bb5b5a97953d987372e74478de634726b3e5 # v3.0.25
|
||||
with:
|
||||
allowlist: |
|
||||
slsa-framework/slsa-github-generator
|
||||
|
||||
4
.github/workflows/code_coverage.yml
vendored
4
.github/workflows/code_coverage.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
make test_shortint_cov
|
||||
|
||||
- name: Upload tfhe coverage to Codecov
|
||||
uses: codecov/codecov-action@ad3126e916f78f00edff4ed0317cf185271ccc2d
|
||||
uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24
|
||||
if: steps.changed-files.outputs.tfhe_any_changed == 'true'
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
@@ -104,7 +104,7 @@ jobs:
|
||||
make test_integer_cov
|
||||
|
||||
- name: Upload tfhe coverage to Codecov
|
||||
uses: codecov/codecov-action@ad3126e916f78f00edff4ed0317cf185271ccc2d
|
||||
uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24
|
||||
if: steps.changed-files.outputs.tfhe_any_changed == 'true'
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
@@ -149,7 +149,7 @@ jobs:
|
||||
|
||||
- name: Run High Level API Tests
|
||||
run: |
|
||||
BIG_TESTS_INSTANCE=FALSE make test_high_level_api_gpu
|
||||
make test_high_level_api_gpu
|
||||
|
||||
slack-notify:
|
||||
name: Slack Notification
|
||||
|
||||
8
Makefile
8
Makefile
@@ -294,7 +294,7 @@ check_typos: install_typos_checker
|
||||
.PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled
|
||||
clippy_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types,zk-pok \
|
||||
--all-targets \
|
||||
-p $(TFHE_SPEC) -- --no-deps -D warnings
|
||||
|
||||
@@ -315,7 +315,7 @@ clippy_hpu: install_rs_check_toolchain
|
||||
.PHONY: clippy_gpu_hpu # Run clippy lints on tfhe with "gpu" and "hpu" enabled
|
||||
clippy_gpu_hpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,hpu,pbs-stats,extended-types \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,hpu,pbs-stats,extended-types,zk-pok \
|
||||
--all-targets \
|
||||
-p $(TFHE_SPEC) -- --no-deps -D warnings
|
||||
|
||||
@@ -899,7 +899,7 @@ test_high_level_api: install_rs_build_toolchain
|
||||
|
||||
test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) nextest run --cargo-profile $(CARGO_PROFILE) \
|
||||
--features=integer,internal-keycache,gpu -p $(TFHE_SPEC) \
|
||||
--test-threads=4 --features=integer,internal-keycache,gpu,zk-pok -p $(TFHE_SPEC) \
|
||||
-E "test(/high_level_api::.*gpu.*/)"
|
||||
|
||||
test_high_level_api_hpu: install_rs_build_toolchain install_cargo_nextest
|
||||
@@ -1073,7 +1073,7 @@ check_compile_tests: install_rs_build_toolchain
|
||||
.PHONY: check_compile_tests_benches_gpu # Build tests in debug without running them
|
||||
check_compile_tests_benches_gpu: install_rs_build_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \
|
||||
--features=experimental,boolean,shortint,integer,internal-keycache,gpu \
|
||||
--features=experimental,boolean,shortint,integer,internal-keycache,gpu,zk-pok \
|
||||
-p $(TFHE_SPEC)
|
||||
mkdir -p "$(TFHECUDA_BUILD)" && \
|
||||
cd "$(TFHECUDA_BUILD)" && \
|
||||
|
||||
@@ -8,7 +8,7 @@ extern std::mutex m;
|
||||
extern bool p2p_enabled;
|
||||
|
||||
extern "C" {
|
||||
int32_t cuda_setup_multi_gpu();
|
||||
int32_t cuda_setup_multi_gpu(int device_0_id);
|
||||
}
|
||||
|
||||
// Define a variant type that can be either a vector or a single pointer
|
||||
|
||||
@@ -522,7 +522,8 @@ template <typename Torus>
|
||||
bool has_support_to_cuda_programmable_bootstrap_tbc(uint32_t num_samples,
|
||||
uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size,
|
||||
uint32_t level_count);
|
||||
uint32_t level_count,
|
||||
uint32_t max_shared_memory);
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,
|
||||
|
||||
@@ -492,6 +492,7 @@ __host__ void host_fourier_transform_forward_as_integer_f128(
|
||||
batch_convert_u128_to_f128_as_integer<params>
|
||||
<<<grid_size, block_size, 0, stream>>>(d_re0, d_re1, d_im0, d_im1,
|
||||
d_standard);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
// call negacyclic 128 bit forward fft.
|
||||
if (full_sm) {
|
||||
@@ -503,6 +504,7 @@ __host__ void host_fourier_transform_forward_as_integer_f128(
|
||||
<<<grid_size, block_size, shared_memory_size, stream>>>(
|
||||
d_re0, d_re1, d_im0, d_im1, d_re0, d_re1, d_im0, d_im1, buffer);
|
||||
}
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
cuda_memcpy_async_to_cpu(re0, d_re0, N / 2 * sizeof(double), stream,
|
||||
gpu_index);
|
||||
|
||||
@@ -261,6 +261,8 @@ void cuda_fourier_polynomial_mul(void *stream_v, uint32_t gpu_index,
|
||||
default:
|
||||
break;
|
||||
}
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
cuda_drop_async(buffer, stream, gpu_index);
|
||||
}
|
||||
|
||||
|
||||
@@ -279,6 +279,7 @@ void cuda_convert_lwe_programmable_bootstrap_key(cudaStream_t stream,
|
||||
PANIC("Cuda error (convert KSK): unsupported polynomial size. Supported "
|
||||
"N's are powers of two in the interval [256..16384].")
|
||||
}
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
cuda_drop_async(d_bsk, stream, gpu_index);
|
||||
cuda_drop_async(buffer, stream, gpu_index);
|
||||
@@ -315,6 +316,7 @@ void convert_u128_to_f128_and_forward_fft_128(cudaStream_t stream,
|
||||
// convert u128 into 4 x double
|
||||
batch_convert_u128_to_f128_strided_as_torus<params>
|
||||
<<<grid_size, block_size, 0, stream>>>(d_bsk, d_standard);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
// call negacyclic 128 bit forward fft.
|
||||
if (full_sm) {
|
||||
@@ -326,6 +328,7 @@ void convert_u128_to_f128_and_forward_fft_128(cudaStream_t stream,
|
||||
<<<grid_size, block_size, shared_memory_size, stream>>>(d_bsk, d_bsk,
|
||||
buffer);
|
||||
}
|
||||
check_cuda_error(cudaGetLastError());
|
||||
cuda_drop_async(buffer, stream, gpu_index);
|
||||
}
|
||||
|
||||
|
||||
@@ -848,4 +848,7 @@ template uint64_t scratch_cuda_programmable_bootstrap_tbc<uint64_t>(
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory,
|
||||
bool allocate_ms_array);
|
||||
template bool
|
||||
supports_distributed_shared_memory_on_classic_programmable_bootstrap<
|
||||
__uint128_t>(uint32_t polynomial_size, uint32_t max_shared_memory);
|
||||
#endif
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
std::mutex m;
|
||||
bool p2p_enabled = false;
|
||||
|
||||
int32_t cuda_setup_multi_gpu() {
|
||||
// Enable bidirectional p2p access between all available GPUs and device_0_id
|
||||
int32_t cuda_setup_multi_gpu(int device_0_id) {
|
||||
int num_gpus = cuda_get_number_of_gpus();
|
||||
if (num_gpus == 0)
|
||||
PANIC("GPU error: the number of GPUs should be > 0.")
|
||||
@@ -18,11 +19,13 @@ int32_t cuda_setup_multi_gpu() {
|
||||
omp_set_nested(1);
|
||||
int has_peer_access_to_device_0;
|
||||
for (int i = 1; i < num_gpus; i++) {
|
||||
check_cuda_error(
|
||||
cudaDeviceCanAccessPeer(&has_peer_access_to_device_0, i, 0));
|
||||
check_cuda_error(cudaDeviceCanAccessPeer(&has_peer_access_to_device_0,
|
||||
i, device_0_id));
|
||||
if (has_peer_access_to_device_0) {
|
||||
cuda_set_device(i);
|
||||
check_cuda_error(cudaDeviceEnablePeerAccess(0, 0));
|
||||
check_cuda_error(cudaDeviceEnablePeerAccess(device_0_id, 0));
|
||||
cuda_set_device(device_0_id);
|
||||
check_cuda_error(cudaDeviceEnablePeerAccess(i, 0));
|
||||
}
|
||||
num_used_gpus += 1;
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, TbcMultiBit)
|
||||
(benchmark::State &st) {
|
||||
if (!has_support_to_cuda_programmable_bootstrap_tbc_multi_bit<uint64_t>(
|
||||
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
|
||||
pbs_level)) {
|
||||
pbs_level, cuda_get_max_shared_memory(0))) {
|
||||
st.SkipWithError("Configuration not supported for tbc operation");
|
||||
return;
|
||||
}
|
||||
@@ -256,7 +256,7 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, TbcPBC)
|
||||
(benchmark::State &st) {
|
||||
if (!has_support_to_cuda_programmable_bootstrap_tbc<uint64_t>(
|
||||
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
|
||||
pbs_level)) {
|
||||
pbs_level, cuda_get_max_shared_memory(0))) {
|
||||
st.SkipWithError("Configuration not supported for tbc operation");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ public:
|
||||
number_of_inputs = (int)GetParam().number_of_inputs;
|
||||
|
||||
// Enable Multi-GPU logic
|
||||
gpu_count = cuda_setup_multi_gpu();
|
||||
gpu_count = cuda_setup_multi_gpu(0);
|
||||
active_gpu_count = std::min((uint)number_of_inputs, gpu_count);
|
||||
for (uint gpu_i = 0; gpu_i < active_gpu_count; gpu_i++) {
|
||||
streams.push_back(cuda_create_stream(gpu_i));
|
||||
|
||||
@@ -101,6 +101,6 @@ extern "C" {
|
||||
|
||||
pub fn cuda_drop_async(ptr: *mut c_void, stream: *mut c_void, gpu_index: u32);
|
||||
|
||||
pub fn cuda_setup_multi_gpu() -> i32;
|
||||
pub fn cuda_setup_multi_gpu(gpu_index: u32) -> i32;
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@@ -14,6 +14,8 @@ import sys
|
||||
|
||||
ONE_HOUR_IN_SECONDS = 3600
|
||||
ONE_SECOND_IN_NANOSECONDS = 1e9
|
||||
# These are directories where crypto parameters records can be stored.
|
||||
BENCHMARK_DIRS = ["tfhe-benchmark", "tfhe-zk-pok"]
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -348,8 +350,18 @@ def get_parameters(bench_id, directory):
|
||||
|
||||
:return: :class:`tuple` as ``(benchmark parameters, display name, operator type)``
|
||||
"""
|
||||
params_dir = pathlib.Path("tfhe-benchmark", "benchmarks_parameters", bench_id)
|
||||
params = _parse_file_to_json(params_dir, "parameters.json")
|
||||
for dirname in BENCHMARK_DIRS:
|
||||
params_dir = pathlib.Path(dirname, "benchmarks_parameters", bench_id)
|
||||
try:
|
||||
params = _parse_file_to_json(params_dir, "parameters.json")
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"file not found: '[...]/benchmarks_parameters/{bench_id}/parameters.json'"
|
||||
)
|
||||
|
||||
display_name = params.pop("display_name")
|
||||
operator = params.pop("operator_type")
|
||||
|
||||
@@ -140,8 +140,8 @@ if [[ "${backend}" == "gpu" ]]; then
|
||||
test_threads=8
|
||||
doctest_threads=8
|
||||
else
|
||||
test_threads=1
|
||||
doctest_threads=1
|
||||
test_threads=4
|
||||
doctest_threads=4
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@@ -138,11 +138,13 @@ pub fn test_shortint_clientkey(
|
||||
|
||||
let key: ClientKey = load_and_unversionize(dir, test, format)?;
|
||||
|
||||
if test_params != key.parameters {
|
||||
if test_params != key.parameters() {
|
||||
Err(test.failure(
|
||||
format!(
|
||||
"Invalid {} parameters:\n Expected :\n{:?}\nGot:\n{:?}",
|
||||
format, test_params, key.parameters
|
||||
format,
|
||||
test_params,
|
||||
key.parameters()
|
||||
),
|
||||
format,
|
||||
))
|
||||
|
||||
@@ -136,7 +136,7 @@ required-features = ["shortint"]
|
||||
name = "pbs128-bench"
|
||||
path = "benches/core_crypto/pbs128_bench.rs"
|
||||
harness = false
|
||||
required-features = ["shortint"]
|
||||
required-features = ["shortint", "internal-keycache"]
|
||||
|
||||
[[bin]]
|
||||
name = "boolean_key_sizes"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -418,11 +418,11 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
|
||||
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
|
||||
mod cuda {
|
||||
use super::*;
|
||||
use benchmark::utilities::{cuda_local_keys, cuda_local_streams};
|
||||
use benchmark::utilities::cuda_local_streams;
|
||||
use criterion::BatchSize;
|
||||
use itertools::Itertools;
|
||||
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
|
||||
use tfhe::integer::gpu::key_switching_key::CudaKeySwitchingKey;
|
||||
use tfhe::integer::gpu::key_switching_key::{CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial};
|
||||
use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList;
|
||||
use tfhe::integer::gpu::CudaServerKey;
|
||||
use tfhe::integer::CompressedServerKey;
|
||||
@@ -451,14 +451,17 @@ mod cuda {
|
||||
let param_name = param_name.as_str();
|
||||
let cks = ClientKey::new(param_fhe);
|
||||
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
|
||||
let sk = compressed_server_key.decompress();
|
||||
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
|
||||
|
||||
let compact_private_key = CompactPrivateKey::new(param_pke);
|
||||
let pk = CompactPublicKey::new(&compact_private_key);
|
||||
let d_ksk = CudaKeySwitchingKey::new(
|
||||
(&compact_private_key, None),
|
||||
(&cks, &gpu_sks),
|
||||
param_ksk,
|
||||
&streams,
|
||||
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
|
||||
let d_ksk_material =
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
||||
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
|
||||
&d_ksk_material,
|
||||
&gpu_sks,
|
||||
);
|
||||
|
||||
// We have a use case with 320 bits of metadata
|
||||
@@ -609,7 +612,6 @@ mod cuda {
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let gpu_sks_vec = cuda_local_keys(&cks);
|
||||
let gpu_count = get_number_of_gpus() as usize;
|
||||
|
||||
let elements = zk_throughput_num_elements();
|
||||
@@ -637,20 +639,17 @@ mod cuda {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let local_streams = cuda_local_streams(num_block, elements as usize);
|
||||
let d_ksk_vec = gpu_sks_vec
|
||||
let d_ksk_material_vec = local_streams
|
||||
.par_iter()
|
||||
.zip(local_streams.par_iter())
|
||||
.map(|(gpu_sks, local_stream)| {
|
||||
CudaKeySwitchingKey::new(
|
||||
(&compact_private_key, None),
|
||||
(&cks, gpu_sks),
|
||||
param_ksk,
|
||||
.map(|local_stream| {
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(
|
||||
&ksk,
|
||||
local_stream,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(d_ksk_vec.len(), gpu_count);
|
||||
assert_eq!(d_ksk_material_vec.len(), gpu_count);
|
||||
|
||||
bench_group.bench_function(&bench_id_verify, |b| {
|
||||
b.iter(|| {
|
||||
@@ -673,14 +672,16 @@ mod cuda {
|
||||
(gpu_cts, local_streams)
|
||||
};
|
||||
|
||||
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
|
||||
gpu_cts.par_iter()
|
||||
.zip(local_streams.par_iter())
|
||||
.enumerate()
|
||||
.for_each(|(i, (gpu_ct, local_stream))| {
|
||||
gpu_ct
|
||||
.expand_without_verification(&d_ksk_vec[i % gpu_count], local_stream)
|
||||
.unwrap();
|
||||
b.iter_batched(setup_encrypted_values,
|
||||
|(gpu_cts, local_streams)| {
|
||||
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
|
||||
(|(i, (gpu_ct, local_stream))| {
|
||||
let d_ksk =
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);
|
||||
|
||||
gpu_ct
|
||||
.expand_without_verification(&d_ksk, local_stream)
|
||||
.unwrap();
|
||||
});
|
||||
}, BatchSize::SmallInput);
|
||||
});
|
||||
@@ -698,16 +699,15 @@ mod cuda {
|
||||
(gpu_cts, local_streams)
|
||||
};
|
||||
|
||||
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
|
||||
gpu_cts
|
||||
.par_iter()
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(|(gpu_ct, local_stream)| {
|
||||
gpu_ct
|
||||
.verify_and_expand(
|
||||
&crs, &pk, &metadata, &d_ksk, local_stream
|
||||
)
|
||||
.unwrap();
|
||||
b.iter_batched(setup_encrypted_values,
|
||||
|(gpu_cts, local_streams)| {
|
||||
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
|
||||
(|(gpu_ct, local_stream)| {
|
||||
gpu_ct
|
||||
.verify_and_expand(
|
||||
&crs, &pk, &metadata, &d_ksk, local_stream,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
}, BatchSize::SmallInput);
|
||||
});
|
||||
|
||||
@@ -27,7 +27,7 @@ fn bench_server_key_unary_function<F>(
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus().0;
|
||||
let modulus = cks.parameters().message_modulus().0;
|
||||
|
||||
let clear_text = rng.gen::<u64>() % modulus;
|
||||
|
||||
@@ -70,7 +70,7 @@ fn bench_server_key_binary_function<F>(
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus().0;
|
||||
let modulus = cks.parameters().message_modulus().0;
|
||||
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
let clear_1 = rng.gen::<u64>() % modulus;
|
||||
@@ -115,7 +115,7 @@ fn bench_server_key_binary_scalar_function<F>(
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus().0;
|
||||
let modulus = cks.parameters().message_modulus().0;
|
||||
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
let clear_1 = rng.gen::<u64>() % modulus;
|
||||
@@ -159,7 +159,7 @@ fn bench_server_key_binary_scalar_division_function<F>(
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus().0;
|
||||
let modulus = cks.parameters().message_modulus().0;
|
||||
assert_ne!(modulus, 1);
|
||||
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
@@ -200,7 +200,7 @@ fn carry_extract_bench(c: &mut Criterion) {
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus().0;
|
||||
let modulus = cks.parameters().message_modulus().0;
|
||||
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
|
||||
@@ -236,7 +236,7 @@ fn programmable_bootstrapping_bench(c: &mut Criterion) {
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus().0;
|
||||
let modulus = cks.parameters().message_modulus().0;
|
||||
|
||||
let acc = sks.generate_lookup_table(|x| x);
|
||||
|
||||
|
||||
82
tfhe/docs/configuration/gpu_acceleration/zk-pok.md
Normal file
82
tfhe/docs/configuration/gpu_acceleration/zk-pok.md
Normal file
@@ -0,0 +1,82 @@
|
||||
# Zero-knowledge proofs
|
||||
|
||||
Zero-knowledge proofs (ZK) are a powerful tool to assert that the encryption of a message is correct, as discussed in [advanced features](../../fhe-computation/advanced-features/zk-pok.md).
|
||||
However, computation is not possible on the type of ciphertexts it produces (i.e. `ProvenCompactCiphertextList`). This document explains how to use the GPU to accelerate the
|
||||
preprocessing step needed to convert ciphertexts formatted for ZK to ciphertexts in the right format for computation purposes on GPU. This
|
||||
operation is called "expansion".
|
||||
|
||||
## Proven compact ciphertext list
|
||||
|
||||
A proven compact list of ciphertexts can be seen as a compacted collection of ciphertexts for which encryption can be verified.
|
||||
This verification is currently only supported on the CPU, but the expansion can be accelerated using the GPU.
|
||||
This way, verification and expansion can be performed in parallel, efficiently using all the available computational resources.
|
||||
|
||||
## Supported types
|
||||
Encrypted messages can be integers (like FheUint64) or booleans. The GPU backend does not currently support encrypted strings.
|
||||
|
||||
{% hint style="info" %}
|
||||
You can enable this feature using the flag: `--features=zk-pok,gpu` when building **TFHE-rs**.
|
||||
{% endhint %}
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
The following example shows how a client can encrypt and prove a ciphertext, and how a server can verify the proof, preprocess the ciphertext and run a computation on it on GPU:
|
||||
|
||||
```rust
|
||||
use rand::random;
|
||||
use tfhe::CompressedServerKey;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::set_server_key;
|
||||
use tfhe::zk::{CompactPkeCrs, ZkComputeLoad};
|
||||
|
||||
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let params = tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
// Indicate which parameters to use for the Compact Public Key encryption
|
||||
let cpk_params = tfhe::shortint::parameters::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
// And parameters allowing to keyswitch/cast to the computation parameters.
|
||||
let casting_params = tfhe::shortint::parameters::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
// Enable the dedicated parameters on the config
|
||||
let config = tfhe::ConfigBuilder::with_custom_parameters(params)
|
||||
.use_dedicated_compact_public_key_parameters((cpk_params, casting_params)).build();
|
||||
|
||||
// The CRS should be generated in an offline phase then shared to all clients and the server
|
||||
let crs = CompactPkeCrs::from_config(config, 64).unwrap();
|
||||
|
||||
// Then use TFHE-rs as usual
|
||||
let client_key = tfhe::ClientKey::generate(config);
|
||||
let compressed_server_key = CompressedServerKey::new(&client_key);
|
||||
let gpu_server_key = compressed_server_key.decompress_to_gpu();
|
||||
|
||||
let public_key = tfhe::CompactPublicKey::try_new(&client_key).unwrap();
|
||||
// This can be left empty, but if provided allows to tie the proof to arbitrary data
|
||||
let metadata = [b'T', b'F', b'H', b'E', b'-', b'r', b's'];
|
||||
|
||||
let clear_a = random::<u64>();
|
||||
let clear_b = random::<u64>();
|
||||
|
||||
let proven_compact_list = tfhe::ProvenCompactCiphertextList::builder(&public_key)
|
||||
.push(clear_a)
|
||||
.push(clear_b)
|
||||
.build_with_proof_packed(&crs, &metadata, ZkComputeLoad::Verify)?;
|
||||
|
||||
// Server side
|
||||
let result = {
|
||||
set_server_key(gpu_server_key);
|
||||
|
||||
// Verify the ciphertexts
|
||||
let expander =
|
||||
proven_compact_list.verify_and_expand(&crs, &public_key, &metadata)?;
|
||||
let a: tfhe::FheUint64 = expander.get(0)?.unwrap();
|
||||
let b: tfhe::FheUint64 = expander.get(1)?.unwrap();
|
||||
|
||||
a + b
|
||||
};
|
||||
|
||||
// Back on the client side
|
||||
let a_plus_b: u64 = result.decrypt(&client_key);
|
||||
assert_eq!(a_plus_b, clear_a.wrapping_add(clear_b));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
@@ -114,7 +114,7 @@ fn main() {
|
||||
let msg1 = 1;
|
||||
let msg2 = 0;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
|
||||
@@ -86,7 +86,7 @@ fn main() {
|
||||
let msg1 = 1;
|
||||
let msg2 = 0;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
|
||||
@@ -59,7 +59,7 @@ fn main() {
|
||||
let msg2 = 3;
|
||||
let scalar = 4;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -91,7 +91,7 @@ fn main() {
|
||||
let msg2 = 3;
|
||||
let scalar = 4;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -134,7 +134,7 @@ fn main() {
|
||||
let msg2 = 3;
|
||||
let scalar = 4;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -168,7 +168,7 @@ fn main() {
|
||||
let msg2 = 3;
|
||||
let scalar = 4;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -244,7 +244,7 @@ fn main() {
|
||||
let msg1 = 2;
|
||||
let msg2 = 1;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the private client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
@@ -275,7 +275,7 @@ fn main() {
|
||||
let msg1 = 2;
|
||||
let msg2 = 1;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the private client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
@@ -306,7 +306,7 @@ fn main() {
|
||||
let msg1 = 2;
|
||||
let msg2 = 1;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the private client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
@@ -365,7 +365,7 @@ fn main() {
|
||||
let msg1 = 3;
|
||||
let msg2 = 2;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
let modulus = client_key.parameters().message_modulus().0;
|
||||
|
||||
// We use the private client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::core_crypto::prelude::{
|
||||
use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu;
|
||||
|
||||
/// A structure representing a vector of LWE ciphertexts with 64 bits of precision on the GPU.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CudaLweCiphertextList<T: UnsignedInteger>(pub(crate) CudaLweList<T>);
|
||||
|
||||
#[allow(dead_code)]
|
||||
|
||||
@@ -10,6 +10,10 @@ use crate::core_crypto::prelude::{
|
||||
pub struct CudaLweCompactCiphertextList<T: UnsignedInteger>(pub CudaLweList<T>);
|
||||
|
||||
impl<T: UnsignedInteger> CudaLweCompactCiphertextList<T> {
|
||||
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
|
||||
Self(self.0.duplicate(streams))
|
||||
}
|
||||
|
||||
pub fn from_lwe_compact_ciphertext_list<C: Container<Element = T>>(
|
||||
h_ct: &LweCompactCiphertextList<C>,
|
||||
streams: &CudaStreams,
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::core_crypto::prelude::{
|
||||
UnsignedInteger,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CudaLweKeyswitchKey<T: UnsignedInteger> {
|
||||
pub(crate) d_vec: CudaVec<T>,
|
||||
|
||||
@@ -17,7 +17,6 @@ pub use entities::*;
|
||||
use std::ffi::c_void;
|
||||
use tfhe_cuda_backend::bindings::*;
|
||||
use tfhe_cuda_backend::cuda_bind::*;
|
||||
|
||||
pub struct CudaStreams {
|
||||
pub ptr: Vec<*mut c_void>,
|
||||
pub gpu_indexes: Vec<GpuIndex>,
|
||||
@@ -30,7 +29,7 @@ unsafe impl Sync for CudaStreams {}
|
||||
impl CudaStreams {
|
||||
/// Create a new `CudaStreams` structure with as many GPUs as there are on the machine
|
||||
pub fn new_multi_gpu() -> Self {
|
||||
let gpu_count = setup_multi_gpu();
|
||||
let gpu_count = setup_multi_gpu(GpuIndex::new(0));
|
||||
let mut gpu_indexes = Vec::with_capacity(gpu_count as usize);
|
||||
let mut ptr_array = Vec::with_capacity(gpu_count as usize);
|
||||
|
||||
@@ -43,6 +42,22 @@ impl CudaStreams {
|
||||
gpu_indexes,
|
||||
}
|
||||
}
|
||||
/// Create a new `CudaStreams` structure with the GPUs with id provided in a list
|
||||
pub fn new_multi_gpu_with_indexes(indexes: &[GpuIndex]) -> Self {
|
||||
let _gpu_count = setup_multi_gpu(indexes[0]);
|
||||
|
||||
let mut gpu_indexes = Vec::with_capacity(indexes.len());
|
||||
let mut ptr_array = Vec::with_capacity(indexes.len());
|
||||
|
||||
for &i in indexes {
|
||||
ptr_array.push(unsafe { cuda_create_stream(i.get()) });
|
||||
gpu_indexes.push(i);
|
||||
}
|
||||
Self {
|
||||
ptr: ptr_array,
|
||||
gpu_indexes,
|
||||
}
|
||||
}
|
||||
/// Create a new `CudaStreams` structure with one GPU, whose index corresponds to the one given
|
||||
/// as input
|
||||
pub fn new_single_gpu(gpu_index: GpuIndex) -> Self {
|
||||
@@ -88,6 +103,14 @@ impl CudaStreams {
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for CudaStreams {
|
||||
fn clone(&self) -> Self {
|
||||
// The `new_multi_gpu_with_indexes()` function is used here to adapt to any specific type of
|
||||
// streams being cloned (single, multi, or custom)
|
||||
Self::new_multi_gpu_with_indexes(self.gpu_indexes.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaStreams {
|
||||
fn drop(&mut self) {
|
||||
for (i, &s) in self.ptr.iter().enumerate() {
|
||||
@@ -959,7 +982,7 @@ pub unsafe fn fourier_transform_backward_as_torus_f128_async<T: UnsignedInteger>
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CudaLweList<T: UnsignedInteger> {
|
||||
// Pointer to GPU data
|
||||
pub d_vec: CudaVec<T>,
|
||||
@@ -971,6 +994,17 @@ pub struct CudaLweList<T: UnsignedInteger> {
|
||||
pub ciphertext_modulus: CiphertextModulus<T>,
|
||||
}
|
||||
|
||||
impl<T: UnsignedInteger> CudaLweList<T> {
|
||||
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
|
||||
Self {
|
||||
d_vec: self.d_vec.duplicate(streams),
|
||||
lwe_ciphertext_count: self.lwe_ciphertext_count,
|
||||
lwe_dimension: self.lwe_dimension,
|
||||
ciphertext_modulus: self.ciphertext_modulus,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CudaGlweList<T: UnsignedInteger> {
|
||||
// Pointer to GPU data
|
||||
@@ -1002,8 +1036,8 @@ pub fn get_number_of_gpus() -> u32 {
|
||||
}
|
||||
|
||||
/// Setup multi-GPU and return the number of GPUs used
|
||||
pub fn setup_multi_gpu() -> u32 {
|
||||
unsafe { cuda_setup_multi_gpu() as u32 }
|
||||
pub fn setup_multi_gpu(device_0_id: GpuIndex) -> u32 {
|
||||
unsafe { cuda_setup_multi_gpu(device_0_id.get()) as u32 }
|
||||
}
|
||||
|
||||
/// Synchronize device
|
||||
|
||||
@@ -7,7 +7,6 @@ use crate::array::stride::{ParStridedIter, ParStridedIterMut, StridedIter};
|
||||
use crate::array::traits::TensorSlice;
|
||||
use crate::high_level_api::array::{ArrayBackend, BackendDataContainer, BackendDataContainerMut};
|
||||
use crate::high_level_api::global_state;
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::prelude::{FheDecrypt, FheTryEncrypt};
|
||||
use crate::{ClientKey, FheBoolId};
|
||||
@@ -25,10 +24,15 @@ pub type GpuFheBoolSliceMut<'a> =
|
||||
pub struct GpuBooleanSlice<'a>(pub(crate) &'a [CudaBooleanBlock]);
|
||||
pub struct GpuBooleanSliceMut<'a>(pub(crate) &'a mut [CudaBooleanBlock]);
|
||||
pub struct GpuBooleanOwned(pub(crate) Vec<CudaBooleanBlock>);
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
|
||||
impl Clone for GpuBooleanOwned {
|
||||
fn clone(&self) -> Self {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
// When cloning, we assume that the intention is to return a ciphertext that lies in the GPU
|
||||
// 0 defined in the set server key. Hence, we use the server key to get the streams instead
|
||||
// of those inside the ciphertext itself
|
||||
with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
Self(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
|
||||
})
|
||||
}
|
||||
@@ -83,7 +87,8 @@ impl BackendDataContainer for GpuBooleanSlice<'_> {
|
||||
}
|
||||
|
||||
fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
GpuBooleanOwned(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
|
||||
})
|
||||
}
|
||||
@@ -104,7 +109,8 @@ impl BackendDataContainer for GpuBooleanSliceMut<'_> {
|
||||
}
|
||||
|
||||
fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
GpuBooleanOwned(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
|
||||
})
|
||||
}
|
||||
@@ -156,14 +162,13 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, Self::Slice<'a>>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -172,14 +177,13 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, Self::Slice<'a>>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -188,24 +192,22 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, Self::Slice<'a>>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
|
||||
fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams)))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams)))
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -216,16 +218,13 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, &'_ [bool]>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(
|
||||
cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -234,16 +233,13 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, &'_ [bool]>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(
|
||||
cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -252,16 +248,13 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, &'_ [bool]>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(
|
||||
cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -270,7 +263,8 @@ impl FheTryEncrypt<&[bool], ClientKey> for GpuFheBoolArray {
|
||||
type Error = crate::Error;
|
||||
|
||||
fn try_encrypt(values: &[bool], cks: &ClientKey) -> Result<Self, Self::Error> {
|
||||
let encrypted = with_thread_local_cuda_streams(|streams| {
|
||||
let encrypted = with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
values
|
||||
.iter()
|
||||
.copied()
|
||||
@@ -285,7 +279,8 @@ impl FheTryEncrypt<&[bool], ClientKey> for GpuFheBoolArray {
|
||||
|
||||
impl FheDecrypt<Vec<bool>> for GpuFheBoolSlice<'_> {
|
||||
fn decrypt(&self, key: &ClientKey) -> Vec<bool> {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
self.elems
|
||||
.0
|
||||
.iter()
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::array::traits::{
|
||||
};
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::high_level_api::global_state;
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
use crate::high_level_api::integers::{FheIntId, FheUintId};
|
||||
use crate::integer::block_decomposition::{
|
||||
DecomposableInto, RecomposableFrom, RecomposableSignedInteger,
|
||||
@@ -53,7 +53,8 @@ where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
Self(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
|
||||
})
|
||||
}
|
||||
@@ -108,12 +109,11 @@ where
|
||||
F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, &T, &CudaStreams) -> T,
|
||||
{
|
||||
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -170,12 +170,11 @@ where
|
||||
F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, Clear, &CudaStreams) -> T,
|
||||
{
|
||||
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -336,11 +335,10 @@ where
|
||||
|
||||
fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
|
||||
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -439,7 +437,8 @@ where
|
||||
}
|
||||
|
||||
fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
GpuOwned(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
|
||||
})
|
||||
}
|
||||
@@ -463,7 +462,8 @@ where
|
||||
}
|
||||
|
||||
fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
GpuOwned(self.0.iter().map(|elem| elem.duplicate(streams)).collect())
|
||||
})
|
||||
}
|
||||
@@ -492,7 +492,8 @@ where
|
||||
fn try_encrypt(clears: &'a [Clear], key: &ClientKey) -> Result<Self, Self::Error> {
|
||||
let num_blocks = Id::num_blocks(key.message_modulus());
|
||||
Ok(Self::new(
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
clears
|
||||
.iter()
|
||||
.copied()
|
||||
@@ -527,7 +528,8 @@ where
|
||||
));
|
||||
}
|
||||
let num_blocks = Id::num_blocks(key.message_modulus());
|
||||
let elems = with_thread_local_cuda_streams(|streams| {
|
||||
let elems = with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
clears
|
||||
.iter()
|
||||
.copied()
|
||||
@@ -570,7 +572,8 @@ where
|
||||
Clear: RecomposableFrom<u64> + UnsignedNumeric,
|
||||
{
|
||||
fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
self.as_tensor_slice()
|
||||
.iter()
|
||||
.map(|ct: &CudaUnsignedRadixCiphertext| {
|
||||
@@ -591,7 +594,8 @@ where
|
||||
fn try_encrypt(clears: &'a [Clear], key: &ClientKey) -> Result<Self, Self::Error> {
|
||||
let num_blocks = Id::num_blocks(key.message_modulus());
|
||||
Ok(Self::new(
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
clears
|
||||
.iter()
|
||||
.copied()
|
||||
@@ -634,7 +638,8 @@ where
|
||||
Clear: RecomposableSignedInteger,
|
||||
{
|
||||
fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
self.elems
|
||||
.0
|
||||
.iter()
|
||||
|
||||
@@ -13,8 +13,6 @@ pub(in crate::high_level_api) mod traits;
|
||||
use crate::array::traits::TensorSlice;
|
||||
use crate::high_level_api::array::traits::HasClear;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::{FheIntId, FheUintId};
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::{FheBool, FheId, FheInt, FheUint, Tag};
|
||||
@@ -369,7 +367,8 @@ pub fn fhe_uint_array_eq<Id: FheUintId>(lhs: &[FheUint<Id>], rhs: &[FheUint<Id>]
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(gpu_key) => {
|
||||
let streams = &gpu_key.streams;
|
||||
let tmp_lhs = lhs
|
||||
.iter()
|
||||
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams))
|
||||
@@ -381,7 +380,7 @@ pub fn fhe_uint_array_eq<Id: FheUintId>(lhs: &[FheUint<Id>], rhs: &[FheUint<Id>]
|
||||
|
||||
let result = gpu_key.key.key.all_eq_slices(&tmp_lhs, &tmp_rhs, streams);
|
||||
FheBool::new(result, gpu_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support Array yet.")
|
||||
@@ -410,7 +409,8 @@ pub fn fhe_uint_array_contains_sub_slice<Id: FheUintId>(
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(gpu_key) => {
|
||||
let streams = &gpu_key.streams;
|
||||
let tmp_lhs = lhs
|
||||
.iter()
|
||||
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams))
|
||||
@@ -425,7 +425,7 @@ pub fn fhe_uint_array_contains_sub_slice<Id: FheUintId>(
|
||||
.key
|
||||
.contains_sub_slice(&tmp_lhs, &tmp_pattern, streams);
|
||||
FheBool::new(result, gpu_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support Array yet.")
|
||||
|
||||
@@ -3,6 +3,9 @@ use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||
|
||||
use crate::{CompactCiphertextList, Tag};
|
||||
|
||||
#[cfg(feature = "zk-pok")]
|
||||
use crate::ProvenCompactCiphertextList;
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct CompactCiphertextListV0(crate::integer::ciphertext::CompactCiphertextList);
|
||||
|
||||
@@ -17,9 +20,6 @@ impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "zk-pok")]
|
||||
use crate::ProvenCompactCiphertextList;
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
pub enum CompactCiphertextListVersions {
|
||||
V0(CompactCiphertextListV0),
|
||||
|
||||
@@ -3,8 +3,6 @@ use crate::backward_compatibility::booleans::FheBoolVersions;
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric};
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId};
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::high_level_api::traits::{FheEq, IfThenElse, ScalarIfThenElse, Tagged};
|
||||
@@ -408,7 +406,8 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool {
|
||||
(InnerBoolean::Cpu(new_ct), key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.if_then_else(
|
||||
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
|
||||
&*ct_then.ciphertext.on_gpu(streams),
|
||||
@@ -417,7 +416,7 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool {
|
||||
);
|
||||
let boolean_inner = CudaBooleanBlock(inner);
|
||||
(InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support if_then_else with clear input")
|
||||
@@ -449,7 +448,8 @@ where
|
||||
FheUint::new(inner, cpu_sks.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.if_then_else(
|
||||
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
|
||||
&*ct_then.ciphertext.on_gpu(streams),
|
||||
@@ -458,7 +458,7 @@ where
|
||||
);
|
||||
|
||||
FheUint::new(inner, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_then = ct_then.ciphertext.on_hpu(device);
|
||||
@@ -506,7 +506,8 @@ impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
|
||||
FheInt::new(new_ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.if_then_else(
|
||||
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
|
||||
&*ct_then.ciphertext.on_gpu(streams),
|
||||
@@ -515,7 +516,7 @@ impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
|
||||
);
|
||||
|
||||
FheInt::new(inner, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support signed integers")
|
||||
@@ -537,7 +538,8 @@ impl IfThenElse<Self> for FheBool {
|
||||
(InnerBoolean::Cpu(new_ct), key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.if_then_else(
|
||||
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
|
||||
&*ct_then.ciphertext.on_gpu(streams),
|
||||
@@ -546,7 +548,7 @@ impl IfThenElse<Self> for FheBool {
|
||||
);
|
||||
let boolean_inner = CudaBooleanBlock(inner);
|
||||
(InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bool if then else")
|
||||
@@ -600,7 +602,8 @@ where
|
||||
Self::new(ciphertext, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.eq(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&other.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -608,7 +611,7 @@ where
|
||||
);
|
||||
let ciphertext = InnerBoolean::Cuda(inner);
|
||||
Self::new(ciphertext, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support FheBool::eq")
|
||||
@@ -646,7 +649,8 @@ where
|
||||
Self::new(ciphertext, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.ne(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&other.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -654,7 +658,7 @@ where
|
||||
);
|
||||
let ciphertext = InnerBoolean::Cuda(inner);
|
||||
Self::new(ciphertext, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support FheBool::ne")
|
||||
@@ -695,14 +699,15 @@ impl FheEq<bool> for FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.scalar_eq(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(other),
|
||||
streams,
|
||||
);
|
||||
(InnerBoolean::Cuda(inner), cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support FheBool::eq with a bool")
|
||||
@@ -742,14 +747,15 @@ impl FheEq<bool> for FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.scalar_ne(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(other),
|
||||
streams,
|
||||
);
|
||||
(InnerBoolean::Cuda(inner), cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support FheBool::ne with a bool")
|
||||
@@ -820,7 +826,8 @@ where
|
||||
(InnerBoolean::Cpu(inner_ct), key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_ct = cuda_key.key.key.bitand(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -833,7 +840,7 @@ where
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitand (&)")
|
||||
@@ -909,7 +916,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_ct = cuda_key.key.key.bitor(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -921,7 +929,7 @@ where
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitor (|)")
|
||||
@@ -997,7 +1005,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_ct = cuda_key.key.key.bitxor(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -1009,7 +1018,7 @@ where
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitxor (^)")
|
||||
@@ -1077,7 +1086,8 @@ impl BitAnd<bool> for &FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_ct = cuda_key.key.key.scalar_bitand(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(rhs),
|
||||
@@ -1089,7 +1099,7 @@ impl BitAnd<bool> for &FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("hpu does not bitand (&) with a bool")
|
||||
@@ -1157,7 +1167,8 @@ impl BitOr<bool> for &FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_ct = cuda_key.key.key.scalar_bitor(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(rhs),
|
||||
@@ -1169,7 +1180,7 @@ impl BitOr<bool> for &FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("hpu does not bitor (|) with a bool")
|
||||
@@ -1237,7 +1248,8 @@ impl BitXor<bool> for &FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_ct = cuda_key.key.key.scalar_bitxor(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(rhs),
|
||||
@@ -1249,7 +1261,7 @@ impl BitXor<bool> for &FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("hpu does not bitxor (^) with a bool")
|
||||
@@ -1445,13 +1457,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitand assign (&=)")
|
||||
@@ -1492,13 +1505,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitor assign (|=)")
|
||||
@@ -1539,13 +1553,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitxor assign (^=)")
|
||||
@@ -1580,13 +1595,14 @@ impl BitAndAssign<bool> for FheBool {
|
||||
.scalar_bitand_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs));
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
u8::from(rhs),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitand assign (&=) with a bool")
|
||||
@@ -1621,13 +1637,14 @@ impl BitOrAssign<bool> for FheBool {
|
||||
.scalar_bitor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs));
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
u8::from(rhs),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitor assign (|=) with a bool")
|
||||
@@ -1662,13 +1679,14 @@ impl BitXorAssign<bool> for FheBool {
|
||||
.scalar_bitxor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs));
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
u8::from(rhs),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitor assign (^=) with a bool")
|
||||
@@ -1729,7 +1747,8 @@ impl std::ops::Not for &FheBool {
|
||||
(InnerBoolean::Cpu(inner), key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner =
|
||||
cuda_key
|
||||
.key
|
||||
@@ -1741,7 +1760,7 @@ impl std::ops::Not for &FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitnot (!)")
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use super::base::FheBool;
|
||||
use crate::high_level_api::booleans::inner::InnerBoolean;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
@@ -90,7 +88,8 @@ impl FheTryTrivialEncrypt<bool> for FheBool {
|
||||
(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner: CudaUnsignedRadixCiphertext =
|
||||
cuda_key
|
||||
.key
|
||||
@@ -100,7 +99,7 @@ impl FheTryTrivialEncrypt<bool> for FheBool {
|
||||
inner.into_inner(),
|
||||
));
|
||||
(ct, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support trivial encryption")
|
||||
|
||||
@@ -4,9 +4,9 @@ use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::high_level_api::details::MaybeCloned;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::{
|
||||
with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes,
|
||||
};
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams_for_gpu_indexes;
|
||||
use crate::integer::BooleanBlock;
|
||||
use crate::Device;
|
||||
use serde::{Deserializer, Serializer};
|
||||
@@ -36,7 +36,9 @@ impl Clone for InnerBoolean {
|
||||
Self::Cpu(inner) => Self::Cpu(inner.clone()),
|
||||
#[cfg(feature = "gpu")]
|
||||
Self::Cuda(inner) => {
|
||||
with_thread_local_cuda_streams(|streams| Self::Cuda(inner.duplicate(streams)))
|
||||
with_thread_local_cuda_streams_for_gpu_indexes(inner.gpu_indexes(), |streams| {
|
||||
Self::Cuda(inner.duplicate(streams))
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
Self::Hpu(inner) => Self::Hpu(inner.clone()),
|
||||
@@ -251,7 +253,8 @@ impl InnerBoolean {
|
||||
#[cfg(feature = "gpu")]
|
||||
// We may not be on the correct Cuda device
|
||||
if let Self::Cuda(cuda_ct) = self {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
if cuda_ct.gpu_indexes() != streams.gpu_indexes() {
|
||||
*cuda_ct = cuda_ct.duplicate(streams);
|
||||
}
|
||||
@@ -273,7 +276,8 @@ impl InnerBoolean {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
Device::CudaGpu => {
|
||||
let new_inner = with_thread_local_cuda_streams(|streams| {
|
||||
let new_inner = with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock::from_boolean_block(
|
||||
&cpu_ct, streams,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use super::{FheBool, InnerBoolean};
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
@@ -41,7 +39,8 @@ impl FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let d_ct: CudaUnsignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -52,7 +51,7 @@ impl FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support random bool generation")
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use crate::backward_compatibility::compact_list::CompactCiphertextListVersions;
|
||||
#[cfg(feature = "zk-pok")]
|
||||
use crate::backward_compatibility::compact_list::ProvenCompactCiphertextListVersions;
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::core_crypto::commons::math::random::{Deserialize, Serialize};
|
||||
use crate::core_crypto::prelude::Numeric;
|
||||
@@ -10,7 +6,7 @@ use crate::high_level_api::global_state;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::high_level_api::traits::Tagged;
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
use crate::integer::ciphertext::{Compactable, DataKind, Expandable};
|
||||
use crate::integer::ciphertext::{Compactable, DataKind};
|
||||
use crate::integer::encryption::KnowsMessageModulus;
|
||||
use crate::integer::parameters::{
|
||||
CompactCiphertextListConformanceParams, IntegerCompactCiphertextListExpansionMode,
|
||||
@@ -18,9 +14,14 @@ use crate::integer::parameters::{
|
||||
use crate::named::Named;
|
||||
use crate::prelude::CiphertextList;
|
||||
use crate::shortint::MessageModulus;
|
||||
use crate::HlExpandable;
|
||||
use tfhe_versionable::Versionize;
|
||||
#[cfg(feature = "zk-pok")]
|
||||
pub use zk::ProvenCompactCiphertextList;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
|
||||
#[cfg(feature = "zk-pok")]
|
||||
use crate::zk::{CompactPkeCrs, ZkComputeLoad};
|
||||
use crate::{CompactPublicKey, Tag};
|
||||
@@ -173,7 +174,7 @@ impl CompactCiphertextList {
|
||||
self.inner
|
||||
.expand(sks.integer_compact_ciphertext_list_expansion_mode())
|
||||
.map(|inner| CompactCiphertextListExpander {
|
||||
inner,
|
||||
inner: InnerCompactCiphertextListExpander::Cpu(inner),
|
||||
tag: self.tag.clone(),
|
||||
})
|
||||
}
|
||||
@@ -183,9 +184,11 @@ impl CompactCiphertextList {
|
||||
if !self.inner.is_packed() && !self.inner.needs_casting() {
|
||||
// No ServerKey required, short-circuit to avoid the global state call
|
||||
return Ok(CompactCiphertextListExpander {
|
||||
inner: self
|
||||
.inner
|
||||
.expand(IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking)?,
|
||||
inner: InnerCompactCiphertextListExpander::Cpu(
|
||||
self.inner.expand(
|
||||
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking,
|
||||
)?,
|
||||
),
|
||||
tag: self.tag.clone(),
|
||||
});
|
||||
}
|
||||
@@ -196,7 +199,7 @@ impl CompactCiphertextList {
|
||||
.inner
|
||||
.expand(cpu_key.integer_compact_ciphertext_list_expansion_mode())
|
||||
.map(|inner| CompactCiphertextListExpander {
|
||||
inner,
|
||||
inner: InnerCompactCiphertextListExpander::Cpu(inner),
|
||||
tag: self.tag.clone(),
|
||||
}),
|
||||
#[cfg(any(feature = "gpu", feature = "hpu"))]
|
||||
@@ -228,17 +231,144 @@ impl ParameterSetConformant for CompactCiphertextList {
|
||||
#[cfg(feature = "zk-pok")]
|
||||
mod zk {
|
||||
use super::*;
|
||||
use crate::backward_compatibility::compact_list::ProvenCompactCiphertextListVersions;
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::high_level_api::global_state::device_of_internal_keys;
|
||||
use crate::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams;
|
||||
use crate::zk::CompactPkeCrs;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::zk::CudaProvenCompactCiphertextList;
|
||||
use serde::Serializer;
|
||||
|
||||
pub enum InnerProvenCompactCiphertextList {
|
||||
Cpu(crate::integer::ciphertext::ProvenCompactCiphertextList),
|
||||
#[cfg(feature = "gpu")]
|
||||
Cuda(crate::integer::gpu::zk::CudaProvenCompactCiphertextList),
|
||||
}
|
||||
|
||||
impl Clone for InnerProvenCompactCiphertextList {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Cpu(inner) => Self::Cpu(inner.clone()),
|
||||
#[cfg(feature = "gpu")]
|
||||
Self::Cuda(inner) => with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
Self::Cuda(inner.duplicate(streams))
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Versionize)]
|
||||
#[versionize(ProvenCompactCiphertextListVersions)]
|
||||
pub struct ProvenCompactCiphertextList {
|
||||
pub(crate) inner: crate::integer::ciphertext::ProvenCompactCiphertextList,
|
||||
pub(crate) inner: InnerProvenCompactCiphertextList,
|
||||
pub(crate) tag: Tag,
|
||||
}
|
||||
|
||||
impl InnerProvenCompactCiphertextList {
|
||||
pub(crate) fn on_cpu(&self) -> &crate::integer::ciphertext::ProvenCompactCiphertextList {
|
||||
match self {
|
||||
Self::Cpu(inner) => inner,
|
||||
#[cfg(feature = "gpu")]
|
||||
Self::Cuda(inner) => &inner.h_proved_lists,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)] // Method can return an error if hpu is enabled
|
||||
fn move_to_device(&mut self, device: crate::Device) -> Result<(), crate::Error> {
|
||||
let new_value = match (&self, device) {
|
||||
(Self::Cpu(_), crate::Device::Cpu) => None,
|
||||
#[cfg(feature = "gpu")]
|
||||
(Self::Cuda(cuda_ct), crate::Device::CudaGpu) => with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
if cuda_ct.gpu_indexes() == streams.gpu_indexes() {
|
||||
None
|
||||
} else {
|
||||
Some(Self::Cuda(cuda_ct.duplicate(streams)))
|
||||
}
|
||||
}),
|
||||
#[cfg(feature = "gpu")]
|
||||
(Self::Cuda(cuda_ct), crate::Device::Cpu) => {
|
||||
let cpu_ct = cuda_ct.h_proved_lists.clone();
|
||||
Some(Self::Cpu(cpu_ct))
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
(Self::Cpu(cpu_ct), crate::Device::CudaGpu) => {
|
||||
let cuda_ct = with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
|
||||
cpu_ct, streams,
|
||||
)
|
||||
});
|
||||
Some(Self::Cuda(cuda_ct))
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
(_, crate::Device::Hpu) => {
|
||||
return Err(crate::error!(
|
||||
"Hpu does not support ProvenCompactCiphertextList"
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(v) = new_value {
|
||||
*self = v;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for InnerProvenCompactCiphertextList {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
self.on_cpu().serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for InnerProvenCompactCiphertextList {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let mut new =
|
||||
crate::integer::ciphertext::ProvenCompactCiphertextList::deserialize(deserializer)
|
||||
.map(Self::Cpu)?;
|
||||
|
||||
if let Some(device) = device_of_internal_keys() {
|
||||
new.move_to_device(device)
|
||||
.map_err(serde::de::Error::custom)?;
|
||||
}
|
||||
|
||||
Ok(new)
|
||||
}
|
||||
}
|
||||
use tfhe_versionable::{Unversionize, UnversionizeError, VersionizeOwned};
|
||||
impl Versionize for InnerProvenCompactCiphertextList {
|
||||
type Versioned<'vers> =
|
||||
<crate::integer::ciphertext::ProvenCompactCiphertextList as VersionizeOwned>::VersionedOwned;
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
self.on_cpu().clone().versionize_owned()
|
||||
}
|
||||
}
|
||||
impl VersionizeOwned for InnerProvenCompactCiphertextList {
|
||||
type VersionedOwned =
|
||||
<crate::integer::ciphertext::ProvenCompactCiphertextList as VersionizeOwned>::VersionedOwned;
|
||||
fn versionize_owned(self) -> Self::VersionedOwned {
|
||||
self.on_cpu().clone().versionize_owned()
|
||||
}
|
||||
}
|
||||
|
||||
impl Unversionize for InnerProvenCompactCiphertextList {
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
|
||||
Ok(Self::Cpu(
|
||||
crate::integer::ciphertext::ProvenCompactCiphertextList::unversionize(versioned)?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Tagged for ProvenCompactCiphertextList {
|
||||
fn tag(&self) -> &Tag {
|
||||
&self.tag
|
||||
@@ -258,7 +388,7 @@ mod zk {
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.inner.len()
|
||||
self.inner.on_cpu().len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
@@ -266,8 +396,9 @@ mod zk {
|
||||
}
|
||||
|
||||
pub fn get_kind_of(&self, index: usize) -> Option<crate::FheTypes> {
|
||||
self.inner.get_kind_of(index).and_then(|data_kind| {
|
||||
crate::FheTypes::from_data_kind(data_kind, self.inner.ct_list.message_modulus())
|
||||
let inner_cpu = self.inner.on_cpu();
|
||||
inner_cpu.get_kind_of(index).and_then(|data_kind| {
|
||||
crate::FheTypes::from_data_kind(data_kind, inner_cpu.ct_list.message_modulus())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -277,7 +408,7 @@ mod zk {
|
||||
pk: &CompactPublicKey,
|
||||
metadata: &[u8],
|
||||
) -> crate::zk::ZkVerificationOutcome {
|
||||
self.inner.verify(crs, &pk.key.key, metadata)
|
||||
self.inner.on_cpu().verify(crs, &pk.key.key, metadata)
|
||||
}
|
||||
|
||||
pub fn verify_and_expand(
|
||||
@@ -286,36 +417,112 @@ mod zk {
|
||||
pk: &CompactPublicKey,
|
||||
metadata: &[u8],
|
||||
) -> crate::Result<CompactCiphertextListExpander> {
|
||||
// For WASM
|
||||
if !self.inner.is_packed() && !self.inner.needs_casting() {
|
||||
// No ServerKey required, short circuit to avoid the global state call
|
||||
return Ok(CompactCiphertextListExpander {
|
||||
inner: self.inner.verify_and_expand(
|
||||
#[allow(irrefutable_let_patterns)]
|
||||
if let InnerProvenCompactCiphertextList::Cpu(inner) = &self.inner {
|
||||
// For WASM
|
||||
if !inner.is_packed() && !inner.needs_casting() {
|
||||
let expander = inner.verify_and_expand(
|
||||
crs,
|
||||
&pk.key.key,
|
||||
metadata,
|
||||
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking,
|
||||
)?,
|
||||
tag: self.tag.clone(),
|
||||
});
|
||||
)?;
|
||||
// No ServerKey required, short circuit to avoid the global state call
|
||||
return Ok(CompactCiphertextListExpander {
|
||||
inner: InnerCompactCiphertextListExpander::Cpu(expander),
|
||||
tag: self.tag.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
global_state::try_with_internal_keys(|maybe_keys| match maybe_keys {
|
||||
None => Err(crate::high_level_api::errors::UninitializedServerKey.into()),
|
||||
Some(InternalServerKey::Cpu(cpu_key)) => self
|
||||
.inner
|
||||
.verify_and_expand(
|
||||
crs,
|
||||
&pk.key.key,
|
||||
metadata,
|
||||
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
|
||||
)
|
||||
.map(|expander| CompactCiphertextListExpander {
|
||||
inner: expander,
|
||||
tag: self.tag.clone(),
|
||||
}),
|
||||
#[cfg(any(feature = "gpu", feature = "hpu"))]
|
||||
Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())),
|
||||
Some(InternalServerKey::Cpu(cpu_key)) => match &self.inner {
|
||||
InnerProvenCompactCiphertextList::Cpu(inner) => inner
|
||||
.verify_and_expand(
|
||||
crs,
|
||||
&pk.key.key,
|
||||
metadata,
|
||||
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
|
||||
)
|
||||
.map(|expander| CompactCiphertextListExpander {
|
||||
inner: InnerCompactCiphertextListExpander::Cpu(expander),
|
||||
tag: self.tag.clone(),
|
||||
}),
|
||||
#[cfg(feature = "gpu")]
|
||||
InnerProvenCompactCiphertextList::Cuda(inner) => inner
|
||||
.h_proved_lists
|
||||
.verify_and_expand(
|
||||
crs,
|
||||
&pk.key.key,
|
||||
metadata,
|
||||
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
|
||||
)
|
||||
.map(|expander| CompactCiphertextListExpander {
|
||||
inner: InnerCompactCiphertextListExpander::Cpu(expander),
|
||||
tag: self.tag.clone(),
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
Some(InternalServerKey::Cuda(gpu_key)) => match &self.inner {
|
||||
InnerProvenCompactCiphertextList::Cuda(inner) => {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
let ksk = CudaKeySwitchingKey {
|
||||
key_switching_key_material: gpu_key
|
||||
.key
|
||||
.cpk_key_switching_key_material
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
dest_server_key: &gpu_key.key.key,
|
||||
};
|
||||
let expander = inner.verify_and_expand(
|
||||
crs,
|
||||
&pk.key.key,
|
||||
metadata,
|
||||
&ksk,
|
||||
streams,
|
||||
)?;
|
||||
|
||||
Ok(CompactCiphertextListExpander {
|
||||
inner: InnerCompactCiphertextListExpander::Cuda(expander),
|
||||
tag: self.tag.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
InnerProvenCompactCiphertextList::Cpu(cpu_inner) => {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
let gpu_proven_ct = CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
|
||||
cpu_inner, streams,
|
||||
);
|
||||
let ksk = CudaKeySwitchingKey {
|
||||
key_switching_key_material: gpu_key
|
||||
.key
|
||||
.cpk_key_switching_key_material
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
dest_server_key: &gpu_key.key.key,
|
||||
};
|
||||
let expander = gpu_proven_ct.verify_and_expand(
|
||||
crs,
|
||||
&pk.key.key,
|
||||
metadata,
|
||||
&ksk,
|
||||
streams,
|
||||
)?;
|
||||
|
||||
Ok(CompactCiphertextListExpander {
|
||||
inner: InnerCompactCiphertextListExpander::Cuda(expander),
|
||||
tag: self.tag.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
Some(InternalServerKey::Hpu(_)) => Err(crate::error!(
|
||||
"Hpu does not support ProvenCompactCiphertextList"
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -324,30 +531,86 @@ mod zk {
|
||||
///
|
||||
/// If you are here you were probably looking for it: use at your own risks.
|
||||
pub fn expand_without_verification(&self) -> crate::Result<CompactCiphertextListExpander> {
|
||||
// For WASM
|
||||
if !self.inner.is_packed() && !self.inner.needs_casting() {
|
||||
// No ServerKey required, short circuit to avoid the global state call
|
||||
return Ok(CompactCiphertextListExpander {
|
||||
inner: self.inner.expand_without_verification(
|
||||
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking,
|
||||
)?,
|
||||
tag: self.tag.clone(),
|
||||
});
|
||||
#[allow(irrefutable_let_patterns)]
|
||||
if let InnerProvenCompactCiphertextList::Cpu(inner) = &self.inner {
|
||||
// For WASM
|
||||
if !inner.is_packed() && !inner.needs_casting() {
|
||||
// No ServerKey required, short circuit to avoid the global state call
|
||||
return Ok(CompactCiphertextListExpander {
|
||||
inner: InnerCompactCiphertextListExpander::Cpu(
|
||||
inner.expand_without_verification(
|
||||
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking,
|
||||
)?,
|
||||
),
|
||||
tag: self.tag.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
global_state::try_with_internal_keys(|maybe_keys| match maybe_keys {
|
||||
global_state::try_with_internal_keys(|maybe_keys| {
|
||||
match maybe_keys {
|
||||
None => Err(crate::high_level_api::errors::UninitializedServerKey.into()),
|
||||
Some(InternalServerKey::Cpu(cpu_key)) => self
|
||||
.inner
|
||||
.expand_without_verification(
|
||||
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
|
||||
)
|
||||
.map(|expander| CompactCiphertextListExpander {
|
||||
inner: expander,
|
||||
tag: self.tag.clone(),
|
||||
}),
|
||||
#[cfg(any(feature = "gpu", feature = "hpu"))]
|
||||
Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())),
|
||||
Some(InternalServerKey::Cpu(cpu_key)) => match &self.inner {
|
||||
InnerProvenCompactCiphertextList::Cpu(inner) => inner
|
||||
.expand_without_verification(
|
||||
cpu_key.integer_compact_ciphertext_list_expansion_mode(),
|
||||
)
|
||||
.map(|expander| CompactCiphertextListExpander {
|
||||
inner: InnerCompactCiphertextListExpander::Cpu(expander),
|
||||
tag: self.tag.clone(),
|
||||
}),
|
||||
#[cfg(feature = "gpu")]
|
||||
InnerProvenCompactCiphertextList::Cuda(_) => {
|
||||
Err(crate::Error::new("Tried expanding a ProvenCompactCiphertextList on the GPU, but the set ServerKey is a ServerKey".to_string()))
|
||||
}
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
Some(InternalServerKey::Cuda(gpu_key)) => match &self.inner {
|
||||
InnerProvenCompactCiphertextList::Cuda(inner) => {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
let ksk = CudaKeySwitchingKey {
|
||||
key_switching_key_material: gpu_key
|
||||
.key
|
||||
.cpk_key_switching_key_material
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
dest_server_key: &gpu_key.key.key,
|
||||
};
|
||||
let expander = inner.expand_without_verification(&ksk, streams)?;
|
||||
|
||||
Ok(CompactCiphertextListExpander {
|
||||
inner: InnerCompactCiphertextListExpander::Cuda(expander),
|
||||
tag: self.tag.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
InnerProvenCompactCiphertextList::Cpu(inner) => {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
let gpu_proven_ct = CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
|
||||
inner, streams,
|
||||
);
|
||||
let ksk = CudaKeySwitchingKey {
|
||||
key_switching_key_material: gpu_key
|
||||
.key
|
||||
.cpk_key_switching_key_material
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
dest_server_key: &gpu_key.key.key,
|
||||
};
|
||||
let expander = gpu_proven_ct.expand_without_verification(&ksk, streams)?;
|
||||
|
||||
Ok(CompactCiphertextListExpander {
|
||||
inner: InnerCompactCiphertextListExpander::Cuda(expander),
|
||||
tag: self.tag.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
Some(InternalServerKey::Hpu(_)) => Err(crate::error!("Hpu does not support ProvenCompactCiphertextList")),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -356,9 +619,7 @@ mod zk {
|
||||
type ParameterSet = IntegerProvenCompactCiphertextListConformanceParams;
|
||||
|
||||
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
|
||||
let Self { inner, tag: _ } = self;
|
||||
|
||||
inner.is_conformant(parameter_set)
|
||||
self.inner.on_cpu().is_conformant(parameter_set)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,7 +628,7 @@ mod zk {
|
||||
use super::*;
|
||||
use crate::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams;
|
||||
use crate::shortint::parameters::*;
|
||||
use crate::zk::CompactPkeCrs;
|
||||
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
#[test]
|
||||
@@ -409,14 +670,24 @@ mod zk {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum InnerCompactCiphertextListExpander {
|
||||
Cpu(crate::integer::ciphertext::CompactCiphertextListExpander),
|
||||
#[cfg(feature = "gpu")]
|
||||
Cuda(crate::integer::gpu::ciphertext::compact_list::CudaCompactCiphertextListExpander),
|
||||
}
|
||||
|
||||
pub struct CompactCiphertextListExpander {
|
||||
pub(in crate::high_level_api) inner: crate::integer::ciphertext::CompactCiphertextListExpander,
|
||||
pub inner: InnerCompactCiphertextListExpander,
|
||||
tag: Tag,
|
||||
}
|
||||
|
||||
impl CiphertextList for CompactCiphertextListExpander {
|
||||
fn len(&self) -> usize {
|
||||
self.inner.len()
|
||||
match &self.inner {
|
||||
InnerCompactCiphertextListExpander::Cpu(inner) => inner.len(),
|
||||
#[cfg(feature = "gpu")]
|
||||
InnerCompactCiphertextListExpander::Cuda(inner) => inner.len(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
@@ -424,16 +695,34 @@ impl CiphertextList for CompactCiphertextListExpander {
|
||||
}
|
||||
|
||||
fn get_kind_of(&self, index: usize) -> Option<crate::FheTypes> {
|
||||
self.inner.get_kind_of(index).and_then(|data_kind| {
|
||||
crate::FheTypes::from_data_kind(data_kind, self.inner.message_modulus())
|
||||
})
|
||||
match &self.inner {
|
||||
InnerCompactCiphertextListExpander::Cpu(inner) => {
|
||||
inner.get_kind_of(index).and_then(|data_kind| {
|
||||
crate::FheTypes::from_data_kind(data_kind, inner.message_modulus())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InnerCompactCiphertextListExpander::Cuda(inner) => {
|
||||
inner.get_kind_of(index).and_then(|data_kind| {
|
||||
crate::FheTypes::from_data_kind(data_kind, inner.message_modulus(index)?)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get<T>(&self, index: usize) -> crate::Result<Option<T>>
|
||||
where
|
||||
T: Expandable + Tagged,
|
||||
T: HlExpandable + Tagged,
|
||||
{
|
||||
let mut expanded = self.inner.get::<T>(index);
|
||||
let mut expanded = match &self.inner {
|
||||
InnerCompactCiphertextListExpander::Cpu(inner) => inner.get::<T>(index),
|
||||
#[cfg(feature = "gpu")]
|
||||
InnerCompactCiphertextListExpander::Cuda(inner) => with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
inner.get::<T>(index, streams)
|
||||
}),
|
||||
};
|
||||
|
||||
if let Ok(Some(inner)) = &mut expanded {
|
||||
inner.tag_mut().set_data(self.tag.data());
|
||||
}
|
||||
@@ -531,7 +820,6 @@ impl CompactCiphertextListBuilder {
|
||||
})
|
||||
.expect("Internal error, invalid parameters should not have been allowed")
|
||||
}
|
||||
|
||||
#[cfg(feature = "zk-pok")]
|
||||
pub fn build_with_proof_packed(
|
||||
&self,
|
||||
@@ -542,7 +830,10 @@ impl CompactCiphertextListBuilder {
|
||||
self.inner
|
||||
.build_with_proof_packed(crs, metadata, compute_load)
|
||||
.map(|proved_list| ProvenCompactCiphertextList {
|
||||
inner: proved_list,
|
||||
inner:
|
||||
crate::high_level_api::compact_list::zk::InnerProvenCompactCiphertextList::Cpu(
|
||||
proved_list,
|
||||
),
|
||||
tag: self.tag.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ use crate::high_level_api::booleans::InnerBoolean;
|
||||
use crate::high_level_api::errors::UninitializedServerKey;
|
||||
use crate::high_level_api::global_state::device_of_internal_keys;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
use crate::high_level_api::integers::{FheIntId, FheUintId};
|
||||
use crate::integer::ciphertext::{DataKind, Expandable};
|
||||
#[cfg(feature = "gpu")]
|
||||
@@ -27,7 +27,6 @@ use crate::named::Named;
|
||||
use crate::prelude::{CiphertextList, Tagged};
|
||||
use crate::shortint::Ciphertext;
|
||||
use crate::{Device, FheBool, FheInt, FheUint, Tag};
|
||||
|
||||
impl<Id: FheUintId> HlCompressible for FheUint<Id> {
|
||||
fn compress_into(self, messages: &mut Vec<(ToBeCompressed, DataKind)>) {
|
||||
match self.ciphertext {
|
||||
@@ -142,9 +141,12 @@ impl CompressedCiphertextListBuilder {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
ToBeCompressed::Cuda(cuda_radix) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
flat_cpu_blocks.append(&mut cuda_radix.to_cpu_blocks(streams));
|
||||
});
|
||||
with_thread_local_cuda_streams_for_gpu_indexes(
|
||||
cuda_radix.d_blocks.0.d_vec.gpu_indexes.as_slice(),
|
||||
|streams| {
|
||||
flat_cpu_blocks.append(&mut cuda_radix.to_cpu_blocks(streams));
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -178,17 +180,16 @@ impl CompressedCiphertextListBuilder {
|
||||
for (element, _) in &self.inner {
|
||||
match element {
|
||||
ToBeCompressed::Cpu(cpu_blocks) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_radixes.push(CudaRadixCiphertext::from_cpu_blocks(
|
||||
cpu_blocks, streams,
|
||||
));
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_radixes
|
||||
.push(CudaRadixCiphertext::from_cpu_blocks(cpu_blocks, streams));
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
ToBeCompressed::Cuda(cuda_radix) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_radixes.push(cuda_radix.duplicate(streams));
|
||||
});
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -201,10 +202,11 @@ impl CompressedCiphertextListBuilder {
|
||||
crate::Error::new("Compression key not set in server key".to_owned())
|
||||
})
|
||||
.map(|compression_key| {
|
||||
let packed_list = with_thread_local_cuda_streams(|streams| {
|
||||
let packed_list = {
|
||||
let streams = &cuda_key.streams;
|
||||
compression_key
|
||||
.compress_ciphertexts_into_list(cuda_radixes.as_slice(), streams)
|
||||
});
|
||||
};
|
||||
let info = self.inner.iter().map(|(_, kind)| *kind).collect();
|
||||
|
||||
let compressed_list = CudaCompressedCiphertextList { packed_list, info };
|
||||
@@ -273,7 +275,8 @@ impl InnerCompressedCiphertextList {
|
||||
#[cfg(feature = "gpu")]
|
||||
// We may not be on the correct Cuda device
|
||||
if let Self::Cuda(cuda_ct) = self {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
if cuda_ct.gpu_indexes() != streams.gpu_indexes() {
|
||||
*cuda_ct = cuda_ct.duplicate(streams);
|
||||
}
|
||||
@@ -295,7 +298,8 @@ impl InnerCompressedCiphertextList {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
Device::CudaGpu => {
|
||||
let new_inner = with_thread_local_cuda_streams(|streams| {
|
||||
let new_inner = with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cpu_ct.to_cuda_compressed_ciphertext_list(streams)
|
||||
});
|
||||
*self = Self::Cuda(new_inner);
|
||||
@@ -353,7 +357,8 @@ impl Versionize for InnerCompressedCiphertextList {
|
||||
Self::Cpu(inner) => inner.clone().versionize_owned(),
|
||||
#[cfg(feature = "gpu")]
|
||||
Self::Cuda(inner) => {
|
||||
let cpu_data = with_thread_local_cuda_streams(|streams| {
|
||||
let cpu_data = with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
inner.to_compressed_ciphertext_list(streams)
|
||||
});
|
||||
cpu_data.versionize_owned()
|
||||
@@ -371,7 +376,8 @@ impl VersionizeOwned for InnerCompressedCiphertextList {
|
||||
Self::Cpu(inner) => inner.versionize_owned(),
|
||||
#[cfg(feature = "gpu")]
|
||||
Self::Cuda(inner) => {
|
||||
let cpu_data = with_thread_local_cuda_streams(|streams| {
|
||||
let cpu_data = with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
inner.to_compressed_ciphertext_list(streams)
|
||||
});
|
||||
cpu_data.versionize_owned()
|
||||
@@ -475,11 +481,12 @@ impl CiphertextList for CompressedCiphertextList {
|
||||
crate::Error::new("Compression key not set in server key".to_owned())
|
||||
})
|
||||
.and_then(|decompression_key| {
|
||||
let mut ct = with_thread_local_cuda_streams(|streams| {
|
||||
let mut ct = {
|
||||
let streams = &cuda_key.streams;
|
||||
self.inner
|
||||
.on_gpu(streams)
|
||||
.get::<T>(index, decompression_key, streams)
|
||||
});
|
||||
};
|
||||
if let Ok(Some(ct_ref)) = &mut ct {
|
||||
ct_ref.tag_mut().set_data(cuda_key.tag.data())
|
||||
}
|
||||
@@ -501,7 +508,8 @@ impl CompressedCiphertextList {
|
||||
InnerCompressedCiphertextList::Cpu(inner) => (inner, tag),
|
||||
#[cfg(feature = "gpu")]
|
||||
InnerCompressedCiphertextList::Cuda(inner) => (
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
inner.to_compressed_ciphertext_list(streams)
|
||||
}),
|
||||
tag,
|
||||
|
||||
@@ -73,14 +73,6 @@ pub fn unset_server_key() {
|
||||
|
||||
fn replace_server_key(new_one: Option<impl Into<InternalServerKey>>) -> Option<InternalServerKey> {
|
||||
let keys = new_one.map(Into::into);
|
||||
#[cfg(feature = "gpu")]
|
||||
if let Some(InternalServerKey::Cuda(cuda_key)) = &keys {
|
||||
gpu::CUDA_STREAMS.with_borrow_mut(|current_streams| {
|
||||
if current_streams.gpu_indexes() != cuda_key.gpu_indexes() {
|
||||
*current_streams = cuda_key.build_streams();
|
||||
}
|
||||
});
|
||||
}
|
||||
INTERNAL_KEYS.replace(keys)
|
||||
}
|
||||
|
||||
@@ -187,7 +179,7 @@ where
|
||||
.unwrap_display();
|
||||
let InternalServerKey::Cuda(cuda_key) = key else {
|
||||
panic!(
|
||||
"Cpu key requested but only the key for {:?} is available",
|
||||
"CUDA key requested but only the key for {:?} is available",
|
||||
key.device()
|
||||
)
|
||||
};
|
||||
@@ -196,13 +188,15 @@ where
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub(in crate::high_level_api) use gpu::{
|
||||
with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes,
|
||||
};
|
||||
pub(in crate::high_level_api) use gpu::with_thread_local_cuda_streams_for_gpu_indexes;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use gpu::CudaGpuChoice;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[cfg(feature = "gpu")]
|
||||
pub struct CustomMultiGpuIndexes(Vec<GpuIndex>);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
mod gpu {
|
||||
use crate::core_crypto::gpu::get_number_of_gpus;
|
||||
@@ -210,28 +204,15 @@ mod gpu {
|
||||
use super::*;
|
||||
use std::cell::LazyCell;
|
||||
|
||||
thread_local! {
|
||||
pub(crate) static CUDA_STREAMS: RefCell<CudaStreams> = RefCell::new(CudaStreams::new_multi_gpu());
|
||||
}
|
||||
|
||||
pub(in crate::high_level_api) fn with_thread_local_cuda_streams<
|
||||
R,
|
||||
F: for<'a> FnOnce(&'a CudaStreams) -> R,
|
||||
>(
|
||||
func: F,
|
||||
) -> R {
|
||||
CUDA_STREAMS.with(|cell| func(&cell.borrow()))
|
||||
}
|
||||
|
||||
struct CudaStreamPool {
|
||||
multi: LazyCell<CudaStreams>,
|
||||
custom: Option<CudaStreams>,
|
||||
single: Vec<LazyCell<CudaStreams, Box<dyn Fn() -> CudaStreams>>>,
|
||||
}
|
||||
|
||||
impl CudaStreamPool {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
multi: LazyCell::new(CudaStreams::new_multi_gpu),
|
||||
custom: None,
|
||||
single: (0..get_number_of_gpus())
|
||||
.map(|index| {
|
||||
let ctor =
|
||||
@@ -243,29 +224,6 @@ mod gpu {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> std::ops::Index<&'a [GpuIndex]> for CudaStreamPool {
|
||||
type Output = CudaStreams;
|
||||
|
||||
fn index(&self, indexes: &'a [GpuIndex]) -> &Self::Output {
|
||||
match indexes.len() {
|
||||
0 => panic!("Internal error: Gpu indexes must not be empty"),
|
||||
1 => &self.single[indexes[0].get() as usize],
|
||||
_ => &self.multi,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Index<CudaGpuChoice> for CudaStreamPool {
|
||||
type Output = CudaStreams;
|
||||
|
||||
fn index(&self, choice: CudaGpuChoice) -> &Self::Output {
|
||||
match choice {
|
||||
CudaGpuChoice::Multi => &self.multi,
|
||||
CudaGpuChoice::Single(index) => &self.single[index.get() as usize],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(in crate::high_level_api) fn with_thread_local_cuda_streams_for_gpu_indexes<
|
||||
R,
|
||||
F: for<'a> FnOnce(&'a CudaStreams) -> R,
|
||||
@@ -276,15 +234,38 @@ mod gpu {
|
||||
thread_local! {
|
||||
static POOL: RefCell<CudaStreamPool> = RefCell::new(CudaStreamPool::new());
|
||||
}
|
||||
POOL.with_borrow(|stream_pool| {
|
||||
let stream = &stream_pool[gpu_indexes];
|
||||
func(stream)
|
||||
})
|
||||
|
||||
if gpu_indexes.len() == 1 {
|
||||
POOL.with_borrow(|pool| func(&pool.single[gpu_indexes[0].get() as usize]))
|
||||
} else {
|
||||
POOL.with_borrow_mut(|pool| match &pool.custom {
|
||||
Some(streams) if streams.gpu_indexes != gpu_indexes => {
|
||||
pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes));
|
||||
}
|
||||
None => {
|
||||
pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes));
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
|
||||
POOL.with_borrow(|pool| func(pool.custom.as_ref().unwrap()))
|
||||
}
|
||||
}
|
||||
#[derive(Copy, Clone)]
|
||||
|
||||
impl CustomMultiGpuIndexes {
|
||||
pub fn new(indexes: Vec<GpuIndex>) -> Self {
|
||||
Self(indexes)
|
||||
}
|
||||
pub fn gpu_indexes(&self) -> &[GpuIndex] {
|
||||
self.0.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum CudaGpuChoice {
|
||||
Single(GpuIndex),
|
||||
Multi,
|
||||
Custom(CustomMultiGpuIndexes),
|
||||
}
|
||||
|
||||
impl From<GpuIndex> for CudaGpuChoice {
|
||||
@@ -293,11 +274,24 @@ mod gpu {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<GpuIndex>> for CustomMultiGpuIndexes {
|
||||
fn from(value: Vec<GpuIndex>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CustomMultiGpuIndexes> for CudaGpuChoice {
|
||||
fn from(values: CustomMultiGpuIndexes) -> Self {
|
||||
Self::Custom(values)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaGpuChoice {
|
||||
pub(in crate::high_level_api) fn build_streams(self) -> CudaStreams {
|
||||
match self {
|
||||
Self::Single(idx) => CudaStreams::new_single_gpu(idx),
|
||||
Self::Multi => CudaStreams::new_multi_gpu(),
|
||||
Self::Custom(idxs) => CudaStreams::new_multi_gpu_with_indexes(idxs.gpu_indexes()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use super::{FheIntId, FheUint, FheUintId};
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
@@ -39,7 +37,8 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
Self::new(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let d_ct: CudaUnsignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -50,7 +49,7 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
);
|
||||
|
||||
Self::new(d_ct, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -92,7 +91,8 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
Self::new(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let d_ct: CudaUnsignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -103,7 +103,7 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
streams,
|
||||
);
|
||||
Self::new(d_ct, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -145,7 +145,8 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
Self::new(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let d_ct: CudaSignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -156,7 +157,7 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
);
|
||||
|
||||
Self::new(d_ct, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -200,7 +201,8 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
Self::new(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let d_ct: CudaSignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -211,7 +213,7 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
streams,
|
||||
);
|
||||
Self::new(d_ct, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
|
||||
@@ -17,8 +17,6 @@ use crate::shortint::PBSParameters;
|
||||
use crate::{Device, FheBool, ServerKey, Tag};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
pub trait FheIntId: IntegerId {}
|
||||
|
||||
/// A Generic FHE signed integer
|
||||
@@ -197,13 +195,14 @@ where
|
||||
Self::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.abs(&*self.ciphertext.on_gpu(streams), streams);
|
||||
Self::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -237,13 +236,14 @@ where
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.is_even(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheBool::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -277,13 +277,14 @@ where
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.is_odd(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheBool::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -321,7 +322,8 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -332,7 +334,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -370,7 +372,8 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -381,7 +384,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -419,7 +422,8 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -430,7 +434,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -468,7 +472,8 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -479,7 +484,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -601,7 +606,8 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -612,7 +618,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -659,7 +665,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let (result, is_ok) = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -673,7 +680,7 @@ where
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(is_ok, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -859,7 +866,8 @@ where
|
||||
Self::new(new_ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let target_num_blocks = IntoId::num_blocks(cuda_key.message_modulus());
|
||||
let new_ciphertext = cuda_key.key.key.cast_to_signed(
|
||||
input.ciphertext.into_gpu(streams),
|
||||
@@ -867,7 +875,7 @@ where
|
||||
streams,
|
||||
);
|
||||
Self::new(new_ciphertext, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -908,14 +916,15 @@ where
|
||||
Self::new(new_ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let new_ciphertext = cuda_key.key.key.cast_to_signed(
|
||||
input.ciphertext.into_gpu(streams),
|
||||
IntoId::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(new_ciphertext, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -959,14 +968,15 @@ where
|
||||
Self::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.cast_to_signed(
|
||||
input.ciphertext.into_gpu(streams).0,
|
||||
Id::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use crate::core_crypto::prelude::SignedNumeric;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheIntId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::integer::block_decomposition::{DecomposableInto, RecomposableSignedInteger};
|
||||
@@ -113,14 +111,15 @@ where
|
||||
Ok(Self::new(ciphertext, key.tag.clone()))
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner: CudaSignedRadixCiphertext = cuda_key.key.key.create_trivial_radix(
|
||||
value,
|
||||
Id::num_blocks(cuda_key.key.key.message_modulus),
|
||||
streams,
|
||||
);
|
||||
Ok(Self::new(inner, cuda_key.tag.clone()))
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_) => panic!("Hpu does not currently support signed operation"),
|
||||
})
|
||||
|
||||
@@ -4,9 +4,9 @@ use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::high_level_api::details::MaybeCloned;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::{
|
||||
with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes,
|
||||
};
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams_for_gpu_indexes;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
#[cfg(feature = "gpu")]
|
||||
@@ -38,10 +38,12 @@ impl Clone for SignedRadixCiphertext {
|
||||
match self {
|
||||
Self::Cpu(inner) => Self::Cpu(inner.clone()),
|
||||
#[cfg(feature = "gpu")]
|
||||
Self::Cuda(inner) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = inner.duplicate(streams);
|
||||
Self::Cuda(inner)
|
||||
}),
|
||||
Self::Cuda(inner) => {
|
||||
with_thread_local_cuda_streams_for_gpu_indexes(inner.gpu_indexes(), |streams| {
|
||||
let inner = inner.duplicate(streams);
|
||||
Self::Cuda(inner)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -141,10 +143,10 @@ impl SignedRadixCiphertext {
|
||||
streams: &CudaStreams,
|
||||
) -> MaybeCloned<'_, CudaSignedRadixCiphertext> {
|
||||
match self {
|
||||
Self::Cpu(ct) => with_thread_local_cuda_streams(|streams| {
|
||||
Self::Cpu(ct) => {
|
||||
let ct = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(ct, streams);
|
||||
MaybeCloned::Cloned(ct)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
Self::Cuda(ct) => {
|
||||
if ct.gpu_indexes() == streams.gpu_indexes() {
|
||||
@@ -220,7 +222,8 @@ impl SignedRadixCiphertext {
|
||||
#[cfg(feature = "gpu")]
|
||||
(Self::Cuda(cuda_ct), Device::CudaGpu) => {
|
||||
// We are on a GPU, but it may not be the correct one
|
||||
let new = with_thread_local_cuda_streams(|streams| {
|
||||
let new = with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
if cuda_ct.gpu_indexes() == streams.gpu_indexes() {
|
||||
None
|
||||
} else {
|
||||
@@ -233,14 +236,16 @@ impl SignedRadixCiphertext {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
(Self::Cpu(ct), Device::CudaGpu) => {
|
||||
let new_inner = with_thread_local_cuda_streams(|streams| {
|
||||
let new_inner = with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(ct, streams)
|
||||
});
|
||||
*self = Self::Cuda(new_inner);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
(Self::Cuda(ct), Device::Cpu) => {
|
||||
let new_inner = with_thread_local_cuda_streams(|streams| {
|
||||
let new_inner = with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
ct.to_signed_radix_ciphertext(streams)
|
||||
});
|
||||
*self = Self::Cpu(new_inner);
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::details::MaybeCloned;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
use crate::high_level_api::integers::{FheIntId, FheUintId};
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
|
||||
SubSizeOnGpu,
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
|
||||
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu,
|
||||
ShlSizeOnGpu, ShrSizeOnGpu, SizeOnGpu, SubSizeOnGpu,
|
||||
};
|
||||
use crate::high_level_api::traits::{
|
||||
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
|
||||
@@ -21,9 +24,6 @@ use std::ops::{
|
||||
Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
|
||||
};
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
|
||||
impl<'a, Id> std::iter::Sum<&'a Self> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
@@ -77,7 +77,8 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
let cts = iter
|
||||
.map(|fhe_uint| {
|
||||
match fhe_uint.ciphertext.on_gpu(streams) {
|
||||
@@ -107,7 +108,7 @@ where
|
||||
)
|
||||
});
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -151,14 +152,15 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.max(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -201,14 +203,15 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.min(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -262,14 +265,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.eq(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -305,14 +309,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.ne(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -374,14 +379,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.lt(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -417,14 +423,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.le(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -460,14 +467,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.gt(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -503,14 +511,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.ge(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -589,7 +598,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let (q, r) = cuda_key.key.key.div_rem(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
@@ -599,7 +609,7 @@ where
|
||||
FheInt::<Id>::new(q, cuda_key.tag.clone()),
|
||||
FheInt::<Id>::new(r, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -677,11 +687,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -724,11 +734,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -771,11 +781,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -816,11 +826,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -861,11 +871,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -906,11 +916,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -958,14 +968,14 @@ generic_integer_impl_operation!(
|
||||
FheInt::new(inner_result, cpu_key.tag.clone())
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1013,14 +1023,14 @@ generic_integer_impl_operation!(
|
||||
FheInt::new(inner_result, cpu_key.tag.clone())
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1132,11 +1142,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1180,11 +1190,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1228,11 +1238,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1276,11 +1286,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1330,13 +1340,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.add_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.add_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1381,13 +1390,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.sub_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.sub_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1432,13 +1440,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.mul_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.mul_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1481,13 +1488,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1530,13 +1536,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1579,13 +1584,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1633,7 +1637,8 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
let cuda_lhs = self.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().div(
|
||||
&*cuda_lhs,
|
||||
@@ -1641,7 +1646,7 @@ where
|
||||
streams,
|
||||
);
|
||||
*cuda_lhs = cuda_result;
|
||||
});
|
||||
};
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1689,15 +1694,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let cuda_lhs = self.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().rem(
|
||||
&*cuda_lhs,
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
*cuda_lhs = cuda_result;
|
||||
});
|
||||
let streams = &cuda_key.streams;
|
||||
let cuda_lhs = self.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result =
|
||||
cuda_key
|
||||
.pbs_key()
|
||||
.rem(&*cuda_lhs, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1750,13 +1753,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.left_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.left_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1808,13 +1810,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.right_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.right_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1867,13 +1868,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.rotate_left_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.rotate_left_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1926,13 +1926,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.rotate_right_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.rotate_right_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -2002,13 +2001,14 @@ where
|
||||
FheInt::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.neg(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -2075,13 +2075,14 @@ where
|
||||
FheInt::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.bitnot(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -2101,13 +2102,12 @@ where
|
||||
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_add_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_add_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2126,13 +2126,12 @@ where
|
||||
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_sub_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_sub_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2148,12 +2147,11 @@ where
|
||||
fn get_size_on_gpu(&self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2171,13 +2169,12 @@ where
|
||||
let rhs = rhs.borrow();
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_bitand_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_bitand_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2195,13 +2192,12 @@ where
|
||||
let rhs = rhs.borrow();
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_bitor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_bitor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2219,13 +2215,12 @@ where
|
||||
let rhs = rhs.borrow();
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_bitxor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_bitxor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2241,11 +2236,219 @@ where
|
||||
fn get_bitnot_size_on_gpu(&self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> FheOrdSizeOnGpu<&Self> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
{
|
||||
fn get_gt_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_gt_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_ge_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_ge_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_lt_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_lt_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_le_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_le_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> FheMinSizeOnGpu<&Self> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
{
|
||||
fn get_min_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_min_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> FheMaxSizeOnGpu<&Self> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
{
|
||||
fn get_max_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_max_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Id2> ShlSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
Id2: FheUintId,
|
||||
{
|
||||
fn get_left_shift_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_left_shift_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Id2> ShrSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
Id2: FheUintId,
|
||||
{
|
||||
fn get_right_shift_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_right_shift_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Id2> RotateLeftSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
Id2: FheUintId,
|
||||
{
|
||||
fn get_rotate_left_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_rotate_left_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Id2> RotateRightSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
Id2: FheUintId,
|
||||
{
|
||||
fn get_rotate_right_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_rotate_right_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use crate::core_crypto::prelude::SignedNumeric;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheIntId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
@@ -53,7 +51,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let (result, overflow) = cuda_key.key.key.signed_overflowing_add(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
&other.ciphertext.on_gpu(streams),
|
||||
@@ -63,7 +62,7 @@ where
|
||||
FheInt::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(overflow, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -153,7 +152,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_add(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
other,
|
||||
@@ -163,7 +163,7 @@ where
|
||||
FheInt::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(overflow, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -291,7 +291,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let (result, overflow) = cuda_key.key.key.signed_overflowing_sub(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
&other.ciphertext.on_gpu(streams),
|
||||
@@ -301,7 +302,7 @@ where
|
||||
FheInt::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(overflow, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -390,7 +391,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_sub(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
other,
|
||||
@@ -400,7 +402,7 @@ where
|
||||
FheInt::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(overflow, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
|
||||
@@ -3,13 +3,15 @@ use crate::core_crypto::commons::numeric::CastFrom;
|
||||
use crate::high_level_api::errors::UnwrapResultExt;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
use crate::high_level_api::integers::signed::inner::SignedRadixCiphertext;
|
||||
use crate::high_level_api::integers::FheIntId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SubSizeOnGpu,
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, FheMaxSizeOnGpu,
|
||||
FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu,
|
||||
ShrSizeOnGpu, SubSizeOnGpu,
|
||||
};
|
||||
use crate::high_level_api::traits::{
|
||||
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
|
||||
@@ -59,14 +61,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.scalar_max(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
rhs,
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -111,14 +112,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.scalar_min(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
rhs,
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -162,14 +162,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -207,14 +206,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -257,14 +255,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -301,14 +298,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -345,14 +341,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -389,14 +384,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -406,6 +400,118 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Clear> FheOrdSizeOnGpu<Clear> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
Clear: DecomposableInto<u64>,
|
||||
{
|
||||
fn get_gt_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_ge_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_lt_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_le_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Clear> FheMinSizeOnGpu<Clear> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
Clear: DecomposableInto<u64>,
|
||||
{
|
||||
fn get_min_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_min_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Clear> FheMaxSizeOnGpu<Clear> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
Clear: DecomposableInto<u64>,
|
||||
{
|
||||
fn get_max_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_max_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// DivRem is a bit special as it returns a tuple of quotient and remainder
|
||||
macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
(
|
||||
@@ -444,11 +550,11 @@ macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
|
||||
let (inner_q, inner_r) = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.signed_scalar_div_rem(
|
||||
&*self.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
let (q, r) = (
|
||||
SignedRadixCiphertext::Cuda(inner_q),
|
||||
SignedRadixCiphertext::Cuda(inner_r),
|
||||
@@ -501,11 +607,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_left_shift(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -521,6 +627,31 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: ShlSizeOnGpu(get_left_shift_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_scalar_left_shift_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type,)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: Shr(shr),
|
||||
implem: {
|
||||
@@ -534,11 +665,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_right_shift(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -554,6 +685,31 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: ShrSizeOnGpu(get_right_shift_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_scalar_right_shift_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type,)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: RotateLeft(rotate_left),
|
||||
implem: {
|
||||
@@ -567,11 +723,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_rotate_left(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -587,6 +743,31 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: RotateLeftSizeOnGpu(get_rotate_left_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_scalar_rotate_left_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type,)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: RotateRight(rotate_right),
|
||||
implem: {
|
||||
@@ -600,11 +781,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_rotate_right(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -620,6 +801,31 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: RotateRightSizeOnGpu(get_rotate_right_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_scalar_rotate_right_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type,)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation_assign!(
|
||||
rust_trait: ShlAssign(shl_assign),
|
||||
implem: {
|
||||
@@ -631,11 +837,10 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
.scalar_left_shift_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) =>
|
||||
{let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -662,10 +867,9 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -691,11 +895,10 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
.scalar_rotate_left_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) =>
|
||||
{let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -722,10 +925,10 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -839,11 +1042,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_add(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -866,12 +1069,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_add_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -896,11 +1098,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_sub(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -923,12 +1125,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_sub_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -953,11 +1154,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_mul(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -987,11 +1188,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_bitand(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1014,12 +1215,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitand_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1044,11 +1244,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_bitor(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1071,12 +1271,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitor_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1102,11 +1301,11 @@ macro_rules! define_scalar_ops {
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_bitxor(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1129,12 +1328,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1159,11 +1357,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.signed_scalar_div(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1193,11 +1391,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.signed_scalar_rem(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1237,12 +1435,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheInt<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_add_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1267,12 +1464,11 @@ macro_rules! define_scalar_ops {
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let mut result: CudaSignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix(
|
||||
lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
|
||||
cuda_key.pbs_key().sub_assign(&mut result, &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
SignedRadixCiphertext::Cuda(result)
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1294,12 +1490,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheInt<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_sub_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1350,12 +1545,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheInt<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitand_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1389,12 +1583,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheInt<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitor_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1428,12 +1621,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheInt<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1459,10 +1651,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1494,10 +1685,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1525,10 +1715,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1557,10 +1746,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1588,10 +1776,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1618,10 +1805,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1647,11 +1833,12 @@ macro_rules! define_scalar_ops {
|
||||
.signed_scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().signed_scalar_div(&cuda_lhs, rhs, streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1676,11 +1863,11 @@ macro_rules! define_scalar_ops {
|
||||
.signed_scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
|
||||
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().signed_scalar_rem(&cuda_lhs, rhs, streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
|
||||
@@ -2,13 +2,13 @@ use crate::high_level_api::integers::signed::tests::{
|
||||
test_case_ilog2, test_case_leading_trailing_zeros_ones,
|
||||
};
|
||||
use crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu;
|
||||
use crate::high_level_api::traits::AddSizeOnGpu;
|
||||
use crate::prelude::{
|
||||
check_valid_cuda_malloc, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
|
||||
FheTryEncrypt, SubSizeOnGpu,
|
||||
check_valid_cuda_malloc, AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu,
|
||||
BitXorSizeOnGpu, FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, FheTryEncrypt,
|
||||
RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu, ShrSizeOnGpu, SubSizeOnGpu,
|
||||
};
|
||||
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
use crate::{FheInt32, GpuIndex};
|
||||
use crate::{FheInt32, FheUint32, GpuIndex};
|
||||
use rand::Rng;
|
||||
|
||||
#[test]
|
||||
@@ -162,3 +162,132 @@ fn test_gpu_get_bitops_size_on_gpu() {
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
}
|
||||
#[test]
|
||||
fn test_gpu_get_comparisons_size_on_gpu() {
|
||||
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
|
||||
let mut rng = rand::thread_rng();
|
||||
let clear_a = rng.gen_range(1..=i32::MAX);
|
||||
let clear_b = rng.gen_range(1..=i32::MAX);
|
||||
let mut a = FheInt32::try_encrypt(clear_a, &cks).unwrap();
|
||||
let mut b = FheInt32::try_encrypt(clear_b, &cks).unwrap();
|
||||
a.move_to_current_device();
|
||||
b.move_to_current_device();
|
||||
let a = &a;
|
||||
let b = &b;
|
||||
|
||||
let gt_tmp_buffer_size = a.get_gt_size_on_gpu(b);
|
||||
let scalar_gt_tmp_buffer_size = a.get_gt_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
gt_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_gt_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let ge_tmp_buffer_size = a.get_ge_size_on_gpu(b);
|
||||
let scalar_ge_tmp_buffer_size = a.get_ge_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
ge_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_ge_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let lt_tmp_buffer_size = a.get_lt_size_on_gpu(b);
|
||||
let scalar_lt_tmp_buffer_size = a.get_lt_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
lt_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_lt_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let le_tmp_buffer_size = a.get_le_size_on_gpu(b);
|
||||
let scalar_le_tmp_buffer_size = a.get_le_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
le_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_le_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let max_tmp_buffer_size = a.get_max_size_on_gpu(b);
|
||||
let scalar_max_tmp_buffer_size = a.get_max_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
max_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_max_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let min_tmp_buffer_size = a.get_min_size_on_gpu(b);
|
||||
let scalar_min_tmp_buffer_size = a.get_min_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
min_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_min_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_get_shift_rotate_size_on_gpu() {
|
||||
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
|
||||
let mut rng = rand::thread_rng();
|
||||
let clear_a = rng.gen_range(1..=i32::MAX);
|
||||
let clear_b = rng.gen_range(1..=u32::MAX);
|
||||
let mut a = FheInt32::try_encrypt(clear_a, &cks).unwrap();
|
||||
let mut b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
|
||||
a.move_to_current_device();
|
||||
b.move_to_current_device();
|
||||
let a = &a;
|
||||
let b = &b;
|
||||
|
||||
let left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(b);
|
||||
let scalar_left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
left_shift_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_left_shift_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(b);
|
||||
let scalar_right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
right_shift_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_right_shift_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(b);
|
||||
let scalar_rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
rotate_left_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_rotate_left_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(b);
|
||||
let scalar_rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
rotate_right_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_rotate_right_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@ use super::inner::RadixCiphertext;
|
||||
use crate::backward_compatibility::integers::FheUintVersions;
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::core_crypto::prelude::{CastFrom, UnsignedInteger, UnsignedNumeric};
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::signed::{FheInt, FheIntId};
|
||||
use crate::high_level_api::integers::IntegerId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
@@ -304,13 +302,14 @@ where
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.is_even(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheBool::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -344,13 +343,14 @@ where
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.is_odd(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheBool::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -481,7 +481,8 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -492,7 +493,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -530,7 +531,8 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -541,7 +543,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -579,7 +581,8 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -590,7 +593,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -628,7 +631,8 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -639,7 +643,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -761,7 +765,8 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -772,7 +777,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -819,7 +824,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let (result, is_ok) = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -833,7 +839,7 @@ where
|
||||
super::FheUint32::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(is_ok, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -905,7 +911,8 @@ where
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let (result, matched) = cuda_key.key.key.match_value(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
matches,
|
||||
@@ -920,7 +927,7 @@ where
|
||||
} else {
|
||||
Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string()))
|
||||
}
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -987,7 +994,8 @@ where
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let result = cuda_key.key.key.match_value_or(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
matches,
|
||||
@@ -1000,7 +1008,7 @@ where
|
||||
} else {
|
||||
Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string()))
|
||||
}
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1216,14 +1224,15 @@ where
|
||||
Self::new(casted, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let casted = cuda_key.key.key.cast_to_unsigned(
|
||||
input.ciphertext.into_gpu(streams),
|
||||
IntoId::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(casted, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1264,14 +1273,15 @@ where
|
||||
Self::new(casted, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let casted = cuda_key.key.key.cast_to_unsigned(
|
||||
input.ciphertext.into_gpu(streams),
|
||||
IntoId::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(casted, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1312,14 +1322,15 @@ where
|
||||
Self::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner = cuda_key.key.key.cast_to_unsigned(
|
||||
input.ciphertext.into_gpu(streams).0,
|
||||
Id::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use crate::core_crypto::prelude::UnsignedNumeric;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom};
|
||||
@@ -115,14 +113,15 @@ where
|
||||
Ok(Self::new(ciphertext, key.tag.clone()))
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix(
|
||||
value,
|
||||
Id::num_blocks(cuda_key.key.key.message_modulus),
|
||||
streams,
|
||||
);
|
||||
Ok(Self::new(inner, cuda_key.tag.clone()))
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support trivial encryption")
|
||||
|
||||
@@ -4,9 +4,9 @@ use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::high_level_api::details::MaybeCloned;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::{
|
||||
with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes,
|
||||
};
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams_for_gpu_indexes;
|
||||
#[cfg(feature = "hpu")]
|
||||
use crate::high_level_api::keys::HpuTaggedDevice;
|
||||
#[cfg(feature = "gpu")]
|
||||
@@ -52,9 +52,10 @@ impl Clone for RadixCiphertext {
|
||||
match self {
|
||||
Self::Cpu(inner) => Self::Cpu(inner.clone()),
|
||||
#[cfg(feature = "gpu")]
|
||||
Self::Cuda(inner) => {
|
||||
with_thread_local_cuda_streams(|streams| Self::Cuda(inner.duplicate(streams)))
|
||||
}
|
||||
Self::Cuda(inner) => with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
Self::Cuda(inner.duplicate(streams))
|
||||
}),
|
||||
#[cfg(feature = "hpu")]
|
||||
Self::Hpu(inner) => {
|
||||
// NB: Hpu backends flavor behavs differently regarding memory.
|
||||
@@ -166,10 +167,13 @@ impl RadixCiphertext {
|
||||
match self {
|
||||
Self::Cpu(ct) => MaybeCloned::Borrowed(ct),
|
||||
#[cfg(feature = "gpu")]
|
||||
Self::Cuda(ct) => with_thread_local_cuda_streams(|streams| {
|
||||
let cpu_ct = ct.to_radix_ciphertext(streams);
|
||||
MaybeCloned::Cloned(cpu_ct)
|
||||
}),
|
||||
Self::Cuda(ct) => with_thread_local_cuda_streams_for_gpu_indexes(
|
||||
ct.ciphertext.d_blocks.0.d_vec.gpu_indexes.as_slice(),
|
||||
|streams| {
|
||||
let cpu_ct = ct.to_radix_ciphertext(streams);
|
||||
MaybeCloned::Cloned(cpu_ct)
|
||||
},
|
||||
),
|
||||
#[cfg(feature = "hpu")]
|
||||
Self::Hpu(hpu_ct) => {
|
||||
let cpu_inner = hpu_ct.to_radix_ciphertext();
|
||||
@@ -297,7 +301,8 @@ impl RadixCiphertext {
|
||||
#[cfg(feature = "gpu")]
|
||||
// We may not be on the correct Cuda device
|
||||
if let Self::Cuda(cuda_ct) = self {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
if cuda_ct.gpu_indexes() != streams.gpu_indexes() {
|
||||
*cuda_ct = cuda_ct.duplicate(streams);
|
||||
}
|
||||
@@ -319,7 +324,8 @@ impl RadixCiphertext {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
Device::CudaGpu => {
|
||||
let new_inner = with_thread_local_cuda_streams(|streams| {
|
||||
let new_inner = with_cuda_internal_keys(|key| {
|
||||
let streams = &key.streams;
|
||||
crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext::from_radix_ciphertext(
|
||||
&cpu_ct, streams,
|
||||
)
|
||||
|
||||
@@ -7,13 +7,14 @@ use super::inner::RadixCiphertext;
|
||||
use crate::high_level_api::details::MaybeCloned;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
|
||||
SubSizeOnGpu,
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
|
||||
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu,
|
||||
ShlSizeOnGpu, ShrSizeOnGpu, SizeOnGpu, SubSizeOnGpu,
|
||||
};
|
||||
use crate::high_level_api::traits::{
|
||||
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
|
||||
@@ -83,7 +84,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let cts = iter
|
||||
.map(|fhe_uint| fhe_uint.ciphertext.into_gpu(streams))
|
||||
.collect::<Vec<_>>();
|
||||
@@ -100,7 +102,7 @@ where
|
||||
)
|
||||
});
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let mut iter = iter;
|
||||
@@ -171,7 +173,8 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
let cts = iter
|
||||
.map(|fhe_uint| {
|
||||
match fhe_uint.ciphertext.on_gpu(streams) {
|
||||
@@ -201,7 +204,7 @@ where
|
||||
)
|
||||
});
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -256,14 +259,15 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.max(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -306,14 +310,15 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.min(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -367,14 +372,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.eq(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.on_hpu(device);
|
||||
@@ -428,14 +434,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.ne(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.on_hpu(device);
|
||||
@@ -515,14 +522,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.lt(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.on_hpu(device);
|
||||
@@ -576,14 +584,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.le(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.on_hpu(device);
|
||||
@@ -637,14 +646,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.gt(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.on_hpu(device);
|
||||
@@ -698,14 +708,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.ge(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.on_hpu(device);
|
||||
@@ -803,7 +814,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.div_rem(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
@@ -813,7 +825,7 @@ where
|
||||
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
|
||||
FheUint::<Id>::new(inner_result.1, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -893,11 +905,10 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -942,11 +953,10 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -991,11 +1001,10 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -1038,11 +1047,10 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -1085,11 +1093,10 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -1132,11 +1139,10 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -1186,14 +1192,15 @@ generic_integer_impl_operation!(
|
||||
FheUint::new(inner_result, cpu_key.tag.clone())
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1241,14 +1248,16 @@ generic_integer_impl_operation!(
|
||||
FheUint::new(inner_result, cpu_key.tag.clone())
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) =>
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1360,11 +1369,10 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1408,11 +1416,10 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1456,11 +1463,10 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1504,11 +1510,10 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key
|
||||
.rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1557,13 +1562,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.add_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
|
||||
@@ -1608,13 +1614,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.sub_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
|
||||
@@ -1659,13 +1666,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.mul_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
|
||||
@@ -1708,13 +1716,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
|
||||
@@ -1757,13 +1766,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
|
||||
@@ -1806,13 +1816,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_lhs = self.ciphertext.as_hpu_mut(device);
|
||||
@@ -1860,13 +1871,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.div_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1912,13 +1924,14 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.rem_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1970,13 +1983,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.left_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
};
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -2028,13 +2042,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.right_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
};
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -2087,13 +2102,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.rotate_left_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
};
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -2146,13 +2162,12 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.rotate_right_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.rotate_right_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -2230,13 +2245,14 @@ where
|
||||
FheUint::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.neg(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -2303,13 +2319,14 @@ where
|
||||
FheUint::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.bitnot(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support bitnot (operator `!`)")
|
||||
@@ -2327,13 +2344,12 @@ where
|
||||
let rhs = rhs.borrow();
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_add_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_add_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2351,13 +2367,12 @@ where
|
||||
let rhs = rhs.borrow();
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_sub_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_sub_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2372,12 +2387,11 @@ where
|
||||
fn get_size_on_gpu(&self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2395,13 +2409,12 @@ where
|
||||
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_bitand_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_bitand_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2419,13 +2432,12 @@ where
|
||||
let rhs = rhs.borrow();
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_bitor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_bitor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2443,13 +2455,12 @@ where
|
||||
let rhs = rhs.borrow();
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.get_bitxor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_bitxor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -2465,11 +2476,214 @@ where
|
||||
fn get_bitnot_size_on_gpu(&self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> FheOrdSizeOnGpu<&Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
fn get_gt_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_gt_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_ge_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_ge_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_lt_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_lt_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_le_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_le_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> FheMinSizeOnGpu<&Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
fn get_min_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_min_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> FheMaxSizeOnGpu<&Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
fn get_max_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_max_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> ShlSizeOnGpu<&Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
fn get_left_shift_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_left_shift_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> ShrSizeOnGpu<&Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
fn get_right_shift_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_right_shift_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> RotateLeftSizeOnGpu<&Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
fn get_rotate_left_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_rotate_left_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> RotateRightSizeOnGpu<&Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
fn get_rotate_right_size_on_gpu(&self, rhs: &Self) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_rotate_right_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use crate::core_crypto::prelude::UnsignedNumeric;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
@@ -53,7 +51,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.unsigned_overflowing_add(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
&other.ciphertext.on_gpu(streams),
|
||||
@@ -63,7 +62,7 @@ where
|
||||
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
|
||||
FheBool::new(inner_result.1, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -153,7 +152,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.unsigned_overflowing_scalar_add(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
other,
|
||||
@@ -163,7 +163,7 @@ where
|
||||
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
|
||||
FheBool::new(inner_result.1, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -293,7 +293,8 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result = cuda_key.key.key.unsigned_overflowing_sub(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
&other.ciphertext.on_gpu(streams),
|
||||
@@ -303,7 +304,7 @@ where
|
||||
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
|
||||
FheBool::new(inner_result.1, cuda_key.tag.clone()),
|
||||
)
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
|
||||
@@ -8,12 +8,14 @@ use crate::error::InvalidRangeError;
|
||||
use crate::high_level_api::errors::UnwrapResultExt;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::global_state::with_cuda_internal_keys;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SubSizeOnGpu,
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, FheMaxSizeOnGpu,
|
||||
FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu,
|
||||
ShrSizeOnGpu, SubSizeOnGpu,
|
||||
};
|
||||
use crate::high_level_api::traits::{
|
||||
BitSlice, DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
|
||||
@@ -64,14 +66,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -107,14 +110,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -156,14 +160,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -199,14 +204,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -242,14 +248,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -285,14 +292,15 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -301,6 +309,119 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Clear> FheOrdSizeOnGpu<Clear> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
Clear: DecomposableInto<u64>,
|
||||
{
|
||||
fn get_gt_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_gt_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_ge_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_ge_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_lt_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_lt_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_le_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Clear> FheMinSizeOnGpu<Clear> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
Clear: DecomposableInto<u64>,
|
||||
{
|
||||
fn get_min_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_min_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, Clear> FheMaxSizeOnGpu<Clear> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
Clear: DecomposableInto<u64>,
|
||||
{
|
||||
fn get_max_size_on_gpu(&self, _rhs: Clear) -> u64 {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_scalar_max_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id, Clear> FheMax<Clear> for FheUint<Id>
|
||||
where
|
||||
Clear: DecomposableInto<u64>,
|
||||
@@ -336,14 +457,15 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -387,14 +509,15 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}),
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -533,11 +656,12 @@ macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
|
||||
let (inner_q, inner_r) = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_div_rem(
|
||||
&*self.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
|
||||
(
|
||||
<$concrete_type>::new(q, cuda_key.tag.clone()),
|
||||
@@ -785,11 +909,12 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_left_shift(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -805,6 +930,31 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: ShlSizeOnGpu(get_left_shift_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_scalar_left_shift_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type,)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: Shr(shr),
|
||||
implem: {
|
||||
@@ -818,11 +968,12 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_right_shift(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -838,6 +989,31 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: ShrSizeOnGpu(get_right_shift_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_scalar_right_shift_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type,)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: RotateLeft(rotate_left),
|
||||
implem: {
|
||||
@@ -851,11 +1027,12 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_rotate_left(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -871,6 +1048,31 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: RotateLeftSizeOnGpu(get_rotate_left_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_scalar_rotate_left_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type,)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: RotateRight(rotate_right),
|
||||
implem: {
|
||||
@@ -884,11 +1086,12 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_rotate_right(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -904,6 +1107,31 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: RotateRightSizeOnGpu(get_rotate_right_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_cuda_internal_keys(|keys| {
|
||||
let streams = &keys.streams;
|
||||
cuda_key.key.key.get_scalar_rotate_right_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type,)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation_assign!(
|
||||
rust_trait: ShlAssign(shl_assign),
|
||||
implem: {
|
||||
@@ -916,10 +1144,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -946,10 +1175,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -976,10 +1206,9 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1006,10 +1235,9 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1124,11 +1352,12 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_add(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1154,12 +1383,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_add_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1184,11 +1412,12 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_sub(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1214,12 +1443,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_sub_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1244,11 +1472,12 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_mul(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1281,11 +1510,12 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_bitand(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1308,12 +1538,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitand_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1338,11 +1567,12 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_bitor(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1365,12 +1595,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitor_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1396,11 +1625,12 @@ macro_rules! define_scalar_ops {
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_bitxor(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1423,12 +1653,11 @@ macro_rules! define_scalar_ops {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1453,11 +1682,12 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_div(
|
||||
&lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1487,11 +1717,12 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.scalar_rem(
|
||||
&lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
});
|
||||
};
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -1531,12 +1762,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheUint<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_add_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1560,12 +1790,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
let mut result: CudaUnsignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix(
|
||||
lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
|
||||
cuda_key.pbs_key().sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
RadixCiphertext::Cuda(result)
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1587,12 +1816,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheUint<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_sub_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1643,12 +1871,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheUint<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitand_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1682,12 +1909,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheUint<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitor_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1721,12 +1947,11 @@ macro_rules! define_scalar_ops {
|
||||
|_lhs, rhs: &FheUint<_>| {
|
||||
global_state::with_internal_keys(|key|
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
0
|
||||
})
|
||||
@@ -1752,10 +1977,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -1790,10 +2014,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -1824,10 +2047,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
@@ -1859,10 +2081,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1890,10 +2111,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1920,10 +2140,9 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key
|
||||
.scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
@@ -1949,11 +2168,12 @@ macro_rules! define_scalar_ops {
|
||||
.scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().scalar_div(&cuda_lhs, rhs, streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
@@ -1978,11 +2198,12 @@ macro_rules! define_scalar_ops {
|
||||
.scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().scalar_rem(&cuda_lhs, rhs, streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::high_level_api::traits::AddSizeOnGpu;
|
||||
use crate::prelude::{
|
||||
check_valid_cuda_malloc, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
|
||||
FheTryEncrypt, SubSizeOnGpu,
|
||||
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, FheTryEncrypt, RotateLeftSizeOnGpu,
|
||||
RotateRightSizeOnGpu, ShlSizeOnGpu, ShrSizeOnGpu, SubSizeOnGpu,
|
||||
};
|
||||
use crate::shortint::parameters::{
|
||||
TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
@@ -254,3 +255,132 @@ fn test_gpu_get_bitops_size_on_gpu() {
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
}
|
||||
#[test]
|
||||
fn test_gpu_get_comparisons_size_on_gpu() {
|
||||
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
|
||||
let mut rng = rand::thread_rng();
|
||||
let clear_a = rng.gen_range(1..=u32::MAX);
|
||||
let clear_b = rng.gen_range(1..=u32::MAX);
|
||||
let mut a = FheUint32::try_encrypt(clear_a, &cks).unwrap();
|
||||
let mut b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
|
||||
a.move_to_current_device();
|
||||
b.move_to_current_device();
|
||||
let a = &a;
|
||||
let b = &b;
|
||||
|
||||
let gt_tmp_buffer_size = a.get_gt_size_on_gpu(b);
|
||||
let scalar_gt_tmp_buffer_size = a.get_gt_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
gt_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_gt_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let ge_tmp_buffer_size = a.get_ge_size_on_gpu(b);
|
||||
let scalar_ge_tmp_buffer_size = a.get_ge_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
ge_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_ge_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let lt_tmp_buffer_size = a.get_lt_size_on_gpu(b);
|
||||
let scalar_lt_tmp_buffer_size = a.get_lt_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
lt_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_lt_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let le_tmp_buffer_size = a.get_le_size_on_gpu(b);
|
||||
let scalar_le_tmp_buffer_size = a.get_le_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
le_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_le_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let max_tmp_buffer_size = a.get_max_size_on_gpu(b);
|
||||
let scalar_max_tmp_buffer_size = a.get_max_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
max_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_max_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let min_tmp_buffer_size = a.get_min_size_on_gpu(b);
|
||||
let scalar_min_tmp_buffer_size = a.get_min_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
min_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_min_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_get_shift_rotate_size_on_gpu() {
|
||||
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
|
||||
let mut rng = rand::thread_rng();
|
||||
let clear_a = rng.gen_range(1..=u32::MAX);
|
||||
let clear_b = rng.gen_range(1..=u32::MAX);
|
||||
let mut a = FheUint32::try_encrypt(clear_a, &cks).unwrap();
|
||||
let mut b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
|
||||
a.move_to_current_device();
|
||||
b.move_to_current_device();
|
||||
let a = &a;
|
||||
let b = &b;
|
||||
|
||||
let left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(b);
|
||||
let scalar_left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
left_shift_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_left_shift_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(b);
|
||||
let scalar_right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
right_shift_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_right_shift_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(b);
|
||||
let scalar_rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
rotate_left_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_rotate_left_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(b);
|
||||
let scalar_rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(clear_b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
rotate_right_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_rotate_right_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ impl IntegerClientKey {
|
||||
);
|
||||
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
|
||||
let cks = crate::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder)
|
||||
.new_client_key(config.block_parameters.into());
|
||||
.new_client_key(config.block_parameters);
|
||||
|
||||
let key = crate::integer::ClientKey::from(cks);
|
||||
|
||||
@@ -172,7 +172,7 @@ impl IntegerClientKey {
|
||||
|
||||
if let Some(dedicated_compact_private_key) = dedicated_compact_private_key.as_ref() {
|
||||
assert_eq!(
|
||||
shortint_cks.parameters.message_modulus(),
|
||||
shortint_cks.parameters().message_modulus(),
|
||||
dedicated_compact_private_key
|
||||
.0
|
||||
.key
|
||||
@@ -180,7 +180,7 @@ impl IntegerClientKey {
|
||||
.message_modulus,
|
||||
);
|
||||
assert_eq!(
|
||||
shortint_cks.parameters.carry_modulus(),
|
||||
shortint_cks.parameters().carry_modulus(),
|
||||
dedicated_compact_private_key
|
||||
.0
|
||||
.key
|
||||
@@ -315,6 +315,9 @@ impl IntegerServerKey {
|
||||
#[cfg(feature = "gpu")]
|
||||
pub struct IntegerCudaServerKey {
|
||||
pub(crate) key: crate::integer::gpu::CudaServerKey,
|
||||
#[allow(dead_code)]
|
||||
pub(crate) cpk_key_switching_key_material:
|
||||
Option<crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial>,
|
||||
pub(crate) compression_key:
|
||||
Option<crate::integer::gpu::list_compression::server_keys::CudaCompressionKey>,
|
||||
pub(crate) decompression_key:
|
||||
|
||||
@@ -8,6 +8,8 @@ use super::ClientKey;
|
||||
use crate::backward_compatibility::keys::{CompressedServerKeyVersions, ServerKeyVersions};
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::core_crypto::gpu::{synchronize_devices, CudaStreams};
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::keys::inner::IntegerCudaServerKey;
|
||||
@@ -292,6 +294,21 @@ impl CompressedServerKey {
|
||||
&self.integer_key.key,
|
||||
&streams,
|
||||
);
|
||||
let cpk_key_switching_key_material = self
|
||||
.integer_key
|
||||
.cpk_key_switching_key_material
|
||||
.as_ref()
|
||||
.map(|cpk_ksk_material| {
|
||||
let ksk_material = cpk_ksk_material.decompress();
|
||||
let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
|
||||
&ksk_material.material.key_switching_key,
|
||||
&streams,
|
||||
);
|
||||
CudaKeySwitchingKeyMaterial {
|
||||
lwe_keyswitch_key: d_ksk,
|
||||
destination_key: ksk_material.material.destination_key,
|
||||
}
|
||||
});
|
||||
let compression_key: Option<
|
||||
crate::integer::gpu::list_compression::server_keys::CudaCompressionKey,
|
||||
> = self
|
||||
@@ -328,10 +345,12 @@ impl CompressedServerKey {
|
||||
CudaServerKey {
|
||||
key: Arc::new(IntegerCudaServerKey {
|
||||
key,
|
||||
cpk_key_switching_key_material,
|
||||
compression_key,
|
||||
decompression_key,
|
||||
}),
|
||||
tag: self.tag.clone(),
|
||||
streams,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -355,6 +374,7 @@ impl Named for CompressedServerKey {
|
||||
pub struct CudaServerKey {
|
||||
pub(crate) key: Arc<IntegerCudaServerKey>,
|
||||
pub(crate) tag: Tag,
|
||||
pub(crate) streams: CudaStreams,
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
@@ -370,14 +390,6 @@ impl CudaServerKey {
|
||||
pub fn gpu_indexes(&self) -> &[GpuIndex] {
|
||||
&self.key.key.key_switching_key.d_vec.gpu_indexes
|
||||
}
|
||||
|
||||
pub(crate) fn build_streams(&self) -> CudaStreams {
|
||||
if self.gpu_indexes().len() == 1 {
|
||||
CudaStreams::new_single_gpu(self.gpu_indexes()[0])
|
||||
} else {
|
||||
CudaStreams::new_multi_gpu()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
@@ -460,6 +472,9 @@ mod hpu {
|
||||
|
||||
use crate::high_level_api::keys::inner::IntegerServerKeyConformanceParams;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial;
|
||||
|
||||
impl ParameterSetConformant for ServerKey {
|
||||
type ParameterSet = IntegerServerKeyConformanceParams;
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ pub use crate::high_level_api::gpu_utils::*;
|
||||
pub use crate::high_level_api::strings::traits::*;
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
|
||||
SubSizeOnGpu,
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
|
||||
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu,
|
||||
ShlSizeOnGpu, ShrSizeOnGpu, SizeOnGpu, SubSizeOnGpu,
|
||||
};
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use rand::Rng;
|
||||
|
||||
use crate::core_crypto::gpu::get_number_of_gpus;
|
||||
use crate::high_level_api::global_state::CustomMultiGpuIndexes;
|
||||
use crate::prelude::*;
|
||||
use crate::{
|
||||
set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device, FheUint32, GpuIndex,
|
||||
set_server_key, unset_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device,
|
||||
FheUint32, GpuIndex,
|
||||
};
|
||||
|
||||
#[test]
|
||||
@@ -117,3 +119,74 @@ fn test_gpu_selection_2() {
|
||||
assert_eq!(c.gpu_indexes(), &[first_gpu]);
|
||||
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_specific_gpu_selection() {
|
||||
let config = ConfigBuilder::default().build();
|
||||
let keys = ClientKey::generate(config);
|
||||
let compressed_server_keys = CompressedServerKey::new(&keys);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let total_gpus = get_number_of_gpus() as usize;
|
||||
// There are 2^total_gpus possible subsets, excluding the empty one
|
||||
for num_gpus_to_use in 1..(1 << total_gpus) {
|
||||
let mut selected_indices = Vec::new();
|
||||
for j in 0..total_gpus {
|
||||
if (num_gpus_to_use & (1 << j)) != 0 {
|
||||
selected_indices.push(j);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the selected indices to GpuIndex objects
|
||||
let gpus_to_be_used = CustomMultiGpuIndexes::new(
|
||||
selected_indices
|
||||
.iter()
|
||||
.map(|idx| GpuIndex::new(*idx as u32))
|
||||
.collect(),
|
||||
);
|
||||
|
||||
let cuda_key = compressed_server_keys.decompress_to_specific_gpu(gpus_to_be_used);
|
||||
|
||||
let first_gpu = GpuIndex::new(selected_indices[0] as u32);
|
||||
|
||||
let clear_a: u32 = rng.gen();
|
||||
let clear_b: u32 = rng.gen();
|
||||
|
||||
let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap();
|
||||
let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap();
|
||||
|
||||
assert_eq!(a.current_device(), Device::Cpu);
|
||||
assert_eq!(b.current_device(), Device::Cpu);
|
||||
assert_eq!(a.gpu_indexes(), &[]);
|
||||
assert_eq!(b.gpu_indexes(), &[]);
|
||||
|
||||
set_server_key(cuda_key);
|
||||
let c = &a + &b;
|
||||
let decrypted: u32 = c.decrypt(&keys);
|
||||
assert_eq!(c.current_device(), Device::CudaGpu);
|
||||
assert_eq!(c.gpu_indexes(), &[first_gpu]);
|
||||
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
|
||||
|
||||
// Check explicit move, but first make sure input are on Cpu still
|
||||
assert_eq!(a.current_device(), Device::Cpu);
|
||||
assert_eq!(b.current_device(), Device::Cpu);
|
||||
assert_eq!(a.gpu_indexes(), &[]);
|
||||
assert_eq!(b.gpu_indexes(), &[]);
|
||||
|
||||
a.move_to_current_device();
|
||||
b.move_to_current_device();
|
||||
|
||||
assert_eq!(a.current_device(), Device::CudaGpu);
|
||||
assert_eq!(b.current_device(), Device::CudaGpu);
|
||||
assert_eq!(a.gpu_indexes(), &[first_gpu]);
|
||||
assert_eq!(b.gpu_indexes(), &[first_gpu]);
|
||||
|
||||
let c = &a + &b;
|
||||
let decrypted: u32 = c.decrypt(&keys);
|
||||
assert_eq!(c.current_device(), Device::CudaGpu);
|
||||
assert_eq!(c.gpu_indexes(), &[first_gpu]);
|
||||
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
|
||||
unset_server_key();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +143,115 @@ fn test_tag_propagation_zk_pok() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "zk-pok")]
|
||||
#[cfg(feature = "gpu")]
|
||||
fn test_tag_propagation_zk_pok_gpu() {
|
||||
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
let config =
|
||||
ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128)
|
||||
.use_dedicated_compact_public_key_parameters((
|
||||
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
))
|
||||
.build();
|
||||
let crs = crate::zk::CompactPkeCrs::from_config(config, (2 * 32) + (2 * 64) + 2).unwrap();
|
||||
|
||||
let metadata = [b'h', b'l', b'a', b'p', b'i'];
|
||||
|
||||
let mut cks = ClientKey::generate(config);
|
||||
let tag_value = random();
|
||||
cks.tag_mut().set_u64(tag_value);
|
||||
let cks = serialize_then_deserialize(&cks);
|
||||
assert_eq!(cks.tag().as_u64(), tag_value);
|
||||
|
||||
let compressed_server_key = CompressedServerKey::new(&cks);
|
||||
let gpu_sks = compressed_server_key.decompress_to_gpu();
|
||||
assert_eq!(gpu_sks.tag(), cks.tag());
|
||||
set_server_key(gpu_sks);
|
||||
|
||||
let cpk = CompactPublicKey::new(&cks);
|
||||
assert_eq!(cpk.tag(), cks.tag());
|
||||
|
||||
let mut builder = CompactCiphertextList::builder(&cpk);
|
||||
|
||||
let list_packed = builder
|
||||
.push(32u32)
|
||||
.push(1u32)
|
||||
.push(-1i64)
|
||||
.push(i64::MIN)
|
||||
.push(false)
|
||||
.push(true)
|
||||
.build_with_proof_packed(&crs, &metadata, crate::zk::ZkComputeLoad::Proof)
|
||||
.unwrap();
|
||||
|
||||
let expander = list_packed
|
||||
.verify_and_expand(&crs, &cpk, &metadata)
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let au32: FheUint32 = expander.get(0).unwrap().unwrap();
|
||||
let bu32: FheUint32 = expander.get(1).unwrap().unwrap();
|
||||
assert_eq!(au32.tag(), cks.tag());
|
||||
assert_eq!(bu32.tag(), cks.tag());
|
||||
|
||||
let cu32 = au32 + bu32;
|
||||
assert_eq!(cu32.tag(), cks.tag());
|
||||
}
|
||||
|
||||
{
|
||||
let ai64: FheInt64 = expander.get(2).unwrap().unwrap();
|
||||
let bi64: FheInt64 = expander.get(3).unwrap().unwrap();
|
||||
assert_eq!(ai64.tag(), cks.tag());
|
||||
assert_eq!(bi64.tag(), cks.tag());
|
||||
|
||||
let ci64 = ai64 + bi64;
|
||||
assert_eq!(ci64.tag(), cks.tag());
|
||||
}
|
||||
|
||||
{
|
||||
let abool: FheBool = expander.get(4).unwrap().unwrap();
|
||||
let bbool: FheBool = expander.get(5).unwrap().unwrap();
|
||||
assert_eq!(abool.tag(), cks.tag());
|
||||
assert_eq!(bbool.tag(), cks.tag());
|
||||
|
||||
let cbool = abool & bbool;
|
||||
assert_eq!(cbool.tag(), cks.tag());
|
||||
}
|
||||
|
||||
let unverified_expander = list_packed.expand_without_verification().unwrap();
|
||||
|
||||
{
|
||||
let au32: FheUint32 = unverified_expander.get(0).unwrap().unwrap();
|
||||
let bu32: FheUint32 = unverified_expander.get(1).unwrap().unwrap();
|
||||
assert_eq!(au32.tag(), cks.tag());
|
||||
assert_eq!(bu32.tag(), cks.tag());
|
||||
|
||||
let cu32 = au32 + bu32;
|
||||
assert_eq!(cu32.tag(), cks.tag());
|
||||
}
|
||||
|
||||
{
|
||||
let ai64: FheInt64 = unverified_expander.get(2).unwrap().unwrap();
|
||||
let bi64: FheInt64 = unverified_expander.get(3).unwrap().unwrap();
|
||||
assert_eq!(ai64.tag(), cks.tag());
|
||||
assert_eq!(bi64.tag(), cks.tag());
|
||||
|
||||
let ci64 = ai64 + bi64;
|
||||
assert_eq!(ci64.tag(), cks.tag());
|
||||
}
|
||||
|
||||
{
|
||||
let abool: FheBool = unverified_expander.get(4).unwrap().unwrap();
|
||||
let bbool: FheBool = unverified_expander.get(5).unwrap().unwrap();
|
||||
assert_eq!(abool.tag(), cks.tag());
|
||||
assert_eq!(bbool.tag(), cks.tag());
|
||||
|
||||
let cbool = abool & bbool;
|
||||
assert_eq!(cbool.tag(), cks.tag());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "gpu")]
|
||||
fn test_tag_propagation_gpu() {
|
||||
|
||||
@@ -265,3 +265,40 @@ pub trait BitXorSizeOnGpu<Rhs = Self> {
|
||||
pub trait BitNotSizeOnGpu {
|
||||
fn get_bitnot_size_on_gpu(&self) -> u64;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait FheOrdSizeOnGpu<Rhs = Self> {
|
||||
fn get_gt_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
fn get_lt_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
fn get_ge_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
fn get_le_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait FheMinSizeOnGpu<Rhs = Self> {
|
||||
fn get_min_size_on_gpu(&self, other: Rhs) -> u64;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait FheMaxSizeOnGpu<Rhs = Self> {
|
||||
fn get_max_size_on_gpu(&self, other: Rhs) -> u64;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait ShlSizeOnGpu<Rhs = Self> {
|
||||
fn get_left_shift_size_on_gpu(&self, other: Rhs) -> u64;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait ShrSizeOnGpu<Rhs = Self> {
|
||||
fn get_right_shift_size_on_gpu(&self, other: Rhs) -> u64;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait RotateLeftSizeOnGpu<Rhs = Self> {
|
||||
fn get_rotate_left_size_on_gpu(&self, other: Rhs) -> u64;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait RotateRightSizeOnGpu<Rhs = Self> {
|
||||
fn get_rotate_right_size_on_gpu(&self, other: Rhs) -> u64;
|
||||
}
|
||||
|
||||
@@ -310,7 +310,7 @@ pub struct CompactCiphertextListExpander {
|
||||
}
|
||||
|
||||
impl CompactCiphertextListExpander {
|
||||
fn new(expanded_blocks: Vec<Ciphertext>, info: Vec<DataKind>) -> Self {
|
||||
pub(crate) fn new(expanded_blocks: Vec<Ciphertext>, info: Vec<DataKind>) -> Self {
|
||||
Self {
|
||||
expanded_blocks,
|
||||
info,
|
||||
|
||||
@@ -143,7 +143,7 @@ impl ClientKey {
|
||||
}
|
||||
|
||||
pub fn parameters(&self) -> crate::shortint::AtomicPatternParameters {
|
||||
self.key.parameters.ap_parameters().unwrap()
|
||||
self.key.parameters().ap_parameters().unwrap()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -333,7 +333,7 @@ impl ClientKey {
|
||||
return T::ZERO;
|
||||
}
|
||||
|
||||
let bits_in_block = self.key.parameters.message_modulus().0.ilog2();
|
||||
let bits_in_block = self.key.parameters().message_modulus().0.ilog2();
|
||||
let decrypted_block_iter = blocks.iter().map(|block| decrypt_block(&self.key, block));
|
||||
BlockRecomposer::recompose_unsigned(decrypted_block_iter, bits_in_block)
|
||||
}
|
||||
@@ -417,7 +417,7 @@ impl ClientKey {
|
||||
return T::ZERO;
|
||||
}
|
||||
|
||||
let bits_in_block = self.key.parameters.message_modulus().0.ilog2();
|
||||
let bits_in_block = self.key.parameters().message_modulus().0.ilog2();
|
||||
let decrypted_block_iter = ctxt
|
||||
.blocks
|
||||
.iter()
|
||||
|
||||
@@ -7,7 +7,7 @@ pub(crate) trait KnowsMessageModulus {
|
||||
|
||||
impl KnowsMessageModulus for crate::shortint::ClientKey {
|
||||
fn message_modulus(&self) -> MessageModulus {
|
||||
self.parameters.message_modulus()
|
||||
self.parameters().message_modulus()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
use crate::core_crypto::commons::traits::contiguous_entity_container::ContiguousEntityContainer;
|
||||
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use crate::core_crypto::gpu::lwe_compact_ciphertext_list::CudaLweCompactCiphertextList;
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::integer::ciphertext::DataKind;
|
||||
use crate::core_crypto::prelude::LweCiphertext;
|
||||
use crate::integer::ciphertext::{CompactCiphertextListExpander, DataKind};
|
||||
use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaExpandable;
|
||||
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
|
||||
use crate::integer::gpu::ciphertext::CudaRadixCiphertext;
|
||||
use crate::shortint::ciphertext::CompactCiphertextList;
|
||||
use crate::shortint::parameters::{CompactCiphertextListExpansionKind, Degree};
|
||||
use crate::shortint::{CarryModulus, MessageModulus};
|
||||
use crate::shortint::{CarryModulus, Ciphertext, MessageModulus};
|
||||
use itertools::Itertools;
|
||||
|
||||
pub struct CudaCompactCiphertextList {
|
||||
pub(crate) d_ct_list: CudaLweCompactCiphertextList<u64>,
|
||||
@@ -17,12 +20,25 @@ pub struct CudaCompactCiphertextList {
|
||||
pub(crate) expansion_kind: CompactCiphertextListExpansionKind,
|
||||
}
|
||||
|
||||
impl CudaCompactCiphertextList {
|
||||
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
|
||||
Self {
|
||||
d_ct_list: self.d_ct_list.duplicate(streams),
|
||||
degree: self.degree,
|
||||
message_modulus: self.message_modulus,
|
||||
carry_modulus: self.carry_modulus,
|
||||
expansion_kind: self.expansion_kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaCompactCiphertextListInfo {
|
||||
pub info: CudaBlockInfo,
|
||||
pub data_kind: DataKind,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaCompactCiphertextListExpander {
|
||||
pub(crate) expanded_blocks: CudaLweCiphertextList<u64>,
|
||||
pub(crate) blocks_info: Vec<CudaCompactCiphertextListInfo>,
|
||||
@@ -39,6 +55,29 @@ impl CudaCompactCiphertextListExpander {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.expanded_blocks.lwe_ciphertext_count().0
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
pub fn get_kind_of(&self, index: usize) -> Option<DataKind> {
|
||||
let blocks = self.blocks_info.get(index)?;
|
||||
Some(blocks.data_kind)
|
||||
}
|
||||
|
||||
pub fn message_modulus(&self, index: usize) -> Option<MessageModulus> {
|
||||
let blocks = self.blocks_info.get(index)?;
|
||||
Some(blocks.info.message_modulus)
|
||||
}
|
||||
|
||||
pub fn carry_modulus(&self, index: usize) -> Option<CarryModulus> {
|
||||
let blocks = self.blocks_info.get(index)?;
|
||||
Some(blocks.info.carry_modulus)
|
||||
}
|
||||
|
||||
fn blocks_of(
|
||||
&self,
|
||||
index: usize,
|
||||
@@ -92,6 +131,44 @@ impl CudaCompactCiphertextListExpander {
|
||||
.map(|(blocks, kind)| T::from_expanded_blocks(blocks, kind))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub fn to_compact_ciphertext_list_expander(
|
||||
&self,
|
||||
streams: &CudaStreams,
|
||||
) -> CompactCiphertextListExpander {
|
||||
let lwe_ciphertext_list = self.expanded_blocks.to_lwe_ciphertext_list(streams);
|
||||
let ciphertext_modulus = self.expanded_blocks.ciphertext_modulus();
|
||||
|
||||
let expanded_blocks = lwe_ciphertext_list
|
||||
.iter()
|
||||
.zip(self.blocks_info.clone())
|
||||
.map(|(ct, info)| {
|
||||
let lwe = LweCiphertext::from_container(ct.as_ref().to_vec(), ciphertext_modulus);
|
||||
Ciphertext::new(
|
||||
lwe,
|
||||
info.info.degree,
|
||||
info.info.noise_level,
|
||||
info.info.message_modulus,
|
||||
info.info.carry_modulus,
|
||||
info.info.atomic_pattern,
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
let info = self
|
||||
.blocks_info
|
||||
.iter()
|
||||
.map(|ct_info| ct_info.data_kind)
|
||||
.collect_vec();
|
||||
|
||||
CompactCiphertextListExpander::new(expanded_blocks, info)
|
||||
}
|
||||
|
||||
pub fn duplicate(&self) -> Self {
|
||||
Self {
|
||||
expanded_blocks: self.expanded_blocks.clone(),
|
||||
blocks_info: self.blocks_info.iter().cloned().collect_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaCompactCiphertextList {
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::integer::gpu::list_compression::server_keys::{
|
||||
};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::RadixClientKey;
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::EncryptionKeyChoice;
|
||||
|
||||
@@ -21,7 +22,11 @@ impl RadixClientKey {
|
||||
) -> (CudaCompressionKey, CudaDecompressionKey) {
|
||||
let private_compression_key = &private_compression_key.key;
|
||||
|
||||
let params = &private_compression_key.params;
|
||||
let compression_params = &private_compression_key.params;
|
||||
|
||||
let AtomicPatternClientKey::Standard(std_cks) = &self.as_ref().key.atomic_pattern else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
self.parameters().encryption_key_choice(),
|
||||
@@ -32,11 +37,11 @@ impl RadixClientKey {
|
||||
// Compression key
|
||||
let packing_key_switching_key = ShortintEngine::with_thread_local_mut(|engine| {
|
||||
allocate_and_generate_new_lwe_packing_keyswitch_key(
|
||||
&self.as_ref().key.large_lwe_secret_key(),
|
||||
&std_cks.large_lwe_secret_key(),
|
||||
&private_compression_key.post_packing_ks_key,
|
||||
params.packing_ks_base_log,
|
||||
params.packing_ks_level,
|
||||
params.packing_ks_key_noise_distribution,
|
||||
compression_params.packing_ks_base_log,
|
||||
compression_params.packing_ks_level,
|
||||
compression_params.packing_ks_key_noise_distribution,
|
||||
self.parameters().ciphertext_modulus(),
|
||||
&mut engine.encryption_generator,
|
||||
)
|
||||
@@ -45,7 +50,7 @@ impl RadixClientKey {
|
||||
let glwe_compression_key = CompressionKey {
|
||||
key: crate::shortint::list_compression::CompressionKey {
|
||||
packing_key_switching_key,
|
||||
lwe_per_glwe: params.lwe_per_glwe,
|
||||
lwe_per_glwe: compression_params.lwe_per_glwe,
|
||||
storage_log_modulus: private_compression_key.params.storage_log_modulus,
|
||||
},
|
||||
};
|
||||
@@ -60,9 +65,9 @@ impl RadixClientKey {
|
||||
self.parameters().polynomial_size(),
|
||||
private_compression_key.params.br_base_log,
|
||||
private_compression_key.params.br_level,
|
||||
params
|
||||
compression_params
|
||||
.packing_ks_glwe_dimension
|
||||
.to_equivalent_lwe_dimension(params.packing_ks_polynomial_size),
|
||||
.to_equivalent_lwe_dimension(compression_params.packing_ks_polynomial_size),
|
||||
self.parameters().ciphertext_modulus(),
|
||||
);
|
||||
|
||||
@@ -71,7 +76,7 @@ impl RadixClientKey {
|
||||
&private_compression_key
|
||||
.post_packing_ks_key
|
||||
.as_lwe_secret_key(),
|
||||
&self.as_ref().key.glwe_secret_key,
|
||||
&std_cks.glwe_secret_key,
|
||||
&mut bsk,
|
||||
self.parameters().glwe_noise_distribution(),
|
||||
&mut engine.encryption_generator,
|
||||
@@ -84,7 +89,7 @@ impl RadixClientKey {
|
||||
|
||||
let cuda_decompression_key = CudaDecompressionKey {
|
||||
blind_rotate_key,
|
||||
lwe_per_glwe: params.lwe_per_glwe,
|
||||
lwe_per_glwe: compression_params.lwe_per_glwe,
|
||||
glwe_dimension: self.parameters().glwe_dimension(),
|
||||
polynomial_size: self.parameters().polynomial_size(),
|
||||
message_modulus: self.parameters().message_modulus(),
|
||||
|
||||
@@ -1,60 +1,47 @@
|
||||
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::integer::client_key::secret_encryption_key::SecretEncryptionKeyView;
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::ClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::parameters::ShortintKeySwitchingParameters;
|
||||
use crate::integer::key_switching_key::KeySwitchingKey;
|
||||
use crate::shortint::EncryptionKeyChoice;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CudaKeySwitchingKey<'keys> {
|
||||
pub(crate) key_switching_key: CudaLweKeyswitchKey<u64>,
|
||||
pub(crate) dest_server_key: &'keys CudaServerKey,
|
||||
pub struct CudaKeySwitchingKeyMaterial {
|
||||
pub(crate) lwe_keyswitch_key: CudaLweKeyswitchKey<u64>,
|
||||
pub(crate) destination_key: EncryptionKeyChoice,
|
||||
}
|
||||
|
||||
impl<'keys> CudaKeySwitchingKey<'keys> {
|
||||
pub fn new<'input_key, InputEncryptionKey>(
|
||||
input_key_pair: (InputEncryptionKey, Option<&'keys CudaServerKey>),
|
||||
output_key_pair: (&'keys ClientKey, &'keys CudaServerKey),
|
||||
params: ShortintKeySwitchingParameters,
|
||||
#[allow(dead_code)]
|
||||
pub struct CudaKeySwitchingKey<'key> {
|
||||
pub(crate) key_switching_key_material: &'key CudaKeySwitchingKeyMaterial,
|
||||
pub(crate) dest_server_key: &'key CudaServerKey,
|
||||
}
|
||||
|
||||
impl CudaKeySwitchingKeyMaterial {
|
||||
pub fn from_key_switching_key(
|
||||
key_switching_key: &KeySwitchingKey,
|
||||
streams: &CudaStreams,
|
||||
) -> Self
|
||||
where
|
||||
InputEncryptionKey: Into<SecretEncryptionKeyView<'input_key>>,
|
||||
{
|
||||
let input_secret_key: SecretEncryptionKeyView<'_> = input_key_pair.0.into();
|
||||
|
||||
// Creation of the key switching key
|
||||
let key_switching_key = ShortintEngine::with_thread_local_mut(|engine| {
|
||||
engine.new_key_switching_key(&input_secret_key.key, output_key_pair.0.as_ref(), params)
|
||||
});
|
||||
let d_key_switching_key =
|
||||
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&key_switching_key, streams);
|
||||
let full_message_modulus_input =
|
||||
input_secret_key.key.carry_modulus.0 * input_secret_key.key.message_modulus.0;
|
||||
let full_message_modulus_output = output_key_pair.0.key.parameters.carry_modulus().0
|
||||
* output_key_pair.0.key.parameters.message_modulus().0;
|
||||
assert!(
|
||||
full_message_modulus_input.is_power_of_two()
|
||||
&& full_message_modulus_output.is_power_of_two(),
|
||||
"Cannot create casting key if the full messages moduli are not a power of 2"
|
||||
) -> Self {
|
||||
let key_switching_key_material = &key_switching_key.key.key_switching_key_material;
|
||||
let d_lwe_keyswich_key = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
|
||||
&key_switching_key_material.key_switching_key,
|
||||
streams,
|
||||
);
|
||||
if full_message_modulus_input > full_message_modulus_output {
|
||||
assert!(
|
||||
input_key_pair.1.is_some(),
|
||||
"Trying to build a integer::gpu::KeySwitchingKey \
|
||||
going from a large modulus {full_message_modulus_input} \
|
||||
to a smaller modulus {full_message_modulus_output} \
|
||||
without providing a source CudaServerKey, this is not supported"
|
||||
);
|
||||
}
|
||||
|
||||
CudaKeySwitchingKey {
|
||||
key_switching_key: d_key_switching_key,
|
||||
dest_server_key: output_key_pair.1,
|
||||
destination_key: params.destination_key,
|
||||
Self {
|
||||
lwe_keyswitch_key: d_lwe_keyswich_key,
|
||||
destination_key: key_switching_key_material.destination_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'key> CudaKeySwitchingKey<'key> {
|
||||
pub fn from_cuda_key_switching_key_material(
|
||||
key_switching_key_material: &'key CudaKeySwitchingKeyMaterial,
|
||||
dest_server_key: &'key CudaServerKey,
|
||||
) -> Self {
|
||||
Self {
|
||||
key_switching_key_material,
|
||||
dest_server_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1338,6 +1338,64 @@ pub unsafe fn unchecked_comparison_integer_radix_kb_async<T: UnsignedInteger, B:
|
||||
update_noise_degree(radix_lwe_out, &cuda_ffi_radix_lwe_out);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_comparison_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
op: ComparisonType,
|
||||
is_signed: bool,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_comparison_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
op as u32,
|
||||
is_signed,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_comparison(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
@@ -2950,6 +3008,399 @@ pub unsafe fn unchecked_rotate_left_integer_radix_kb_assign_async<
|
||||
update_noise_degree(radix_input, &cuda_ffi_radix_lwe_left);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_scalar_left_shift_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_logical_scalar_shift_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::LeftShift as u32,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_radix_logical_scalar_shift(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_scalar_logical_right_shift_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_logical_scalar_shift_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::RightShift as u32,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_radix_logical_scalar_shift(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_scalar_arithmetic_right_shift_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_arithmetic_scalar_shift_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::RightShift as u32,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_radix_arithmetic_scalar_shift(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_right_shift_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
is_signed: bool,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_shift_and_rotate_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::RightShift as u32,
|
||||
is_signed,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_radix_shift_and_rotate(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_left_shift_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
is_signed: bool,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_shift_and_rotate_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::LeftShift as u32,
|
||||
is_signed,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_radix_shift_and_rotate(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_rotate_right_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
is_signed: bool,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_shift_and_rotate_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::RightRotate as u32,
|
||||
is_signed,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_radix_shift_and_rotate(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_rotate_left_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
is_signed: bool,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_shift_and_rotate_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::LeftRotate as u32,
|
||||
is_signed,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_radix_shift_and_rotate(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
@@ -3381,6 +3832,116 @@ pub unsafe fn unchecked_scalar_rotate_right_integer_radix_kb_assign_async<
|
||||
update_noise_degree(radix_input, &cuda_ffi_radix_lwe_left);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_scalar_rotate_left_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_scalar_rotate_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::LeftShift as u32,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_radix_scalar_rotate(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_scalar_rotate_right_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_scalar_rotate_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::RightShift as u32,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_radix_scalar_rotate(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
|
||||
@@ -12,11 +12,10 @@ use crate::integer::server_key::num_bits_to_represent_unsigned_value;
|
||||
use crate::integer::ClientKey;
|
||||
use crate::shortint::atomic_pattern::compressed::CompressedAtomicPatternServerKey;
|
||||
use crate::shortint::ciphertext::{MaxDegree, MaxNoiseLevel};
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::server_key::ModulusSwitchNoiseReductionKey;
|
||||
use crate::shortint::{
|
||||
AtomicPatternParameters, CarryModulus, CiphertextModulus, MessageModulus, PBSOrder,
|
||||
};
|
||||
use crate::shortint::{CarryModulus, CiphertextModulus, MessageModulus, PBSOrder};
|
||||
mod radix;
|
||||
|
||||
pub enum CudaBootstrappingKey {
|
||||
@@ -72,8 +71,8 @@ impl CudaServerKey {
|
||||
// It should remain just enough space to add a carry
|
||||
let client_key = cks.as_ref();
|
||||
let max_degree = MaxDegree::integer_radix_server_key(
|
||||
client_key.key.parameters.message_modulus(),
|
||||
client_key.key.parameters.carry_modulus(),
|
||||
client_key.key.parameters().message_modulus(),
|
||||
client_key.key.parameters().carry_modulus(),
|
||||
);
|
||||
Self::new_server_key_with_max_degree(client_key, max_degree, streams)
|
||||
}
|
||||
@@ -86,16 +85,18 @@ impl CudaServerKey {
|
||||
let mut engine = ShortintEngine::new();
|
||||
|
||||
// Generate a regular keyset and convert to the GPU
|
||||
let AtomicPatternParameters::Standard(pbs_params_base) = &cks.parameters() else {
|
||||
let AtomicPatternClientKey::Standard(std_cks) = &cks.key.atomic_pattern else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let pbs_params_base = std_cks.parameters;
|
||||
|
||||
let d_bootstrapping_key = match pbs_params_base {
|
||||
crate::shortint::PBSParameters::PBS(pbs_params) => {
|
||||
let h_bootstrap_key: LweBootstrapKeyOwned<u64> =
|
||||
par_allocate_and_generate_new_lwe_bootstrap_key(
|
||||
&cks.key.small_lwe_secret_key(),
|
||||
&cks.key.glwe_secret_key,
|
||||
&std_cks.lwe_secret_key,
|
||||
&std_cks.glwe_secret_key,
|
||||
pbs_params.pbs_base_log,
|
||||
pbs_params.pbs_level,
|
||||
pbs_params.glwe_noise_distribution,
|
||||
@@ -107,7 +108,7 @@ impl CudaServerKey {
|
||||
.map(|modulus_switch_noise_reduction_params| {
|
||||
ModulusSwitchNoiseReductionKey::new(
|
||||
modulus_switch_noise_reduction_params,
|
||||
&cks.key.small_lwe_secret_key(),
|
||||
&std_cks.lwe_secret_key,
|
||||
&mut engine,
|
||||
pbs_params.ciphertext_modulus,
|
||||
pbs_params.lwe_noise_distribution,
|
||||
@@ -124,8 +125,8 @@ impl CudaServerKey {
|
||||
crate::shortint::PBSParameters::MultiBitPBS(pbs_params) => {
|
||||
let h_bootstrap_key: LweMultiBitBootstrapKeyOwned<u64> =
|
||||
par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key(
|
||||
&cks.key.small_lwe_secret_key(),
|
||||
&cks.key.glwe_secret_key,
|
||||
&std_cks.lwe_secret_key,
|
||||
&std_cks.glwe_secret_key,
|
||||
pbs_params.pbs_base_log,
|
||||
pbs_params.pbs_level,
|
||||
pbs_params.grouping_factor,
|
||||
@@ -145,12 +146,12 @@ impl CudaServerKey {
|
||||
|
||||
// Creation of the key switching key
|
||||
let h_key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
&cks.key.large_lwe_secret_key(),
|
||||
&cks.key.small_lwe_secret_key(),
|
||||
cks.parameters().ks_base_log(),
|
||||
cks.parameters().ks_level(),
|
||||
cks.parameters().lwe_noise_distribution(),
|
||||
cks.parameters().ciphertext_modulus(),
|
||||
&std_cks.large_lwe_secret_key(),
|
||||
&std_cks.small_lwe_secret_key(),
|
||||
std_cks.parameters.ks_base_log(),
|
||||
std_cks.parameters.ks_level(),
|
||||
std_cks.parameters.lwe_noise_distribution(),
|
||||
std_cks.parameters.ciphertext_modulus(),
|
||||
&mut engine.encryption_generator,
|
||||
);
|
||||
|
||||
@@ -158,7 +159,7 @@ impl CudaServerKey {
|
||||
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
|
||||
|
||||
assert!(matches!(
|
||||
cks.parameters().encryption_key_choice().into(),
|
||||
std_cks.parameters.encryption_key_choice().into(),
|
||||
PBSOrder::KeyswitchBootstrap
|
||||
));
|
||||
|
||||
@@ -166,12 +167,12 @@ impl CudaServerKey {
|
||||
Self {
|
||||
key_switching_key: d_key_switching_key,
|
||||
bootstrapping_key: d_bootstrapping_key,
|
||||
message_modulus: cks.parameters().message_modulus(),
|
||||
carry_modulus: cks.parameters().carry_modulus(),
|
||||
message_modulus: std_cks.parameters.message_modulus(),
|
||||
carry_modulus: std_cks.parameters.carry_modulus(),
|
||||
max_degree,
|
||||
max_noise_level: cks.parameters().max_noise_level(),
|
||||
ciphertext_modulus: cks.parameters().ciphertext_modulus(),
|
||||
pbs_order: cks.parameters().encryption_key_choice().into(),
|
||||
max_noise_level: std_cks.parameters.max_noise_level(),
|
||||
ciphertext_modulus: std_cks.parameters.ciphertext_modulus(),
|
||||
pbs_order: std_cks.parameters.encryption_key_choice().into(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
get_comparison_integer_radix_kb_size_on_gpu, get_full_propagate_assign_size_on_gpu,
|
||||
unchecked_comparison_integer_radix_kb_async, ComparisonType, CudaServerKey, PBSType,
|
||||
};
|
||||
use crate::shortint::ciphertext::Degree;
|
||||
@@ -348,6 +349,120 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
pub(crate) fn get_comparison_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
op: ComparisonType,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_dimension(),
|
||||
ct_right.as_ref().d_blocks.lwe_dimension()
|
||||
);
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
let actual_full_prop_mem = match (
|
||||
ct_left.block_carries_are_empty(),
|
||||
ct_right.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => 0,
|
||||
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
(false, true) => full_prop_mem,
|
||||
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let comparison_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_comparison_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_comparison_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
actual_full_prop_mem.max(comparison_mem)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
@@ -936,6 +1051,60 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
pub fn get_eq_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::EQ, streams)
|
||||
}
|
||||
|
||||
pub fn get_ne_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::NE, streams)
|
||||
}
|
||||
|
||||
pub fn get_gt_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::GT, streams)
|
||||
}
|
||||
|
||||
pub fn get_ge_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::GE, streams)
|
||||
}
|
||||
|
||||
pub fn get_lt_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::LT, streams)
|
||||
}
|
||||
|
||||
pub fn get_le_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::LE, streams)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
@@ -1221,4 +1390,21 @@ impl CudaServerKey {
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
pub fn get_max_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::MAX, streams)
|
||||
}
|
||||
|
||||
pub fn get_min_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::MIN, streams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,6 +568,7 @@ pub(crate) mod test {
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{gen_keys_gpu, CudaServerKey};
|
||||
use crate::integer::{ClientKey, RadixCiphertext};
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::oprf::create_random_from_seed_modulus_switched;
|
||||
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
use rayon::prelude::*;
|
||||
@@ -635,7 +636,11 @@ pub(crate) mod test {
|
||||
sk.ciphertext_modulus,
|
||||
);
|
||||
|
||||
let sk = ck.key.small_lwe_secret_key();
|
||||
let AtomicPatternClientKey::Standard(std_ck) = &ck.key.atomic_pattern else {
|
||||
panic!("Only std AP is supported on GPU")
|
||||
};
|
||||
|
||||
let sk = std_ck.small_lwe_secret_key();
|
||||
let plain_prf_input = decrypt_lwe_ciphertext(&sk, &ct)
|
||||
.0
|
||||
.wrapping_add(1 << (64 - log_input_p - 1))
|
||||
|
||||
@@ -3,6 +3,8 @@ use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
get_full_propagate_assign_size_on_gpu, get_rotate_left_integer_radix_kb_size_on_gpu,
|
||||
get_rotate_right_integer_radix_kb_size_on_gpu,
|
||||
unchecked_rotate_left_integer_radix_kb_assign_async,
|
||||
unchecked_rotate_right_integer_radix_kb_assign_async, CudaServerKey, PBSType,
|
||||
};
|
||||
@@ -556,4 +558,226 @@ impl CudaServerKey {
|
||||
unsafe { self.rotate_left_assign_async(ct, rotate, streams) };
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
pub fn get_rotate_left_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_dimension(),
|
||||
ct_right.as_ref().d_blocks.lwe_dimension()
|
||||
);
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
let actual_full_prop_mem = match (
|
||||
ct_left.block_carries_are_empty(),
|
||||
ct_right.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => 0,
|
||||
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
(false, true) => full_prop_mem,
|
||||
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let rotate_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_rotate_left_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_rotate_left_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
actual_full_prop_mem.max(rotate_mem)
|
||||
}
|
||||
|
||||
pub fn get_rotate_right_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_dimension(),
|
||||
ct_right.as_ref().d_blocks.lwe_dimension()
|
||||
);
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
let actual_full_prop_mem = match (
|
||||
ct_left.block_carries_are_empty(),
|
||||
ct_right.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => 0,
|
||||
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
(false, true) => full_prop_mem,
|
||||
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let rotate_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_rotate_right_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_rotate_right_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
actual_full_prop_mem.max(rotate_mem)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1057,6 +1057,54 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
pub fn get_scalar_eq_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::EQ, streams)
|
||||
}
|
||||
|
||||
pub fn get_scalar_ne_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::NE, streams)
|
||||
}
|
||||
|
||||
pub fn get_scalar_gt_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::GT, streams)
|
||||
}
|
||||
|
||||
pub fn get_scalar_ge_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::GE, streams)
|
||||
}
|
||||
|
||||
pub fn get_scalar_lt_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::LT, streams)
|
||||
}
|
||||
|
||||
pub fn get_scalar_le_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::LE, streams)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
@@ -1192,4 +1240,19 @@ impl CudaServerKey {
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
pub fn get_scalar_max_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::MAX, streams)
|
||||
}
|
||||
|
||||
pub fn get_scalar_min_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::MIN, streams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ use crate::core_crypto::prelude::{CastFrom, LweBskGroupingFactor};
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
get_full_propagate_assign_size_on_gpu, get_scalar_rotate_left_integer_radix_kb_size_on_gpu,
|
||||
get_scalar_rotate_right_integer_radix_kb_size_on_gpu,
|
||||
unchecked_scalar_rotate_left_integer_radix_kb_assign_async,
|
||||
unchecked_scalar_rotate_right_integer_radix_kb_assign_async, CudaServerKey, PBSType,
|
||||
};
|
||||
@@ -275,4 +277,193 @@ impl CudaServerKey {
|
||||
self.scalar_rotate_right_assign(&mut result, shift, stream);
|
||||
result
|
||||
}
|
||||
pub fn get_scalar_rotate_left_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
let scalar_shift_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
get_scalar_rotate_left_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
)
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_scalar_rotate_left_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
full_prop_mem.max(scalar_shift_mem)
|
||||
}
|
||||
|
||||
pub fn get_scalar_rotate_right_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
let scalar_shift_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
get_scalar_rotate_right_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
)
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_scalar_rotate_right_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
full_prop_mem.max(scalar_shift_mem)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@ use crate::core_crypto::prelude::{CastFrom, LweBskGroupingFactor};
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
get_full_propagate_assign_size_on_gpu,
|
||||
get_scalar_arithmetic_right_shift_integer_radix_kb_size_on_gpu,
|
||||
get_scalar_left_shift_integer_radix_kb_size_on_gpu,
|
||||
get_scalar_logical_right_shift_integer_radix_kb_size_on_gpu,
|
||||
unchecked_scalar_arithmetic_right_shift_integer_radix_kb_assign_async,
|
||||
unchecked_scalar_left_shift_integer_radix_kb_assign_async,
|
||||
unchecked_scalar_logical_right_shift_integer_radix_kb_assign_async, CudaServerKey, PBSType,
|
||||
@@ -647,4 +651,245 @@ impl CudaServerKey {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_scalar_left_shift_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
let scalar_shift_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
get_scalar_left_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
)
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_scalar_left_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
full_prop_mem.max(scalar_shift_mem)
|
||||
}
|
||||
|
||||
pub fn get_scalar_right_shift_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
let scalar_shift_mem = if T::IS_SIGNED {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
get_scalar_arithmetic_right_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
)
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_scalar_arithmetic_right_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
get_scalar_logical_right_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
)
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_scalar_logical_right_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
full_prop_mem.max(scalar_shift_mem)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
get_full_propagate_assign_size_on_gpu, get_left_shift_integer_radix_kb_size_on_gpu,
|
||||
get_right_shift_integer_radix_kb_size_on_gpu,
|
||||
unchecked_left_shift_integer_radix_kb_assign_async,
|
||||
unchecked_right_shift_integer_radix_kb_assign_async, CudaServerKey, PBSType,
|
||||
};
|
||||
@@ -551,4 +553,226 @@ impl CudaServerKey {
|
||||
unsafe { self.left_shift_assign_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
pub fn get_left_shift_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_dimension(),
|
||||
ct_right.as_ref().d_blocks.lwe_dimension()
|
||||
);
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
let actual_full_prop_mem = match (
|
||||
ct_left.block_carries_are_empty(),
|
||||
ct_right.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => 0,
|
||||
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
(false, true) => full_prop_mem,
|
||||
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let shift_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_left_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_left_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
actual_full_prop_mem.max(shift_mem)
|
||||
}
|
||||
|
||||
pub fn get_right_shift_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_dimension(),
|
||||
ct_right.as_ref().d_blocks.lwe_dimension()
|
||||
);
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
let actual_full_prop_mem = match (
|
||||
ct_left.block_carries_are_empty(),
|
||||
ct_right.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => 0,
|
||||
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
(false, true) => full_prop_mem,
|
||||
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let shift_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_right_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_right_shift_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
actual_full_prop_mem.max(shift_mem)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,39 +17,39 @@ where
|
||||
P: Into<TestParameters> + Clone,
|
||||
{
|
||||
// Binary Ops Executors
|
||||
let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
|
||||
let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
|
||||
//let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
|
||||
//let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
|
||||
let bitwise_and_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
|
||||
let bitwise_or_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
|
||||
let bitwise_xor_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
|
||||
let mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
|
||||
let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
|
||||
let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
|
||||
let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
|
||||
let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
|
||||
let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
|
||||
let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
|
||||
//let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
|
||||
//let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
|
||||
//let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
|
||||
//let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
|
||||
//let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
|
||||
//let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
|
||||
|
||||
// Binary Ops Clear functions
|
||||
let clear_add = |x, y| x + y;
|
||||
let clear_sub = |x, y| x - y;
|
||||
//let clear_add = |x, y| x + y;
|
||||
//let clear_sub = |x, y| x - y;
|
||||
let clear_bitwise_and = |x, y| x & y;
|
||||
let clear_bitwise_or = |x, y| x | y;
|
||||
let clear_bitwise_xor = |x, y| x ^ y;
|
||||
let clear_mul = |x, y| x * y;
|
||||
// Warning this rotate definition only works with 64-bit ciphertexts
|
||||
let clear_rotate_left = |x: u64, y: u64| x.rotate_left(y as u32);
|
||||
let clear_left_shift = |x, y| x << y;
|
||||
// Warning this rotate definition only works with 64-bit ciphertexts
|
||||
let clear_rotate_right = |x: u64, y: u64| x.rotate_right(y as u32);
|
||||
let clear_right_shift = |x, y| x >> y;
|
||||
let clear_max = |x: u64, y: u64| max(x, y);
|
||||
let clear_min = |x: u64, y: u64| min(x, y);
|
||||
//let clear_rotate_left = |x: u64, y: u64| x.rotate_left(y as u32);
|
||||
//let clear_left_shift = |x, y| x << y;
|
||||
//// Warning this rotate definition only works with 64-bit ciphertexts
|
||||
//let clear_rotate_right = |x: u64, y: u64| x.rotate_right(y as u32);
|
||||
//let clear_right_shift = |x, y| x >> y;
|
||||
//let clear_max = |x: u64, y: u64| max(x, y);
|
||||
//let clear_min = |x: u64, y: u64| min(x, y);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut binary_ops: Vec<(BinaryOpExecutor, &dyn Fn(u64, u64) -> u64, String)> = vec![
|
||||
(Box::new(add_executor), &clear_add, "add".to_string()),
|
||||
(Box::new(sub_executor), &clear_sub, "sub".to_string()),
|
||||
//(Box::new(add_executor), &clear_add, "add".to_string()),
|
||||
//(Box::new(sub_executor), &clear_sub, "sub".to_string()),
|
||||
(
|
||||
Box::new(bitwise_and_executor),
|
||||
&clear_bitwise_and,
|
||||
@@ -66,28 +66,28 @@ where
|
||||
"bitxor".to_string(),
|
||||
),
|
||||
(Box::new(mul_executor), &clear_mul, "mul".to_string()),
|
||||
(
|
||||
Box::new(rotate_left_executor),
|
||||
&clear_rotate_left,
|
||||
"rotate left".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(left_shift_executor),
|
||||
&clear_left_shift,
|
||||
"left shift".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(rotate_right_executor),
|
||||
&clear_rotate_right,
|
||||
"rotate right".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(right_shift_executor),
|
||||
&clear_right_shift,
|
||||
"right shift".to_string(),
|
||||
),
|
||||
(Box::new(max_executor), &clear_max, "max".to_string()),
|
||||
(Box::new(min_executor), &clear_min, "min".to_string()),
|
||||
//(
|
||||
// Box::new(rotate_left_executor),
|
||||
// &clear_rotate_left,
|
||||
// "rotate left".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(left_shift_executor),
|
||||
// &clear_left_shift,
|
||||
// "left shift".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(rotate_right_executor),
|
||||
// &clear_rotate_right,
|
||||
// "rotate right".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(right_shift_executor),
|
||||
// &clear_right_shift,
|
||||
// "right shift".to_string(),
|
||||
//),
|
||||
//(Box::new(max_executor), &clear_max, "max".to_string()),
|
||||
//(Box::new(min_executor), &clear_min, "min".to_string()),
|
||||
];
|
||||
|
||||
// Unary Ops Executors
|
||||
@@ -115,8 +115,8 @@ where
|
||||
];
|
||||
|
||||
// Scalar binary Ops Executors
|
||||
let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
|
||||
let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
//let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
|
||||
//let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
let scalar_bitwise_and_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
|
||||
let scalar_bitwise_or_executor =
|
||||
@@ -124,27 +124,27 @@ where
|
||||
let scalar_bitwise_xor_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
|
||||
let scalar_mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
|
||||
let scalar_rotate_left_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
|
||||
let scalar_left_shift_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
|
||||
let scalar_rotate_right_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
|
||||
let scalar_right_shift_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
|
||||
//let scalar_rotate_left_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
|
||||
//let scalar_left_shift_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
|
||||
//let scalar_rotate_right_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
|
||||
//let scalar_right_shift_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_binary_ops: Vec<(ScalarBinaryOpExecutor, &dyn Fn(u64, u64) -> u64, String)> = vec![
|
||||
(
|
||||
Box::new(scalar_add_executor),
|
||||
&clear_add,
|
||||
"scalar add".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_sub_executor),
|
||||
&clear_sub,
|
||||
"scalar sub".to_string(),
|
||||
),
|
||||
//(
|
||||
// Box::new(scalar_add_executor),
|
||||
// &clear_add,
|
||||
// "scalar add".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_sub_executor),
|
||||
// &clear_sub,
|
||||
// "scalar sub".to_string(),
|
||||
//),
|
||||
(
|
||||
Box::new(scalar_bitwise_and_executor),
|
||||
&clear_bitwise_and,
|
||||
@@ -165,26 +165,26 @@ where
|
||||
&clear_mul,
|
||||
"scalar mul".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_rotate_left_executor),
|
||||
&clear_rotate_left,
|
||||
"scalar rotate left".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_left_shift_executor),
|
||||
&clear_left_shift,
|
||||
"scalar left shift".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_rotate_right_executor),
|
||||
&clear_rotate_right,
|
||||
"scalar rotate right".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_right_shift_executor),
|
||||
&clear_right_shift,
|
||||
"scalar right shift".to_string(),
|
||||
),
|
||||
//(
|
||||
// Box::new(scalar_rotate_left_executor),
|
||||
// &clear_rotate_left,
|
||||
// "scalar rotate left".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_left_shift_executor),
|
||||
// &clear_left_shift,
|
||||
// "scalar left shift".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_rotate_right_executor),
|
||||
// &clear_rotate_right,
|
||||
// "scalar rotate right".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_right_shift_executor),
|
||||
// &clear_right_shift,
|
||||
// "scalar right shift".to_string(),
|
||||
//),
|
||||
];
|
||||
|
||||
// Overflowing Ops Executors
|
||||
@@ -249,37 +249,37 @@ where
|
||||
|
||||
// Comparison Ops Executors
|
||||
let gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
|
||||
let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
|
||||
let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
|
||||
let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
|
||||
let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
|
||||
let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
|
||||
//let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
|
||||
//let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
|
||||
//let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
|
||||
//let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
|
||||
//let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
|
||||
|
||||
// Comparison Ops Clear functions
|
||||
let clear_gt = |x: u64, y: u64| -> bool { x > y };
|
||||
let clear_ge = |x: u64, y: u64| -> bool { x >= y };
|
||||
let clear_lt = |x: u64, y: u64| -> bool { x < y };
|
||||
let clear_le = |x: u64, y: u64| -> bool { x <= y };
|
||||
let clear_eq = |x: u64, y: u64| -> bool { x == y };
|
||||
let clear_ne = |x: u64, y: u64| -> bool { x != y };
|
||||
//let clear_ge = |x: u64, y: u64| -> bool { x >= y };
|
||||
//let clear_lt = |x: u64, y: u64| -> bool { x < y };
|
||||
//let clear_le = |x: u64, y: u64| -> bool { x <= y };
|
||||
//let clear_eq = |x: u64, y: u64| -> bool { x == y };
|
||||
//let clear_ne = |x: u64, y: u64| -> bool { x != y };
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut comparison_ops: Vec<(ComparisonOpExecutor, &dyn Fn(u64, u64) -> bool, String)> = vec![
|
||||
(Box::new(gt_executor), &clear_gt, "gt".to_string()),
|
||||
(Box::new(ge_executor), &clear_ge, "ge".to_string()),
|
||||
(Box::new(lt_executor), &clear_lt, "lt".to_string()),
|
||||
(Box::new(le_executor), &clear_le, "le".to_string()),
|
||||
(Box::new(eq_executor), &clear_eq, "eq".to_string()),
|
||||
(Box::new(ne_executor), &clear_ne, "ne".to_string()),
|
||||
//(Box::new(ge_executor), &clear_ge, "ge".to_string()),
|
||||
//(Box::new(lt_executor), &clear_lt, "lt".to_string()),
|
||||
//(Box::new(le_executor), &clear_le, "le".to_string()),
|
||||
//(Box::new(eq_executor), &clear_eq, "eq".to_string()),
|
||||
//(Box::new(ne_executor), &clear_ne, "ne".to_string()),
|
||||
];
|
||||
|
||||
// Scalar Comparison Ops Executors
|
||||
let scalar_gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
|
||||
let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
|
||||
let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
|
||||
let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
|
||||
let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
|
||||
let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
|
||||
//let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
|
||||
//let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
|
||||
//let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
|
||||
//let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
|
||||
//let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_comparison_ops: Vec<(
|
||||
@@ -292,31 +292,31 @@ where
|
||||
&clear_gt,
|
||||
"scalar gt".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_ge_executor),
|
||||
&clear_ge,
|
||||
"scalar ge".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_lt_executor),
|
||||
&clear_lt,
|
||||
"scalar lt".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_le_executor),
|
||||
&clear_le,
|
||||
"scalar le".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_eq_executor),
|
||||
&clear_eq,
|
||||
"scalar eq".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_ne_executor),
|
||||
&clear_ne,
|
||||
"scalar ne".to_string(),
|
||||
),
|
||||
//(
|
||||
// Box::new(scalar_ge_executor),
|
||||
// &clear_ge,
|
||||
// "scalar ge".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_lt_executor),
|
||||
// &clear_lt,
|
||||
// "scalar lt".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_le_executor),
|
||||
// &clear_le,
|
||||
// "scalar le".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_eq_executor),
|
||||
// &clear_eq,
|
||||
// "scalar eq".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_ne_executor),
|
||||
// &clear_ne,
|
||||
// "scalar ne".to_string(),
|
||||
//),
|
||||
];
|
||||
|
||||
// Select Executor
|
||||
|
||||
@@ -19,29 +19,29 @@ where
|
||||
P: Into<TestParameters> + Clone,
|
||||
{
|
||||
// Binary Ops Executors
|
||||
let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
|
||||
let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
|
||||
//let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
|
||||
//let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
|
||||
let bitwise_and_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
|
||||
let bitwise_or_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
|
||||
let bitwise_xor_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
|
||||
let mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
|
||||
let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
|
||||
let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
|
||||
//let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
|
||||
//let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
|
||||
|
||||
// Binary Ops Clear functions
|
||||
let clear_add = |x, y| x + y;
|
||||
let clear_sub = |x, y| x - y;
|
||||
//let clear_add = |x, y| x + y;
|
||||
//let clear_sub = |x, y| x - y;
|
||||
let clear_bitwise_and = |x, y| x & y;
|
||||
let clear_bitwise_or = |x, y| x | y;
|
||||
let clear_bitwise_xor = |x, y| x ^ y;
|
||||
let clear_mul = |x, y| x * y;
|
||||
let clear_max = |x: i64, y: i64| max(x, y);
|
||||
let clear_min = |x: i64, y: i64| min(x, y);
|
||||
//let clear_max = |x: i64, y: i64| max(x, y);
|
||||
//let clear_min = |x: i64, y: i64| min(x, y);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut binary_ops: Vec<(SignedBinaryOpExecutor, &dyn Fn(i64, i64) -> i64, String)> = vec![
|
||||
(Box::new(add_executor), &clear_add, "add".to_string()),
|
||||
(Box::new(sub_executor), &clear_sub, "sub".to_string()),
|
||||
//(Box::new(add_executor), &clear_add, "add".to_string()),
|
||||
//(Box::new(sub_executor), &clear_sub, "sub".to_string()),
|
||||
(
|
||||
Box::new(bitwise_and_executor),
|
||||
&clear_bitwise_and,
|
||||
@@ -58,14 +58,14 @@ where
|
||||
"bitxor".to_string(),
|
||||
),
|
||||
(Box::new(mul_executor), &clear_mul, "mul".to_string()),
|
||||
(Box::new(max_executor), &clear_max, "max".to_string()),
|
||||
(Box::new(min_executor), &clear_min, "min".to_string()),
|
||||
//(Box::new(max_executor), &clear_max, "max".to_string()),
|
||||
//(Box::new(min_executor), &clear_min, "min".to_string()),
|
||||
];
|
||||
|
||||
let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
|
||||
let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
|
||||
let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
|
||||
let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
|
||||
//let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
|
||||
//let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
|
||||
//let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
|
||||
// Warning this rotate definition only works with 64-bit ciphertexts
|
||||
let clear_rotate_left = |x: i64, y: u64| x.rotate_left(y as u32);
|
||||
let clear_left_shift = |x: i64, y: u64| x << y;
|
||||
@@ -83,21 +83,21 @@ where
|
||||
&clear_rotate_left,
|
||||
"rotate left".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(left_shift_executor),
|
||||
&clear_left_shift,
|
||||
"left shift".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(rotate_right_executor),
|
||||
&clear_rotate_right,
|
||||
"rotate right".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(right_shift_executor),
|
||||
&clear_right_shift,
|
||||
"right shift".to_string(),
|
||||
),
|
||||
//(
|
||||
// Box::new(left_shift_executor),
|
||||
// &clear_left_shift,
|
||||
// "left shift".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(rotate_right_executor),
|
||||
// &clear_rotate_right,
|
||||
// "rotate right".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(right_shift_executor),
|
||||
// &clear_right_shift,
|
||||
// "right shift".to_string(),
|
||||
//),
|
||||
];
|
||||
|
||||
// Unary Ops Executors
|
||||
@@ -125,8 +125,8 @@ where
|
||||
];
|
||||
|
||||
// Scalar binary Ops Executors
|
||||
let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
|
||||
let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
//let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
|
||||
//let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
let scalar_bitwise_and_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
|
||||
let scalar_bitwise_or_executor =
|
||||
@@ -141,16 +141,16 @@ where
|
||||
&dyn Fn(i64, i64) -> i64,
|
||||
String,
|
||||
)> = vec![
|
||||
(
|
||||
Box::new(scalar_add_executor),
|
||||
&clear_add,
|
||||
"scalar add".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_sub_executor),
|
||||
&clear_sub,
|
||||
"scalar sub".to_string(),
|
||||
),
|
||||
//(
|
||||
// Box::new(scalar_add_executor),
|
||||
// &clear_add,
|
||||
// "scalar add".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_sub_executor),
|
||||
// &clear_sub,
|
||||
// "scalar sub".to_string(),
|
||||
//),
|
||||
(
|
||||
Box::new(scalar_bitwise_and_executor),
|
||||
&clear_bitwise_and,
|
||||
@@ -175,12 +175,12 @@ where
|
||||
|
||||
let scalar_rotate_left_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
|
||||
let scalar_left_shift_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
|
||||
let scalar_rotate_right_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
|
||||
let scalar_right_shift_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
|
||||
//let scalar_left_shift_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
|
||||
//let scalar_rotate_right_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
|
||||
//let scalar_right_shift_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_shift_rotate_ops: Vec<(
|
||||
SignedScalarShiftRotateExecutor,
|
||||
@@ -192,21 +192,21 @@ where
|
||||
&clear_rotate_left,
|
||||
"scalar rotate left".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_left_shift_executor),
|
||||
&clear_left_shift,
|
||||
"scalar left shift".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_rotate_right_executor),
|
||||
&clear_rotate_right,
|
||||
"scalar rotate right".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_right_shift_executor),
|
||||
&clear_right_shift,
|
||||
"scalar right shift".to_string(),
|
||||
),
|
||||
//(
|
||||
// Box::new(scalar_left_shift_executor),
|
||||
// &clear_left_shift,
|
||||
// "scalar left shift".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_rotate_right_executor),
|
||||
// &clear_rotate_right,
|
||||
// "scalar rotate right".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_right_shift_executor),
|
||||
// &clear_right_shift,
|
||||
// "scalar right shift".to_string(),
|
||||
//),
|
||||
];
|
||||
|
||||
// Overflowing Ops Executors
|
||||
@@ -271,11 +271,11 @@ where
|
||||
|
||||
// Comparison Ops Executors
|
||||
let gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
|
||||
let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
|
||||
let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
|
||||
let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
|
||||
let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
|
||||
let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
|
||||
//let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
|
||||
//let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
|
||||
//let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
|
||||
//let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
|
||||
//let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
|
||||
|
||||
// Comparison Ops Clear functions
|
||||
let clear_gt = |x: i64, y: i64| -> bool { x > y };
|
||||
@@ -292,20 +292,20 @@ where
|
||||
String,
|
||||
)> = vec![
|
||||
(Box::new(gt_executor), &clear_gt, "gt".to_string()),
|
||||
(Box::new(ge_executor), &clear_ge, "ge".to_string()),
|
||||
(Box::new(lt_executor), &clear_lt, "lt".to_string()),
|
||||
(Box::new(le_executor), &clear_le, "le".to_string()),
|
||||
(Box::new(eq_executor), &clear_eq, "eq".to_string()),
|
||||
(Box::new(ne_executor), &clear_ne, "ne".to_string()),
|
||||
//(Box::new(ge_executor), &clear_ge, "ge".to_string()),
|
||||
//(Box::new(lt_executor), &clear_lt, "lt".to_string()),
|
||||
//(Box::new(le_executor), &clear_le, "le".to_string()),
|
||||
//(Box::new(eq_executor), &clear_eq, "eq".to_string()),
|
||||
//(Box::new(ne_executor), &clear_ne, "ne".to_string()),
|
||||
];
|
||||
|
||||
// Scalar Comparison Ops Executors
|
||||
let scalar_gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
|
||||
let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
|
||||
let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
|
||||
let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
|
||||
let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
|
||||
let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
|
||||
//let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
|
||||
//let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
|
||||
//let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
|
||||
//let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
|
||||
//let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_comparison_ops: Vec<(
|
||||
@@ -318,31 +318,31 @@ where
|
||||
&clear_gt,
|
||||
"scalar gt".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_ge_executor),
|
||||
&clear_ge,
|
||||
"scalar ge".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_lt_executor),
|
||||
&clear_lt,
|
||||
"scalar lt".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_le_executor),
|
||||
&clear_le,
|
||||
"scalar le".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_eq_executor),
|
||||
&clear_eq,
|
||||
"scalar eq".to_string(),
|
||||
),
|
||||
(
|
||||
Box::new(scalar_ne_executor),
|
||||
&clear_ne,
|
||||
"scalar ne".to_string(),
|
||||
),
|
||||
//(
|
||||
// Box::new(scalar_ge_executor),
|
||||
// &clear_ge,
|
||||
// "scalar ge".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_lt_executor),
|
||||
// &clear_lt,
|
||||
// "scalar lt".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_le_executor),
|
||||
// &clear_le,
|
||||
// "scalar le".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_eq_executor),
|
||||
// &clear_eq,
|
||||
// "scalar eq".to_string(),
|
||||
//),
|
||||
//(
|
||||
// Box::new(scalar_ne_executor),
|
||||
// &clear_ne,
|
||||
// "scalar ne".to_string(),
|
||||
//),
|
||||
];
|
||||
|
||||
// Select Executor
|
||||
|
||||
@@ -18,6 +18,7 @@ use crate::integer::{CompactPublicKey, ProvenCompactCiphertextList};
|
||||
use crate::shortint::ciphertext::{Degree, NoiseLevel};
|
||||
use crate::shortint::AtomicPatternKind;
|
||||
use crate::zk::CompactPkeCrs;
|
||||
use crate::GpuIndex;
|
||||
use itertools::Itertools;
|
||||
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
|
||||
use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu;
|
||||
@@ -28,6 +29,28 @@ pub struct CudaProvenCompactCiphertextList {
|
||||
}
|
||||
|
||||
impl CudaProvenCompactCiphertextList {
|
||||
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
|
||||
Self {
|
||||
h_proved_lists: self.h_proved_lists.clone(),
|
||||
d_compact_lists: self
|
||||
.d_compact_lists
|
||||
.iter()
|
||||
.map(|ct_list| ct_list.duplicate(streams))
|
||||
.collect_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gpu_indexes(&self) -> &[GpuIndex] {
|
||||
self.d_compact_lists
|
||||
.first()
|
||||
.unwrap()
|
||||
.d_ct_list
|
||||
.0
|
||||
.d_vec
|
||||
.gpu_indexes
|
||||
.as_slice()
|
||||
}
|
||||
|
||||
unsafe fn flatten_async(
|
||||
slice_ciphertext_list: &[CudaCompactCiphertextList],
|
||||
streams: &CudaStreams,
|
||||
@@ -152,13 +175,9 @@ impl CudaProvenCompactCiphertextList {
|
||||
DataKind::Boolean => 1,
|
||||
DataKind::Signed(x) => *x,
|
||||
DataKind::Unsigned(x) => *x,
|
||||
_ => panic!("DataKind not supported on GPUs"),
|
||||
DataKind::String { .. } => panic!("DataKind not supported on GPUs"),
|
||||
};
|
||||
std::iter::repeat(match data_kind {
|
||||
DataKind::Boolean => true,
|
||||
_ => false,
|
||||
})
|
||||
.take(repetitions)
|
||||
std::iter::repeat_n(matches!(data_kind, DataKind::Boolean), repetitions)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
@@ -194,25 +213,22 @@ impl CudaProvenCompactCiphertextList {
|
||||
streams,
|
||||
);
|
||||
|
||||
let d_input = &CudaProvenCompactCiphertextList::flatten_async(
|
||||
self.d_compact_lists.as_slice(),
|
||||
streams,
|
||||
);
|
||||
let casting_key = &key.key_switching_key;
|
||||
let sks = key.dest_server_key;
|
||||
let d_input = &Self::flatten_async(self.d_compact_lists.as_slice(), streams);
|
||||
let casting_key = &key.key_switching_key_material;
|
||||
let sks = &key.dest_server_key;
|
||||
let computing_ks_key = &key.dest_server_key.key_switching_key;
|
||||
|
||||
let casting_key_type: KsType = key.destination_key.into();
|
||||
let casting_key_type: KsType = casting_key.destination_key.into();
|
||||
|
||||
match &sks.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
expand_async(
|
||||
streams,
|
||||
&mut d_output,
|
||||
&d_input,
|
||||
d_input,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&casting_key.d_vec,
|
||||
&casting_key.lwe_keyswitch_key.d_vec,
|
||||
sks.message_modulus,
|
||||
sks.carry_modulus,
|
||||
d_bsk.glwe_dimension(),
|
||||
@@ -220,10 +236,16 @@ impl CudaProvenCompactCiphertextList {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
casting_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
casting_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
casting_key.decomposition_level_count(),
|
||||
casting_key.decomposition_base_log(),
|
||||
casting_key
|
||||
.lwe_keyswitch_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
casting_key
|
||||
.lwe_keyswitch_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
casting_key.lwe_keyswitch_key.decomposition_level_count(),
|
||||
casting_key.lwe_keyswitch_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -238,13 +260,10 @@ impl CudaProvenCompactCiphertextList {
|
||||
expand_async(
|
||||
streams,
|
||||
&mut d_output,
|
||||
&CudaProvenCompactCiphertextList::flatten_async(
|
||||
self.d_compact_lists.as_slice(),
|
||||
streams,
|
||||
),
|
||||
&Self::flatten_async(self.d_compact_lists.as_slice(), streams),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&casting_key.d_vec,
|
||||
&casting_key.lwe_keyswitch_key.d_vec,
|
||||
sks.message_modulus,
|
||||
sks.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
@@ -252,10 +271,16 @@ impl CudaProvenCompactCiphertextList {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
casting_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
casting_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
casting_key.decomposition_level_count(),
|
||||
casting_key.decomposition_base_log(),
|
||||
casting_key
|
||||
.lwe_keyswitch_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
casting_key
|
||||
.lwe_keyswitch_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
casting_key.lwe_keyswitch_key.decomposition_level_count(),
|
||||
casting_key.lwe_keyswitch_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -294,7 +319,7 @@ impl CudaProvenCompactCiphertextList {
|
||||
})
|
||||
.collect();
|
||||
|
||||
CudaProvenCompactCiphertextList {
|
||||
Self {
|
||||
h_proved_lists: h_proved_lists.clone(),
|
||||
d_compact_lists,
|
||||
}
|
||||
@@ -328,6 +353,18 @@ impl CudaProvenCompactCiphertextList {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for CudaProvenCompactCiphertextList {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let cpu_ct = ProvenCompactCiphertextList::deserialize(deserializer)?;
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
|
||||
Ok(Self::from_proven_compact_ciphertext_list(&cpu_ct, &streams))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "zk-pok")]
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -344,9 +381,12 @@ mod tests {
|
||||
use crate::integer::ciphertext::{CompactCiphertextList, DataKind};
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey;
|
||||
use crate::integer::gpu::key_switching_key::{
|
||||
CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial,
|
||||
};
|
||||
use crate::integer::gpu::zk::CudaProvenCompactCiphertextList;
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::key_switching_key::KeySwitchingKey;
|
||||
use crate::integer::{
|
||||
ClientKey, CompactPrivateKey, CompactPublicKey, CompressedServerKey,
|
||||
ProvenCompactCiphertextList,
|
||||
@@ -405,15 +445,16 @@ mod tests {
|
||||
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
|
||||
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
let sk = compressed_server_key.decompress();
|
||||
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
|
||||
|
||||
let compact_private_key = CompactPrivateKey::new(pke_params);
|
||||
let d_ksk = CudaKeySwitchingKey::new(
|
||||
(&compact_private_key, None),
|
||||
(&cks, &gpu_sk),
|
||||
ksk_params,
|
||||
&streams,
|
||||
);
|
||||
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params);
|
||||
let d_ksk_material =
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
||||
let d_ksk =
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material, &gpu_sk);
|
||||
|
||||
let pk = CompactPublicKey::new(&compact_private_key);
|
||||
|
||||
let msgs = (0..512)
|
||||
@@ -493,17 +534,17 @@ mod tests {
|
||||
.unwrap();
|
||||
let cks = ClientKey::new(fhe_params);
|
||||
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
|
||||
|
||||
let sk = compressed_server_key.decompress();
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
|
||||
|
||||
let compact_private_key = CompactPrivateKey::new(pke_params);
|
||||
let d_ksk = CudaKeySwitchingKey::new(
|
||||
(&compact_private_key, None),
|
||||
(&cks, &gpu_sk),
|
||||
ksk_params,
|
||||
&streams,
|
||||
);
|
||||
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params);
|
||||
let d_ksk_material =
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
||||
let d_ksk =
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material, &gpu_sk);
|
||||
|
||||
let pk = CompactPublicKey::new(&compact_private_key);
|
||||
|
||||
let msgs = (0..2).map(|_| random::<u64>()).collect::<Vec<_>>();
|
||||
@@ -583,17 +624,17 @@ mod tests {
|
||||
.unwrap();
|
||||
let cks = ClientKey::new(fhe_params);
|
||||
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
|
||||
|
||||
let sk = compressed_server_key.decompress();
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
|
||||
|
||||
let compact_private_key = CompactPrivateKey::new(pke_params);
|
||||
let d_ksk = CudaKeySwitchingKey::new(
|
||||
(&compact_private_key, None),
|
||||
(&cks, &gpu_sk),
|
||||
ksk_params,
|
||||
&streams,
|
||||
);
|
||||
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params);
|
||||
let d_ksk_material =
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
||||
let d_ksk =
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material, &gpu_sk);
|
||||
|
||||
let pk = CompactPublicKey::new(&compact_private_key);
|
||||
|
||||
let msgs = (0..2).map(|_| random::<u64>()).collect::<Vec<_>>();
|
||||
|
||||
@@ -97,8 +97,8 @@ impl ServerKey {
|
||||
// It should remain just enough space to add a carry
|
||||
let client_key = cks.as_ref();
|
||||
let max_degree = MaxDegree::integer_radix_server_key(
|
||||
client_key.key.parameters.message_modulus(),
|
||||
client_key.key.parameters.carry_modulus(),
|
||||
client_key.key.parameters().message_modulus(),
|
||||
client_key.key.parameters().carry_modulus(),
|
||||
);
|
||||
|
||||
let sks = crate::shortint::server_key::ServerKey::new_with_max_degree(
|
||||
@@ -115,8 +115,8 @@ impl ServerKey {
|
||||
{
|
||||
let client_key = cks.as_ref();
|
||||
let max_degree = MaxDegree::integer_crt_server_key(
|
||||
client_key.key.parameters.message_modulus(),
|
||||
client_key.key.parameters.carry_modulus(),
|
||||
client_key.key.parameters().message_modulus(),
|
||||
client_key.key.parameters().carry_modulus(),
|
||||
);
|
||||
|
||||
let sks = crate::shortint::server_key::ServerKey::new_with_max_degree(
|
||||
@@ -255,8 +255,8 @@ pub struct CompressedServerKey {
|
||||
impl CompressedServerKey {
|
||||
pub fn new_radix_compressed_server_key(client_key: &ClientKey) -> Self {
|
||||
let max_degree = MaxDegree::integer_radix_server_key(
|
||||
client_key.key.parameters.message_modulus(),
|
||||
client_key.key.parameters.carry_modulus(),
|
||||
client_key.key.parameters().message_modulus(),
|
||||
client_key.key.parameters().carry_modulus(),
|
||||
);
|
||||
|
||||
let key =
|
||||
|
||||
@@ -3,4 +3,4 @@ pub(crate) mod test_random_op_sequence;
|
||||
pub(crate) mod test_signed_erc20;
|
||||
pub(crate) mod test_signed_random_op_sequence;
|
||||
pub(crate) const NB_CTXT_LONG_RUN: usize = 32;
|
||||
pub(crate) const NB_TESTS_LONG_RUN: usize = 20000;
|
||||
pub(crate) const NB_TESTS_LONG_RUN: usize = 200;
|
||||
|
||||
@@ -588,6 +588,9 @@ pub(crate) fn random_op_sequence_test<P>(
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let input_degrees_right: Vec<u64> =
|
||||
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Input degrees left: {input_degrees_left:?}, right {input_degrees_right:?}, Output degrees {:?}", output_degrees);
|
||||
let decrypted_res: u64 = cks.decrypt(&res);
|
||||
let expected_res: u64 = clear_fn(clear_left, clear_right);
|
||||
|
||||
|
||||
@@ -428,7 +428,7 @@ where
|
||||
// Div executor
|
||||
let div_rem_executor = CpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
|
||||
// Div Rem Clear functions
|
||||
let clear_div_rem = |x: i64, y: i64| -> (i64, i64) { (x.wrapping_div(y), x % y) };
|
||||
let clear_div_rem = |x: i64, y: i64| -> (i64, i64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut div_rem_op: Vec<(
|
||||
SignedDivRemOpExecutor,
|
||||
@@ -680,6 +680,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let input_degrees_right: Vec<u64> =
|
||||
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Input degrees left: {input_degrees_left:?}, right {input_degrees_right:?}, Output degrees {:?}", output_degrees);
|
||||
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
|
||||
let expected_res: i64 = clear_fn(clear_left, clear_right);
|
||||
|
||||
@@ -731,6 +734,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
"Determinism check failed on unary op {fn_name} with clear input {clear_input}.",
|
||||
);
|
||||
let input_degrees: Vec<u64> = input.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output degrees {:?}", output_degrees);
|
||||
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
|
||||
let expected_res: i64 = clear_fn(clear_input);
|
||||
if i % 2 == 0 {
|
||||
@@ -774,6 +780,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
);
|
||||
let input_degrees_left: Vec<u64> =
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output degrees {:?}", output_degrees);
|
||||
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
|
||||
let expected_res: i64 = clear_fn(clear_left, clear_right);
|
||||
|
||||
@@ -829,6 +838,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let input_degrees_right: Vec<u64> =
|
||||
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output degrees {:?}", output_degrees);
|
||||
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
|
||||
let decrypt_signed_overflow = cks.decrypt_bool(&overflow);
|
||||
let (expected_res, expected_overflow) = clear_fn(clear_left, clear_right);
|
||||
@@ -889,6 +901,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
);
|
||||
let input_degrees_left: Vec<u64> =
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output degrees {:?}", output_degrees);
|
||||
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
|
||||
let decrypt_signed_overflow = cks.decrypt_bool(&overflow);
|
||||
let (expected_res, expected_overflow) = clear_fn(clear_left, clear_right);
|
||||
@@ -1020,6 +1035,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let input_degrees_right: Vec<u64> =
|
||||
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output degrees {:?}", output_degrees);
|
||||
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
|
||||
let expected_res = clear_fn(clear_bool, clear_left, clear_right);
|
||||
|
||||
@@ -1081,6 +1099,12 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let input_degrees_right: Vec<u64> =
|
||||
right_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees_q: Vec<u64> =
|
||||
res_q.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees_r: Vec<u64> =
|
||||
res_r.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output degrees {:?}", output_degrees_q);
|
||||
println!("Output degrees {:?}", output_degrees_r);
|
||||
let decrypt_signed_res_q: i64 = cks.decrypt_signed(&res_q);
|
||||
let decrypt_signed_res_r: i64 = cks.decrypt_signed(&res_r);
|
||||
let (expected_res_q, expected_res_r) = clear_fn(clear_left, clear_right);
|
||||
@@ -1147,6 +1171,12 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
);
|
||||
let input_degrees_left: Vec<u64> =
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_q_degrees: Vec<u64> =
|
||||
res_r.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_r_degrees: Vec<u64> =
|
||||
res_r.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output r degrees {:?}", output_r_degrees);
|
||||
println!("Output q degrees {:?}", output_q_degrees);
|
||||
let decrypt_signed_res_q: i64 = cks.decrypt_signed(&res_q);
|
||||
let decrypt_signed_res_r: i64 = cks.decrypt_signed(&res_r);
|
||||
let (expected_res_q, expected_res_r) = clear_fn(clear_left, clear_right);
|
||||
@@ -1205,6 +1235,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
"Determinism check failed on op {fn_name} with clear input {clear_input}.",
|
||||
);
|
||||
let input_degrees: Vec<u64> = input.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output degrees {:?}", output_degrees);
|
||||
let cast_res = sks.cast_to_signed(res, NB_CTXT_LONG_RUN);
|
||||
let decrypt_signed_res: i64 = cks.decrypt_signed(&cast_res);
|
||||
let expected_res = clear_fn(clear_input) as i64;
|
||||
@@ -1252,6 +1285,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let input_degrees_right: Vec<u64> =
|
||||
unsigned_right.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output degrees {:?}", output_degrees);
|
||||
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
|
||||
let expected_res: i64 = clear_fn(clear_left, clear_right as u64);
|
||||
|
||||
@@ -1297,6 +1333,9 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
);
|
||||
let input_degrees_left: Vec<u64> =
|
||||
left_vec[i].blocks.iter().map(|b| b.degree.0).collect();
|
||||
let output_degrees: Vec<u64> =
|
||||
res.blocks.iter().map(|b| b.degree.0).collect();
|
||||
println!("Output degrees {:?}", output_degrees);
|
||||
let decrypt_signed_res: i64 = cks.decrypt_signed(&res);
|
||||
let expected_res: i64 = clear_fn(clear_left, clear_right as u64);
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ where
|
||||
{
|
||||
let cks = cks.as_ref();
|
||||
|
||||
let max_degree_acceptable = cks.key.parameters.message_modulus().0 - 1;
|
||||
let max_degree_acceptable = cks.key.parameters().message_modulus().0 - 1;
|
||||
let num_blocks = ct.blocks.len();
|
||||
|
||||
for (i, block) in ct.blocks.iter().enumerate() {
|
||||
@@ -385,7 +385,7 @@ where
|
||||
{
|
||||
let cks = cks.as_ref();
|
||||
|
||||
let max_degree_acceptable = cks.key.parameters.message_modulus().0 - 1;
|
||||
let max_degree_acceptable = cks.key.parameters().message_modulus().0 - 1;
|
||||
|
||||
for (i, block) in ct.blocks.iter().enumerate() {
|
||||
if block.is_trivial() {
|
||||
|
||||
@@ -24,6 +24,8 @@ mod experimental {
|
||||
use crate::integer::{ClientKey, CrtCiphertext, IntegerCiphertext, RadixCiphertext, ServerKey};
|
||||
use crate::shortint::atomic_pattern::AtomicPattern;
|
||||
use crate::shortint::ciphertext::{Degree, NoiseLevel};
|
||||
use crate::shortint::client_key::atomic_pattern::EncryptionAtomicPattern;
|
||||
use crate::shortint::client_key::StandardClientKeyView;
|
||||
use crate::shortint::server_key::StandardServerKeyView;
|
||||
use crate::shortint::WopbsParameters;
|
||||
|
||||
@@ -238,12 +240,17 @@ mod experimental {
|
||||
sks.key.atomic_pattern.kind()
|
||||
)
|
||||
});
|
||||
|
||||
let cks = cks.as_ref();
|
||||
let ck = StandardClientKeyView::try_from(cks.key.as_view()).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"Wopbs is not supported by the chosen encryption atomic pattern: {:?}",
|
||||
cks.key.atomic_pattern.kind()
|
||||
)
|
||||
});
|
||||
|
||||
Self {
|
||||
wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key(
|
||||
&cks.as_ref().key,
|
||||
sk,
|
||||
parameters,
|
||||
),
|
||||
wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key(ck, sk, parameters),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,11 +275,16 @@ mod experimental {
|
||||
)
|
||||
});
|
||||
|
||||
let cks = cks.as_ref();
|
||||
let ck = StandardClientKeyView::try_from(cks.key.as_view()).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"Wopbs is not supported by the chosen encryption atomic pattern: {:?}",
|
||||
cks.key.atomic_pattern.kind()
|
||||
)
|
||||
});
|
||||
|
||||
Self {
|
||||
wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(
|
||||
&cks.as_ref().key,
|
||||
sk,
|
||||
),
|
||||
wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(ck, sk),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -563,7 +563,7 @@ impl Shortint {
|
||||
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(Seed(seed));
|
||||
ShortintClientKey(
|
||||
crate::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder)
|
||||
.new_client_key(parameters.0.into()),
|
||||
.new_client_key(parameters.0),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::core_crypto::algorithms::lwe_keyswitch_key_generation::allocate_and_generate_new_seeded_lwe_keyswitch_key;
|
||||
use crate::core_crypto::entities::lwe_secret_key::LweSecretKey;
|
||||
use crate::core_crypto::entities::seeded_lwe_keyswitch_key::SeededLweKeyswitchKeyOwned;
|
||||
use crate::shortint::atomic_pattern::ks32::KS32AtomicPatternServerKey;
|
||||
use crate::shortint::backward_compatibility::atomic_pattern::CompressedKS32AtomicPatternServerKeyVersions;
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::client_key::atomic_pattern::KS32AtomicPatternClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::parameters::{KeySwitch32PBSParameters, LweDimension};
|
||||
use crate::shortint::server_key::ShortintCompressedBootstrappingKey;
|
||||
@@ -22,24 +21,15 @@ pub struct CompressedKS32AtomicPatternServerKey {
|
||||
}
|
||||
|
||||
impl CompressedKS32AtomicPatternServerKey {
|
||||
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
pub fn new(cks: &KS32AtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
let params = &cks.parameters;
|
||||
|
||||
let pbs_params = params.ks32_parameters().unwrap();
|
||||
|
||||
let in_key = LweSecretKey::from_container(
|
||||
cks.small_lwe_secret_key()
|
||||
.as_ref()
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| x as u32)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let in_key = cks.small_lwe_secret_key();
|
||||
|
||||
let out_key = &cks.glwe_secret_key;
|
||||
|
||||
let bootstrapping_key_base =
|
||||
engine.new_compressed_bootstrapping_key_ks32(pbs_params, &in_key, out_key);
|
||||
engine.new_compressed_bootstrapping_key_ks32(*params, &in_key, out_key);
|
||||
|
||||
// Creation of the key switching key
|
||||
let key_switching_key = allocate_and_generate_new_seeded_lwe_keyswitch_key(
|
||||
@@ -47,8 +37,8 @@ impl CompressedKS32AtomicPatternServerKey {
|
||||
&in_key,
|
||||
params.ks_base_log(),
|
||||
params.ks_level(),
|
||||
pbs_params.lwe_noise_distribution(),
|
||||
pbs_params.post_keyswitch_ciphertext_modulus(),
|
||||
params.lwe_noise_distribution(),
|
||||
params.post_keyswitch_ciphertext_modulus(),
|
||||
&mut engine.seeder,
|
||||
);
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ pub use standard::*;
|
||||
use super::AtomicPatternServerKey;
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::shortint::backward_compatibility::atomic_pattern::CompressedAtomicPatternServerKeyVersions;
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::parameters::{AtomicPatternParameters, CiphertextModulus, LweDimension};
|
||||
@@ -24,14 +25,12 @@ pub enum CompressedAtomicPatternServerKey {
|
||||
|
||||
impl CompressedAtomicPatternServerKey {
|
||||
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
let params = &cks.parameters;
|
||||
|
||||
match params.ap_parameters().unwrap() {
|
||||
AtomicPatternParameters::Standard(_) => {
|
||||
Self::Standard(CompressedStandardAtomicPatternServerKey::new(cks, engine))
|
||||
}
|
||||
AtomicPatternParameters::KeySwitch32(_) => {
|
||||
Self::KeySwitch32(CompressedKS32AtomicPatternServerKey::new(cks, engine))
|
||||
match &cks.atomic_pattern {
|
||||
AtomicPatternClientKey::Standard(ap_cks) => Self::Standard(
|
||||
CompressedStandardAtomicPatternServerKey::new(ap_cks, engine),
|
||||
),
|
||||
AtomicPatternClientKey::KeySwitch32(ap_cks) => {
|
||||
Self::KeySwitch32(CompressedKS32AtomicPatternServerKey::new(ap_cks, engine))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::core_crypto::algorithms::lwe_keyswitch_key_generation::allocate_and_g
|
||||
use crate::core_crypto::entities::seeded_lwe_keyswitch_key::SeededLweKeyswitchKeyOwned;
|
||||
use crate::shortint::atomic_pattern::standard::StandardAtomicPatternServerKey;
|
||||
use crate::shortint::backward_compatibility::atomic_pattern::CompressedStandardAtomicPatternServerKeyVersions;
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::client_key::atomic_pattern::StandardAtomicPatternClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::parameters::{CiphertextModulus, LweDimension, PBSOrder, PBSParameters};
|
||||
use crate::shortint::server_key::ShortintCompressedBootstrappingKey;
|
||||
@@ -22,17 +22,15 @@ pub struct CompressedStandardAtomicPatternServerKey {
|
||||
}
|
||||
|
||||
impl CompressedStandardAtomicPatternServerKey {
|
||||
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
pub fn new(cks: &StandardAtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
let params = &cks.parameters;
|
||||
|
||||
let pbs_params_base = params.pbs_parameters().unwrap();
|
||||
|
||||
let in_key = &cks.small_lwe_secret_key();
|
||||
|
||||
let out_key = &cks.glwe_secret_key;
|
||||
|
||||
let bootstrapping_key_base =
|
||||
engine.new_compressed_bootstrapping_key(pbs_params_base, in_key, out_key);
|
||||
engine.new_compressed_bootstrapping_key(*params, in_key, out_key);
|
||||
|
||||
// Creation of the key switching key
|
||||
let key_switching_key = allocate_and_generate_new_seeded_lwe_keyswitch_key(
|
||||
@@ -48,7 +46,7 @@ impl CompressedStandardAtomicPatternServerKey {
|
||||
Self::from_raw_parts(
|
||||
key_switching_key,
|
||||
bootstrapping_key_base,
|
||||
pbs_params_base.encryption_key_choice().into(),
|
||||
params.encryption_key_choice().into(),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,11 +10,12 @@ use crate::conformance::ParameterSetConformant;
|
||||
use crate::core_crypto::prelude::{
|
||||
allocate_and_generate_new_lwe_keyswitch_key, extract_lwe_sample_from_glwe_ciphertext,
|
||||
keyswitch_lwe_ciphertext_with_scalar_change, CiphertextModulus as CoreCiphertextModulus,
|
||||
LweCiphertext, LweCiphertextOwned, LweDimension, LweKeyswitchKeyOwned, LweSecretKey,
|
||||
MonomialDegree, MsDecompressionType,
|
||||
LweCiphertext, LweCiphertextOwned, LweDimension, LweKeyswitchKeyOwned, MonomialDegree,
|
||||
MsDecompressionType,
|
||||
};
|
||||
use crate::shortint::backward_compatibility::atomic_pattern::KS32AtomicPatternServerKeyVersions;
|
||||
use crate::shortint::ciphertext::{CompressedModulusSwitchedCiphertext, Degree, NoiseLevel};
|
||||
use crate::shortint::client_key::atomic_pattern::KS32AtomicPatternClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::oprf::generate_pseudo_random_from_pbs;
|
||||
use crate::shortint::parameters::KeySwitch32PBSParameters;
|
||||
@@ -22,7 +23,7 @@ use crate::shortint::server_key::{
|
||||
decompress_and_apply_lookup_table, switch_modulus_and_compress, LookupTableOwned,
|
||||
LookupTableSize, ManyLookupTableOwned, ShortintBootstrappingKey,
|
||||
};
|
||||
use crate::shortint::{Ciphertext, CiphertextModulus, ClientKey};
|
||||
use crate::shortint::{Ciphertext, CiphertextModulus};
|
||||
|
||||
/// The definition of the server key elements used in the
|
||||
/// [`KeySwitch32`](AtomicPatternKind::KeySwitch32) atomic pattern
|
||||
@@ -57,24 +58,14 @@ impl ParameterSetConformant for KS32AtomicPatternServerKey {
|
||||
}
|
||||
|
||||
impl KS32AtomicPatternServerKey {
|
||||
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
pub fn new(cks: &KS32AtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
let params = &cks.parameters;
|
||||
|
||||
let pbs_params = params.ks32_parameters().unwrap();
|
||||
|
||||
let in_key = LweSecretKey::from_container(
|
||||
cks.small_lwe_secret_key()
|
||||
.as_ref()
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| x as u32)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let in_key = cks.small_lwe_secret_key();
|
||||
|
||||
let out_key = &cks.glwe_secret_key;
|
||||
|
||||
let bootstrapping_key_base =
|
||||
engine.new_bootstrapping_key_ks32(pbs_params, &in_key, out_key);
|
||||
let bootstrapping_key_base = engine.new_bootstrapping_key_ks32(*params, &in_key, out_key);
|
||||
|
||||
// Creation of the key switching key
|
||||
let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
@@ -82,8 +73,8 @@ impl KS32AtomicPatternServerKey {
|
||||
&in_key,
|
||||
params.ks_base_log(),
|
||||
params.ks_level(),
|
||||
pbs_params.lwe_noise_distribution(),
|
||||
pbs_params.post_keyswitch_ciphertext_modulus(),
|
||||
params.lwe_noise_distribution(),
|
||||
params.post_keyswitch_ciphertext_modulus(),
|
||||
&mut engine.encryption_generator,
|
||||
);
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ use crate::core_crypto::prelude::{
|
||||
|
||||
use super::backward_compatibility::atomic_pattern::*;
|
||||
use super::ciphertext::{CompressedModulusSwitchedCiphertext, Degree};
|
||||
use super::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use super::engine::ShortintEngine;
|
||||
use super::parameters::{DynamicDistribution, KeySwitch32PBSParameters};
|
||||
use super::prelude::{DecompositionBaseLog, DecompositionLevelCount};
|
||||
@@ -259,14 +260,12 @@ pub enum AtomicPatternServerKey {
|
||||
|
||||
impl AtomicPatternServerKey {
|
||||
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
let params = &cks.parameters;
|
||||
|
||||
match params.ap_parameters().unwrap() {
|
||||
AtomicPatternParameters::Standard(_) => {
|
||||
Self::Standard(StandardAtomicPatternServerKey::new(cks, engine))
|
||||
match &cks.atomic_pattern {
|
||||
AtomicPatternClientKey::Standard(ap_cks) => {
|
||||
Self::Standard(StandardAtomicPatternServerKey::new(ap_cks, engine))
|
||||
}
|
||||
AtomicPatternParameters::KeySwitch32(_) => {
|
||||
Self::KeySwitch32(KS32AtomicPatternServerKey::new(cks, engine))
|
||||
AtomicPatternClientKey::KeySwitch32(ap_cks) => {
|
||||
Self::KeySwitch32(KS32AtomicPatternServerKey::new(ap_cks, engine))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,13 +14,14 @@ use crate::core_crypto::prelude::{
|
||||
};
|
||||
use crate::shortint::backward_compatibility::atomic_pattern::StandardAtomicPatternServerKeyVersions;
|
||||
use crate::shortint::ciphertext::{CompressedModulusSwitchedCiphertext, Degree, NoiseLevel};
|
||||
use crate::shortint::client_key::atomic_pattern::StandardAtomicPatternClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::oprf::generate_pseudo_random_from_pbs;
|
||||
use crate::shortint::server_key::{
|
||||
decompress_and_apply_lookup_table, switch_modulus_and_compress, LookupTableOwned,
|
||||
LookupTableSize, ManyLookupTableOwned, ShortintBootstrappingKey,
|
||||
};
|
||||
use crate::shortint::{Ciphertext, CiphertextModulus, ClientKey, PBSOrder, PBSParameters};
|
||||
use crate::shortint::{Ciphertext, CiphertextModulus, PBSOrder, PBSParameters};
|
||||
|
||||
/// The definition of the server key elements used in the [`Standard`](AtomicPatternKind::Standard)
|
||||
/// atomic pattern
|
||||
@@ -58,16 +59,14 @@ impl ParameterSetConformant for StandardAtomicPatternServerKey {
|
||||
}
|
||||
|
||||
impl StandardAtomicPatternServerKey {
|
||||
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
pub fn new(cks: &StandardAtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
|
||||
let params = &cks.parameters;
|
||||
|
||||
let pbs_params_base = params.pbs_parameters().unwrap();
|
||||
|
||||
let in_key = &cks.small_lwe_secret_key();
|
||||
|
||||
let out_key = &cks.glwe_secret_key;
|
||||
|
||||
let bootstrapping_key_base = engine.new_bootstrapping_key(pbs_params_base, in_key, out_key);
|
||||
let bootstrapping_key_base = engine.new_bootstrapping_key(*params, in_key, out_key);
|
||||
|
||||
// Creation of the key switching key
|
||||
let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
@@ -83,7 +82,7 @@ impl StandardAtomicPatternServerKey {
|
||||
Self::from_raw_parts(
|
||||
key_switching_key,
|
||||
bootstrapping_key_base,
|
||||
pbs_params_base.encryption_key_choice().into(),
|
||||
params.encryption_key_choice().into(),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
use tfhe_versionable::VersionsDispatch;
|
||||
|
||||
use crate::shortint::client_key::atomic_pattern::{
|
||||
AtomicPatternClientKey, KS32AtomicPatternClientKey, StandardAtomicPatternClientKey,
|
||||
};
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
pub enum AtomicPatternClientKeyVersions {
|
||||
V0(AtomicPatternClientKey),
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
pub enum StandardAtomicPatternClientKeyVersions {
|
||||
V0(StandardAtomicPatternClientKey),
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
pub enum KS32AtomicPatternClientKeyVersions {
|
||||
V0(KS32AtomicPatternClientKey),
|
||||
}
|
||||
@@ -1,8 +1,67 @@
|
||||
use tfhe_versionable::VersionsDispatch;
|
||||
pub mod atomic_pattern;
|
||||
|
||||
use crate::shortint::ClientKey;
|
||||
use std::any::{Any, TypeId};
|
||||
|
||||
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||
|
||||
use crate::core_crypto::prelude::{GlweSecretKeyOwned, LweSecretKeyOwned};
|
||||
use crate::shortint::client_key::atomic_pattern::{
|
||||
AtomicPatternClientKey, StandardAtomicPatternClientKey,
|
||||
};
|
||||
use crate::shortint::client_key::GenericClientKey;
|
||||
use crate::shortint::ShortintParameterSet;
|
||||
use crate::Error;
|
||||
|
||||
#[derive(Version)]
|
||||
pub struct ClientKeyV0 {
|
||||
glwe_secret_key: GlweSecretKeyOwned<u64>,
|
||||
lwe_secret_key: LweSecretKeyOwned<u64>,
|
||||
parameters: ShortintParameterSet,
|
||||
}
|
||||
|
||||
impl<AP: 'static> Upgrade<GenericClientKey<AP>> for ClientKeyV0 {
|
||||
type Error = Error;
|
||||
|
||||
fn upgrade(self) -> Result<GenericClientKey<AP>, Self::Error> {
|
||||
let ap_params = self.parameters.pbs_parameters().ok_or_else(|| {
|
||||
Error::new(
|
||||
"ClientKey from TFHE-rs 1.2 and before needs PBS parameters to be upgraded to the latest version"
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let std_ap = StandardAtomicPatternClientKey::from_raw_parts(
|
||||
self.glwe_secret_key,
|
||||
self.lwe_secret_key,
|
||||
ap_params,
|
||||
self.parameters.wopbs_parameters(),
|
||||
);
|
||||
|
||||
if TypeId::of::<AP>() == TypeId::of::<AtomicPatternClientKey>() {
|
||||
let atomic_pattern = AtomicPatternClientKey::Standard(std_ap);
|
||||
let ck: Box<dyn Any + 'static> = Box::new(GenericClientKey { atomic_pattern });
|
||||
Ok(*ck.downcast::<GenericClientKey<AP>>().unwrap()) // We know from the TypeId that
|
||||
// AP is of the right type so we
|
||||
// can unwrap
|
||||
} else if TypeId::of::<AP>() == TypeId::of::<StandardAtomicPatternClientKey>() {
|
||||
let ck: Box<dyn Any + 'static> = Box::new(GenericClientKey {
|
||||
atomic_pattern: std_ap,
|
||||
});
|
||||
Ok(*ck.downcast::<GenericClientKey<AP>>().unwrap()) // We know from the TypeId that
|
||||
// AP is of the right type so we
|
||||
// can unwrap
|
||||
} else {
|
||||
Err(Error::new(
|
||||
"ClientKey from TFHE-rs 1.2 and before can only be deserialized to the standard \
|
||||
Atomic Pattern"
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
pub enum ClientKeyVersions {
|
||||
V0(ClientKey),
|
||||
pub enum ClientKeyVersions<AP> {
|
||||
V0(ClientKeyV0),
|
||||
V1(GenericClientKey<AP>),
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ pub struct ServerKeyV1 {
|
||||
pbs_order: PBSOrder,
|
||||
}
|
||||
|
||||
impl<AP: Clone + 'static> Upgrade<GenericServerKey<AP>> for ServerKeyV1 {
|
||||
impl<AP: 'static> Upgrade<GenericServerKey<AP>> for ServerKeyV1 {
|
||||
type Error = Error;
|
||||
|
||||
fn upgrade(self) -> Result<GenericServerKey<AP>, Self::Error> {
|
||||
@@ -94,32 +94,30 @@ impl<AP: Clone + 'static> Upgrade<GenericServerKey<AP>> for ServerKeyV1 {
|
||||
|
||||
if TypeId::of::<AP>() == TypeId::of::<AtomicPatternServerKey>() {
|
||||
let ap = AtomicPatternServerKey::Standard(std_ap);
|
||||
let sk = ServerKey::from_raw_parts(
|
||||
let sk: Box<dyn Any + 'static> = Box::new(ServerKey::from_raw_parts(
|
||||
ap,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
self.max_degree,
|
||||
self.max_noise_level,
|
||||
);
|
||||
Ok((&sk as &dyn Any)
|
||||
.downcast_ref::<GenericServerKey<AP>>()
|
||||
.unwrap() // We know from the TypeId that AP is of the right type so we can unwrap
|
||||
.clone())
|
||||
));
|
||||
Ok(*sk.downcast::<GenericServerKey<AP>>().unwrap()) // We know from the TypeId that
|
||||
// AP is of the right type so we
|
||||
// can unwrap
|
||||
} else if TypeId::of::<AP>() == TypeId::of::<StandardAtomicPatternServerKey>() {
|
||||
let sk = StandardServerKey::from_raw_parts(
|
||||
let sk: Box<dyn Any + 'static> = Box::new(StandardServerKey::from_raw_parts(
|
||||
std_ap,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
self.max_degree,
|
||||
self.max_noise_level,
|
||||
);
|
||||
Ok((&sk as &dyn Any)
|
||||
.downcast_ref::<GenericServerKey<AP>>()
|
||||
.unwrap() // We know from the TypeId that AP is of the right type so we can unwrap
|
||||
.clone())
|
||||
));
|
||||
Ok(*sk.downcast::<GenericServerKey<AP>>().unwrap()) // We know from the TypeId that
|
||||
// AP is of the right type so we
|
||||
// can unwrap
|
||||
} else {
|
||||
Err(Error::new(
|
||||
"ServerKey from TFHE-rs 1.0 and before can only be deserialized to the classical \
|
||||
"ServerKey from TFHE-rs 1.0 and before can only be deserialized to the standard \
|
||||
Atomic Pattern"
|
||||
.to_string(),
|
||||
))
|
||||
|
||||
@@ -26,21 +26,20 @@ where
|
||||
type Error = Error;
|
||||
|
||||
fn upgrade(self) -> Result<ModulusSwitchNoiseReductionKey<InputScalar>, Self::Error> {
|
||||
let modulus_switch_zeros = &self.modulus_switch_zeros as &dyn Any;
|
||||
let modulus_switch_zeros: Box<dyn Any> = Box::new(self.modulus_switch_zeros);
|
||||
|
||||
// Keys from previous versions where only stored as u64, we check if the destination
|
||||
// key is also u64 or we return an error
|
||||
Ok(ModulusSwitchNoiseReductionKey {
|
||||
modulus_switch_zeros: modulus_switch_zeros
|
||||
.downcast_ref::<LweCiphertextListOwned<InputScalar>>()
|
||||
.ok_or_else(|| {
|
||||
modulus_switch_zeros: *modulus_switch_zeros
|
||||
.downcast::<LweCiphertextListOwned<InputScalar>>()
|
||||
.map_err(|_| {
|
||||
Error::new(format!(
|
||||
"Expected u64 as InputScalar while upgrading \
|
||||
ModulusSwitchNoiseReductionKey, got {}",
|
||||
std::any::type_name::<InputScalar>(),
|
||||
))
|
||||
})?
|
||||
.clone(),
|
||||
})?,
|
||||
ms_bound: self.ms_bound,
|
||||
ms_r_sigma_factor: self.ms_r_sigma_factor,
|
||||
ms_input_variance: self.ms_input_variance,
|
||||
@@ -73,21 +72,20 @@ where
|
||||
type Error = Error;
|
||||
|
||||
fn upgrade(self) -> Result<CompressedModulusSwitchNoiseReductionKey<InputScalar>, Self::Error> {
|
||||
let modulus_switch_zeros = &self.modulus_switch_zeros as &dyn Any;
|
||||
let modulus_switch_zeros: Box<dyn Any> = Box::new(self.modulus_switch_zeros);
|
||||
|
||||
// Keys from previous versions where only stored as u64, we check if the destination
|
||||
// key is also u64 or we return an error
|
||||
Ok(CompressedModulusSwitchNoiseReductionKey {
|
||||
modulus_switch_zeros: modulus_switch_zeros
|
||||
.downcast_ref::<SeededLweCiphertextListOwned<InputScalar>>()
|
||||
.ok_or_else(|| {
|
||||
modulus_switch_zeros: *modulus_switch_zeros
|
||||
.downcast::<SeededLweCiphertextListOwned<InputScalar>>()
|
||||
.map_err(|_| {
|
||||
Error::new(format!(
|
||||
"Expected u64 as InputScalar while upgrading \
|
||||
CompressedModulusSwitchNoiseReductionKey, got {}",
|
||||
std::any::type_name::<InputScalar>(),
|
||||
))
|
||||
})?
|
||||
.clone(),
|
||||
})?,
|
||||
ms_bound: self.ms_bound,
|
||||
ms_r_sigma_factor: self.ms_r_sigma_factor,
|
||||
ms_input_variance: self.ms_input_variance,
|
||||
|
||||
169
tfhe/src/shortint/client_key/atomic_pattern/ks32.rs
Normal file
169
tfhe/src/shortint/client_key/atomic_pattern/ks32.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use crate::core_crypto::prelude::{
|
||||
allocate_and_generate_new_binary_glwe_secret_key,
|
||||
allocate_and_generate_new_binary_lwe_secret_key,
|
||||
};
|
||||
use crate::shortint::backward_compatibility::client_key::atomic_pattern::KS32AtomicPatternClientKeyVersions;
|
||||
use crate::shortint::client_key::{GlweSecretKeyOwned, LweSecretKeyOwned, LweSecretKeyView};
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::parameters::{DynamicDistribution, KeySwitch32PBSParameters};
|
||||
use crate::shortint::{AtomicPatternKind, ShortintParameterSet};
|
||||
|
||||
use super::EncryptionAtomicPattern;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
|
||||
#[versionize(KS32AtomicPatternClientKeyVersions)]
|
||||
pub struct KS32AtomicPatternClientKey {
|
||||
pub(crate) glwe_secret_key: GlweSecretKeyOwned<u64>,
|
||||
/// Key used as the output of the keyswitch operation
|
||||
pub(crate) lwe_secret_key: LweSecretKeyOwned<u32>,
|
||||
pub parameters: KeySwitch32PBSParameters,
|
||||
}
|
||||
|
||||
impl KS32AtomicPatternClientKey {
|
||||
pub(crate) fn new_with_engine(
|
||||
parameters: KeySwitch32PBSParameters,
|
||||
engine: &mut ShortintEngine,
|
||||
) -> Self {
|
||||
// generate the lwe secret key
|
||||
let lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
parameters.lwe_dimension(),
|
||||
&mut engine.secret_generator,
|
||||
);
|
||||
|
||||
// generate the rlwe secret key
|
||||
let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
parameters.glwe_dimension(),
|
||||
parameters.polynomial_size(),
|
||||
&mut engine.secret_generator,
|
||||
);
|
||||
|
||||
// pack the keys in the client key set
|
||||
Self {
|
||||
glwe_secret_key,
|
||||
lwe_secret_key,
|
||||
parameters,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(parameters: KeySwitch32PBSParameters) -> Self {
|
||||
ShortintEngine::with_thread_local_mut(|engine| Self::new_with_engine(parameters, engine))
|
||||
}
|
||||
|
||||
pub fn into_raw_parts(
|
||||
self,
|
||||
) -> (
|
||||
GlweSecretKeyOwned<u64>,
|
||||
LweSecretKeyOwned<u32>,
|
||||
KeySwitch32PBSParameters,
|
||||
) {
|
||||
let Self {
|
||||
glwe_secret_key,
|
||||
lwe_secret_key,
|
||||
parameters,
|
||||
} = self;
|
||||
|
||||
(glwe_secret_key, lwe_secret_key, parameters)
|
||||
}
|
||||
|
||||
pub fn from_raw_parts(
|
||||
glwe_secret_key: GlweSecretKeyOwned<u64>,
|
||||
lwe_secret_key: LweSecretKeyOwned<u32>,
|
||||
parameters: KeySwitch32PBSParameters,
|
||||
) -> Self {
|
||||
assert_eq!(
|
||||
lwe_secret_key.lwe_dimension(),
|
||||
parameters.lwe_dimension(),
|
||||
"Mismatch between the LweSecretKey LweDimension ({:?}) \
|
||||
and the parameters LweDimension ({:?})",
|
||||
lwe_secret_key.lwe_dimension(),
|
||||
parameters.lwe_dimension()
|
||||
);
|
||||
assert_eq!(
|
||||
glwe_secret_key.glwe_dimension(),
|
||||
parameters.glwe_dimension(),
|
||||
"Mismatch between the GlweSecretKey GlweDimension ({:?}) \
|
||||
and the parameters GlweDimension ({:?})",
|
||||
glwe_secret_key.glwe_dimension(),
|
||||
parameters.glwe_dimension()
|
||||
);
|
||||
assert_eq!(
|
||||
glwe_secret_key.polynomial_size(),
|
||||
parameters.polynomial_size(),
|
||||
"Mismatch between the GlweSecretKey PolynomialSize ({:?}) \
|
||||
and the parameters PolynomialSize ({:?})",
|
||||
glwe_secret_key.polynomial_size(),
|
||||
parameters.polynomial_size()
|
||||
);
|
||||
|
||||
Self {
|
||||
glwe_secret_key,
|
||||
lwe_secret_key,
|
||||
parameters,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_from_lwe_encryption_key(
|
||||
encryption_key: LweSecretKeyOwned<u64>,
|
||||
parameters: KeySwitch32PBSParameters,
|
||||
) -> crate::Result<Self> {
|
||||
let expected_lwe_dimension = parameters.encryption_lwe_dimension();
|
||||
if encryption_key.lwe_dimension() != expected_lwe_dimension {
|
||||
return Err(
|
||||
crate::Error::new(
|
||||
format!(
|
||||
"The given encryption key does not have the correct LweDimension, expected: {:?}, got: {:?}",
|
||||
encryption_key.lwe_dimension(),
|
||||
expected_lwe_dimension)));
|
||||
}
|
||||
|
||||
// The key we got is the one used to encrypt,
|
||||
// we have to generate the other key. The KS32 ap only support KS-PBS order so we need to
|
||||
// generate the small key.
|
||||
let small_key = ShortintEngine::with_thread_local_mut(|engine| {
|
||||
allocate_and_generate_new_binary_lwe_secret_key(
|
||||
parameters.lwe_dimension(),
|
||||
&mut engine.secret_generator,
|
||||
)
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
glwe_secret_key: GlweSecretKeyOwned::from_container(
|
||||
encryption_key.into_container(),
|
||||
parameters.polynomial_size(),
|
||||
),
|
||||
lwe_secret_key: small_key,
|
||||
parameters,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn large_lwe_secret_key(&self) -> LweSecretKeyView<'_, u64> {
|
||||
self.glwe_secret_key.as_lwe_secret_key()
|
||||
}
|
||||
|
||||
pub fn small_lwe_secret_key(&self) -> LweSecretKeyView<'_, u32> {
|
||||
self.lwe_secret_key.as_view()
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptionAtomicPattern for KS32AtomicPatternClientKey {
|
||||
fn parameters(&self) -> ShortintParameterSet {
|
||||
self.parameters.into()
|
||||
}
|
||||
|
||||
fn encryption_key(&self) -> LweSecretKeyView<'_, u64> {
|
||||
// The KS32 atomic pattern is only supported with the KsPbs order
|
||||
self.glwe_secret_key.as_lwe_secret_key()
|
||||
}
|
||||
|
||||
fn encryption_noise(&self) -> DynamicDistribution<u64> {
|
||||
// The KS32 atomic pattern is only supported with the KsPbs order
|
||||
self.parameters.glwe_noise_distribution()
|
||||
}
|
||||
|
||||
fn kind(&self) -> AtomicPatternKind {
|
||||
AtomicPatternKind::KeySwitch32
|
||||
}
|
||||
}
|
||||
135
tfhe/src/shortint/client_key/atomic_pattern/mod.rs
Normal file
135
tfhe/src/shortint/client_key/atomic_pattern/mod.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
pub mod ks32;
|
||||
pub mod standard;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use crate::shortint::backward_compatibility::client_key::atomic_pattern::AtomicPatternClientKeyVersions;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::parameters::DynamicDistribution;
|
||||
use crate::shortint::{AtomicPatternKind, AtomicPatternParameters, ShortintParameterSet};
|
||||
|
||||
use super::{LweSecretKeyOwned, LweSecretKeyView};
|
||||
|
||||
pub use ks32::*;
|
||||
pub use standard::*;
|
||||
|
||||
/// An atomic pattern used for encryption
|
||||
///
|
||||
/// This category of atomic patterns will be used by the [`ClientKey`](super::ClientKey) to encrypt
|
||||
/// ciphertexts, and to generate a [`ServerKey`](crate::shortint::ServerKey) that can be used for
|
||||
/// evaluation.
|
||||
pub trait EncryptionAtomicPattern {
|
||||
/// The parameters associated with this client key
|
||||
fn parameters(&self) -> ShortintParameterSet;
|
||||
|
||||
/// The secret key used for encryption
|
||||
fn encryption_key(&self) -> LweSecretKeyView<'_, u64>;
|
||||
|
||||
/// The noise distribution used for encryption
|
||||
fn encryption_noise(&self) -> DynamicDistribution<u64>;
|
||||
|
||||
/// The kind of atomic pattern that will be used by the generated
|
||||
/// [`ServerKey`](crate::shortint::ServerKey)
|
||||
fn kind(&self) -> AtomicPatternKind;
|
||||
|
||||
fn encryption_key_and_noise(&self) -> (LweSecretKeyView<'_, u64>, DynamicDistribution<u64>) {
|
||||
(self.encryption_key(), self.encryption_noise())
|
||||
}
|
||||
}
|
||||
|
||||
// This blancket impl is used to allow "views" of client keys, without having to re-implement the
|
||||
// trait
|
||||
impl<T: EncryptionAtomicPattern> EncryptionAtomicPattern for &T {
|
||||
fn parameters(&self) -> ShortintParameterSet {
|
||||
(*self).parameters()
|
||||
}
|
||||
|
||||
fn encryption_key(&self) -> LweSecretKeyView<'_, u64> {
|
||||
(*self).encryption_key()
|
||||
}
|
||||
|
||||
fn encryption_noise(&self) -> DynamicDistribution<u64> {
|
||||
(*self).encryption_noise()
|
||||
}
|
||||
|
||||
fn kind(&self) -> AtomicPatternKind {
|
||||
(*self).kind()
|
||||
}
|
||||
}
|
||||
|
||||
/// The client key materials for all the supported Atomic Patterns
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
|
||||
#[versionize(AtomicPatternClientKeyVersions)]
|
||||
#[allow(clippy::large_enum_variant)] // The difference in size is just because of the wopbs params in std key
|
||||
pub enum AtomicPatternClientKey {
|
||||
Standard(StandardAtomicPatternClientKey),
|
||||
KeySwitch32(KS32AtomicPatternClientKey),
|
||||
}
|
||||
|
||||
impl AtomicPatternClientKey {
|
||||
pub(crate) fn new_with_engine(
|
||||
parameters: AtomicPatternParameters,
|
||||
engine: &mut ShortintEngine,
|
||||
) -> Self {
|
||||
match parameters {
|
||||
AtomicPatternParameters::Standard(ap_params) => Self::Standard(
|
||||
StandardAtomicPatternClientKey::new_with_engine(ap_params, None, engine),
|
||||
),
|
||||
AtomicPatternParameters::KeySwitch32(ap_params) => Self::KeySwitch32(
|
||||
KS32AtomicPatternClientKey::new_with_engine(ap_params, engine),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(parameters: AtomicPatternParameters) -> Self {
|
||||
ShortintEngine::with_thread_local_mut(|engine| Self::new_with_engine(parameters, engine))
|
||||
}
|
||||
|
||||
pub fn try_from_lwe_encryption_key(
|
||||
encryption_key: LweSecretKeyOwned<u64>,
|
||||
parameters: AtomicPatternParameters,
|
||||
) -> crate::Result<Self> {
|
||||
match parameters {
|
||||
AtomicPatternParameters::Standard(ap_params) => Ok(Self::Standard(
|
||||
StandardAtomicPatternClientKey::try_from_lwe_encryption_key(
|
||||
encryption_key,
|
||||
ap_params,
|
||||
)?,
|
||||
)),
|
||||
AtomicPatternParameters::KeySwitch32(ap_params) => Ok(Self::KeySwitch32(
|
||||
KS32AtomicPatternClientKey::try_from_lwe_encryption_key(encryption_key, ap_params)?,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptionAtomicPattern for AtomicPatternClientKey {
|
||||
fn parameters(&self) -> ShortintParameterSet {
|
||||
match self {
|
||||
Self::Standard(ap) => ap.parameters.into(),
|
||||
Self::KeySwitch32(ap) => ap.parameters.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn encryption_key(&self) -> LweSecretKeyView<'_, u64> {
|
||||
match self {
|
||||
Self::Standard(ap) => ap.encryption_key(),
|
||||
Self::KeySwitch32(ap) => ap.encryption_key(),
|
||||
}
|
||||
}
|
||||
|
||||
fn encryption_noise(&self) -> DynamicDistribution<u64> {
|
||||
match self {
|
||||
Self::Standard(ap) => ap.encryption_noise(),
|
||||
Self::KeySwitch32(ap) => ap.encryption_noise(),
|
||||
}
|
||||
}
|
||||
|
||||
fn kind(&self) -> AtomicPatternKind {
|
||||
match self {
|
||||
Self::Standard(ap_cks) => ap_cks.kind(),
|
||||
Self::KeySwitch32(ap_cks) => ap_cks.kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user