mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
Compare commits
34 Commits
tfhe-versi
...
as/lut_cac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4ea48165b | ||
|
|
0a6b62627d | ||
|
|
6deeb66bf8 | ||
|
|
17022dae69 | ||
|
|
09802dd5ee | ||
|
|
e3fe433a35 | ||
|
|
2bea35a3b5 | ||
|
|
e2bf226276 | ||
|
|
c66f1c6d8b | ||
|
|
9bfe190ad3 | ||
|
|
e40070db0e | ||
|
|
e8d5ceac68 | ||
|
|
f1526b29d8 | ||
|
|
602e0c5a19 | ||
|
|
163c1eeffb | ||
|
|
3bcb9c8360 | ||
|
|
20a64abaf1 | ||
|
|
4b31987a45 | ||
|
|
b27fbc5d78 | ||
|
|
3d797e4823 | ||
|
|
51ef40ace3 | ||
|
|
c560462a4a | ||
|
|
67ed05a008 | ||
|
|
236eea5bd7 | ||
|
|
a1d3262726 | ||
|
|
afbeebc1b4 | ||
|
|
09cd5c1727 | ||
|
|
521f1516bb | ||
|
|
3c171136ad | ||
|
|
6f360968df | ||
|
|
37a0c58cb9 | ||
|
|
99590e3b0f | ||
|
|
6300a025d9 | ||
|
|
7222bff5d6 |
2
.github/workflows/benchmark_hpu_common.yml
vendored
2
.github/workflows/benchmark_hpu_common.yml
vendored
@@ -187,7 +187,7 @@ jobs:
|
||||
- name: Upload parsed results artifact
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: ${{ github.sha }}_${{ matrix.bench_type }}_integer_benchmarks
|
||||
name: ${{ github.sha }}_${{ matrix.bench_type }}_${{ matrix.command }}_benchmarks
|
||||
path: ${{ env.RESULTS_FILENAME }}
|
||||
|
||||
- name: Checkout Slab repo
|
||||
|
||||
67
.github/workflows/pr_milestone_check.yml
vendored
Normal file
67
.github/workflows/pr_milestone_check.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: pr_milestone_check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, synchronize, reopened, milestoned, demilestoned]
|
||||
|
||||
permissions: {}
|
||||
|
||||
# zizmor: ignore[concurrency-limits] only Zama organization members can trigger this workflow
|
||||
# external contributors workflows are manually approved
|
||||
|
||||
jobs:
|
||||
check-empty-milestone:
|
||||
name: pr_milestone_check/check-empty-milestone
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.milestone == null
|
||||
permissions:
|
||||
pull-requests: write # Need write access on pull requests to post comment
|
||||
|
||||
steps:
|
||||
- name: Post Reminder Comment
|
||||
uses: octokit/request-action@dad4362715b7fb2ddedf9772c8670824af564f0d # v2.4.0
|
||||
with:
|
||||
route: POST /repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments
|
||||
body: |
|
||||
'### ❌ Milestone Missing
|
||||
|
||||
Please assign a milestone to this pull request. If your PR targets the next version of
|
||||
TFHE-rs please use the current quarter milestone, e.g. "Q1 26".
|
||||
|
||||
If your PR targets a patch version for previous releases: consider creating a dedicated
|
||||
milestone e.g. v1.5.1 if it does not exist yet.'
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Check Final Status
|
||||
run: |
|
||||
echo "::error::Milestone is missing. This check is failing."
|
||||
exit 1
|
||||
|
||||
check-milestone-open:
|
||||
name: pr_milestone_check/check-milestone-open
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.milestone != null && github.event.pull_request.milestone.state == 'closed'
|
||||
permissions:
|
||||
pull-requests: write # Need write access on pull requests to post comment
|
||||
|
||||
steps:
|
||||
- name: Post Reminder Comment
|
||||
uses: octokit/request-action@dad4362715b7fb2ddedf9772c8670824af564f0d # v2.4.0
|
||||
with:
|
||||
route: POST /repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments
|
||||
body: |
|
||||
'### ❌ Milestone is closed
|
||||
|
||||
Please assign an open milestone to this pull request. If your PR targets the next version of
|
||||
TFHE-rs please use the current quarter milestone, e.g. "Q1 26".
|
||||
|
||||
If your PR targets a patch version for previous releases: consider creating a dedicated
|
||||
milestone e.g. v1.5.1 if it does not exist yet.'
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Check Final Status
|
||||
run: |
|
||||
echo "::error::Milestone is closed. This check is failing."
|
||||
exit 1
|
||||
7
Makefile
7
Makefile
@@ -1454,6 +1454,13 @@ bench_integer_aes256_gpu: install_rs_check_toolchain
|
||||
--bench integer-aes256 \
|
||||
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
|
||||
|
||||
.PHONY: bench_integer_trivium_gpu # Run benchmarks for trivium on GPU backend
|
||||
bench_integer_trivium_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench integer-trivium \
|
||||
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
|
||||
|
||||
.PHONY: bench_integer_multi_bit # Run benchmarks for unsigned integer using multi-bit parameters
|
||||
bench_integer_multi_bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_PARAM_TYPE=MULTI_BIT __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
|
||||
|
||||
@@ -79,7 +79,7 @@ tfhe = { version = "*", features = ["boolean", "shortint", "integer"] }
|
||||
```
|
||||
|
||||
> [!Note]
|
||||
> Note: You need Rust version 1.84 or newer to compile TFHE-rs. You can check your version with `rustc --version`.
|
||||
> Note: You need Rust version 1.91.1 or newer to compile TFHE-rs. You can check your version with `rustc --version`.
|
||||
|
||||
> [!Note]
|
||||
> Note: AArch64-based machines are not supported for Windows as it's currently missing an entropy source to be able to seed the [CSPRNGs](https://en.wikipedia.org/wiki/Cryptographically_secure_pseudorandom_number_generator) used in TFHE-rs.
|
||||
|
||||
@@ -1,43 +1,43 @@
|
||||
# Test vectors for TFHE
|
||||
These test vectors are generated using [TFHE-rs](https://github.com/zama-ai/tfhe-rs), with the git tag `tfhe-test-vectors-0.2.0`.
|
||||
|
||||
They are TFHE-rs objects serialized in the [cbor format](https://cbor.io/). You can deserialize them using any cbor library for the language of your choice. For example, using the [cbor2](https://pypi.org/project/cbor2/) program, run: `cbor2 --pretty toy_params/lwe_a.cbor`.
|
||||
They are TFHE-rs objects serialized in the [cbor format](https://cbor.io/). These can be deserialized using any cbor library for any programming languages. For example, using the [cbor2](https://pypi.org/project/cbor2/) program, the command to run is: `cbor2 --pretty toy_params/lwe_a.cbor`.
|
||||
|
||||
You will find 2 folders with test vectors for different parameter sets:
|
||||
- `valid_params_128`: valid classical PBS parameters using a gaussian noise distribution, providing 128bits of security in the IND-CPA model and a bootstrapping probability of failure of 2^{-64}.
|
||||
- `toy_params`: insecure parameters that yield smaller values
|
||||
There are 2 folders with test vectors for different parameter sets:
|
||||
- `valid_params_128`: valid classical PBS parameters using a Gaussian noise distribution, providing 128-bits of security in the IND-CPA model (i.e., the probability of failure is smaller than 2^{-64}).
|
||||
- `toy_params`: insecure parameters that yield smaller values to simplify the bit comparison of the results.
|
||||
|
||||
The values are generated for the keyswitch -> bootstrap (KS-PBS) atomic pattern. The cleartext inputs are 2 values, A and B defined below.
|
||||
The values are generated to compute a keyswitch (KS) followed by a bootstrap (PBS). The cleartext inputs are 2 values, A and B defined below.
|
||||
|
||||
All the random values are generated from a fixed seed, that can be found in the `RAND_SEED` constant below. The PRNG used is the one based on the AES block cipher in counter mode, from tfhe `tfhe-csprng` crate.
|
||||
|
||||
The programmable bootstrap is applied twice, with 2 different lut, the identity lut and a specific one (currently a x2 operation)
|
||||
The bootstrap is applied twice, with 2 different lut, the identity lut and a specific one computing the double of the input value (i.e., f(x) = 2*x).
|
||||
|
||||
## Vectors
|
||||
The following values are generated:
|
||||
|
||||
### Keys
|
||||
| name | description | TFHE-rs type |
|
||||
|------------------------|---------------------------------------------------------------------------------------|-----------------------------|
|
||||
| `large_lwe_secret_key` | Encryption secret key, before the KS and after the PBS | `LweSecretKey<Vec<u64>>` |
|
||||
| `small_lwe_secret_key` | Secret key encrypting ciphertexts between the KS and the PBS | `LweSecretKey<Vec<u64>>` |
|
||||
| `ksk` | The keyswitching key to convert a ct from the large key to the small one | `LweKeyswitchKey<Vec<u64>>` |
|
||||
| name | description | TFHE-rs type |
|
||||
|------------------------|-----------------------------------------------------------------------------------------|-----------------------------|
|
||||
| `large_lwe_secret_key` | Encryption secret key, before the KS and after the PBS | `LweSecretKey<Vec<u64>>` |
|
||||
| `small_lwe_secret_key` | Secret key encrypting ciphertexts between the KS and the PBS | `LweSecretKey<Vec<u64>>` |
|
||||
| `ksk` | The keyswitching key to convert a ct from the large key to the small one | `LweKeyswitchKey<Vec<u64>>` |
|
||||
| `bsk` | the bootstrapping key to perform a programmable bootstrap on the keyswitched ciphertext | `LweBootstrapKey<Vec<u64>>` |
|
||||
|
||||
|
||||
### Ciphertexts
|
||||
| name | description | TFHE-rs type | Cleartext |
|
||||
|----------------------|--------------------------------------------------------------------------------------------------------------|----------------------------|--------------|
|
||||
| `lwe_a` | Lwe encryption of A | `LweCiphertext<Vec<u64>>` | `A` |
|
||||
| `lwe_b` | Lwe encryption of B | `LweCiphertext<Vec<u64>>` | `B` |
|
||||
| `lwe_sum` | Lwe encryption of A plus lwe encryption of B | `LweCiphertext<Vec<u64>>` | `A+B` |
|
||||
| `lwe_prod` | Lwe encryption of A times cleartext B | `LweCiphertext<Vec<u64>>` | `A*B` |
|
||||
| `lwe_ms` | The lwe ciphertext after the modswitch part of the PBS ([note](#non-native-encoding)) | `LweCiphertext<Vec<u64>>` | `A` |
|
||||
| `lwe_ks` | The lwe ciphertext after the keyswitch | `LweCiphertext<Vec<u64>>` | `A` |
|
||||
| `glwe_after_id_br` | The glwe returned by the application of the identity blind rotation on the mod switched ciphertexts. | `GlweCiphertext<Vec<u64>>` | rot id LUT |
|
||||
| `lwe_after_id_pbs` | The lwe returned by the application of the sample extract operation on the output of the id blind rotation | `LweCiphertext<Vec<u64>>` | `A` |
|
||||
| `glwe_after_spec_br` | The glwe returned by the application of the spec blind rotation on the mod switched ciphertexts. | `GlweCiphertext<Vec<u64>>` | rot spec LUT |
|
||||
| `lwe_after_spec_pbs` | The lwe returned by the application of the sample extract operation on the output of the spec blind rotation | `LweCiphertext<Vec<u64>>` | `spec(A)` |
|
||||
| name | description | TFHE-rs type | Cleartext |
|
||||
|----------------------|-----------------------------------------------------------------------------------------------------|----------------------------|----------------------|
|
||||
| `lwe_a` | LWE Ciphertext encrypting A | `LweCiphertext<Vec<u64>>` | `A` |
|
||||
| `lwe_b` | LWE Ciphertext encrypting B | `LweCiphertext<Vec<u64>>` | `B` |
|
||||
| `lwe_sum` | LWE Ciphertext encrypting A plus lwe encryption of B | `LweCiphertext<Vec<u64>>` | `A+B` |
|
||||
| `lwe_prod` | LWE Ciphertext encrypting A times cleartext B | `LweCiphertext<Vec<u64>>` | `A*B` |
|
||||
| `lwe_ms` | LWE Ciphertext encrypting A after a Modulus Switch from q to 2*N ([note](#non-native-encoding)) | `LweCiphertext<Vec<u64>>` | `A` |
|
||||
| `lwe_ks` | LWE Ciphertext encrypting A after a keyswitch from `large_lwe_secret_key` to `small_lwe_secret_key` | `LweCiphertext<Vec<u64>>` | `A` |
|
||||
| `glwe_after_id_br` | GLWE Ciphertext encrypting A after the application of the identity blind rotation on `lwe_ms` | `GlweCiphertext<Vec<u64>>` | rotation of id LUT |
|
||||
| `lwe_after_id_pbs` | LWE Ciphertext encrypting A after the sample extract operation on `glwe_after_id_br` | `LweCiphertext<Vec<u64>>` | `A` |
|
||||
| `glwe_after_spec_br` | GLWE Ciphertext encrypting spec(A) after the application of the spec blind rotation on `lwe_ms` | `GlweCiphertext<Vec<u64>>` | rotation of spec LUT |
|
||||
| `lwe_after_spec_pbs` | LWE Ciphertext encrypting spec(A) after the sample extract operation on `glwe_after_spec_br` | `LweCiphertext<Vec<u64>>` | `spec(A)` |
|
||||
|
||||
Ciphertexts with the `_karatsuba` suffix are generated using the Karatsuba polynomial multiplication algorithm in the blind rotation, while default ciphertexts are generated using an FFT multiplication.
|
||||
This makes it easier to reproduce bit exact results.
|
||||
|
||||
@@ -86,6 +86,7 @@ fn main() {
|
||||
"cuda/include/integer/integer.h",
|
||||
"cuda/include/integer/rerand.h",
|
||||
"cuda/include/aes/aes.h",
|
||||
"cuda/include/trivium/trivium.h",
|
||||
"cuda/include/zk/zk.h",
|
||||
"cuda/include/keyswitch/keyswitch.h",
|
||||
"cuda/include/keyswitch/ks_enums.h",
|
||||
|
||||
@@ -29,15 +29,13 @@ template <typename Torus> struct int_aes_lut_buffers {
|
||||
allocate_gpu_memory, size_tracker);
|
||||
std::function<Torus(Torus, Torus)> and_lambda =
|
||||
[](Torus a, Torus b) -> Torus { return a & b; };
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->and_lut->get_lut(0, 0),
|
||||
this->and_lut->get_degree(0), this->and_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, and_lambda, allocate_gpu_memory);
|
||||
|
||||
auto active_streams_and_lut = streams.active_gpu_subset(
|
||||
SBOX_MAX_AND_GATES * num_aes_inputs * sbox_parallelism,
|
||||
params.pbs_type);
|
||||
this->and_lut->broadcast_lut(active_streams_and_lut);
|
||||
this->and_lut->generate_and_broadcast_bivariate_lut(
|
||||
active_streams_and_lut, {0}, {and_lambda}, allocate_gpu_memory);
|
||||
|
||||
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
|
||||
this->flush_lut = new int_radix_lut<Torus>(
|
||||
@@ -46,14 +44,11 @@ template <typename Torus> struct int_aes_lut_buffers {
|
||||
std::function<Torus(Torus)> flush_lambda = [](Torus x) -> Torus {
|
||||
return x & 1;
|
||||
};
|
||||
generate_device_accumulator(
|
||||
streams.stream(0), streams.gpu_index(0), this->flush_lut->get_lut(0, 0),
|
||||
this->flush_lut->get_degree(0), this->flush_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, flush_lambda, allocate_gpu_memory);
|
||||
|
||||
auto active_streams_flush_lut = streams.active_gpu_subset(
|
||||
AES_STATE_BITS * num_aes_inputs, params.pbs_type);
|
||||
this->flush_lut->broadcast_lut(active_streams_flush_lut);
|
||||
this->flush_lut->generate_and_broadcast_lut(
|
||||
active_streams_flush_lut, {0}, {flush_lambda}, allocate_gpu_memory);
|
||||
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
|
||||
this->carry_lut = new int_radix_lut<Torus>(
|
||||
@@ -61,14 +56,11 @@ template <typename Torus> struct int_aes_lut_buffers {
|
||||
std::function<Torus(Torus)> carry_lambda = [](Torus x) -> Torus {
|
||||
return (x >> 1) & 1;
|
||||
};
|
||||
generate_device_accumulator(
|
||||
streams.stream(0), streams.gpu_index(0), this->carry_lut->get_lut(0, 0),
|
||||
this->carry_lut->get_degree(0), this->carry_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, carry_lambda, allocate_gpu_memory);
|
||||
|
||||
auto active_streams_carry_lut =
|
||||
streams.active_gpu_subset(num_aes_inputs, params.pbs_type);
|
||||
this->carry_lut->broadcast_lut(active_streams_carry_lut);
|
||||
this->carry_lut->generate_and_broadcast_lut(
|
||||
active_streams_carry_lut, {0}, {carry_lambda}, allocate_gpu_memory);
|
||||
this->carry_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
}
|
||||
|
||||
|
||||
@@ -65,14 +65,8 @@ template <typename Torus> struct boolean_bitop_buffer {
|
||||
return x % params.message_modulus;
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
message_extract_lut->get_lut(0, 0),
|
||||
message_extract_lut->get_degree(0),
|
||||
message_extract_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
lut_f_message_extract, gpu_memory_allocated);
|
||||
message_extract_lut->broadcast_lut(active_streams);
|
||||
message_extract_lut->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {lut_f_message_extract}, gpu_memory_allocated);
|
||||
}
|
||||
tmp_lwe_left = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
@@ -142,12 +136,8 @@ template <typename Torus> struct int_bitop_buffer {
|
||||
}
|
||||
};
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, 0),
|
||||
lut->get_degree(0), lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, lut_bivariate_f, gpu_memory_allocated);
|
||||
lut->broadcast_lut(active_streams);
|
||||
lut->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {lut_bivariate_f}, gpu_memory_allocated);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
@@ -156,6 +146,8 @@ template <typename Torus> struct int_bitop_buffer {
|
||||
num_radix_blocks, allocate_gpu_memory,
|
||||
size_tracker);
|
||||
|
||||
std::vector<std::function<Torus(Torus)>> lut_funcs;
|
||||
std::vector<uint32_t> lut_indices;
|
||||
for (int i = 0; i < params.message_modulus; i++) {
|
||||
auto rhs = i;
|
||||
|
||||
@@ -171,14 +163,13 @@ template <typename Torus> struct int_bitop_buffer {
|
||||
return x ^ rhs;
|
||||
}
|
||||
};
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, i),
|
||||
lut->get_degree(i), lut->get_max_degree(i), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, lut_univariate_scalar_f,
|
||||
gpu_memory_allocated);
|
||||
lut->broadcast_lut(active_streams);
|
||||
|
||||
lut_funcs.push_back(lut_univariate_scalar_f);
|
||||
lut_indices.push_back(i);
|
||||
}
|
||||
|
||||
lut->generate_and_broadcast_lut(active_streams, lut_indices, lut_funcs,
|
||||
gpu_memory_allocated);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,16 +202,11 @@ template <typename Torus> struct boolean_bitnot_buffer {
|
||||
return x % message_modulus;
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
message_extract_lut->get_lut(0, 0),
|
||||
message_extract_lut->get_degree(0),
|
||||
message_extract_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
lut_f_message_extract, gpu_memory_allocated);
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(lwe_ciphertext_count, params.pbs_type);
|
||||
message_extract_lut->broadcast_lut(active_streams);
|
||||
|
||||
message_extract_lut->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {lut_f_message_extract}, gpu_memory_allocated);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,21 +28,17 @@ template <typename Torus> struct int_extend_radix_with_sign_msb_buffer {
|
||||
uint32_t bits_per_block = std::log2(params.message_modulus);
|
||||
uint32_t msg_modulus = params.message_modulus;
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, 0),
|
||||
lut->get_degree(0), lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
[msg_modulus, bits_per_block](Torus x) {
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
|
||||
lut->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {[msg_modulus, bits_per_block](Torus x) {
|
||||
const auto xm = x % msg_modulus;
|
||||
const auto sign_bit = (xm >> (bits_per_block - 1)) & 1;
|
||||
return (Torus)((msg_modulus - 1) * sign_bit);
|
||||
},
|
||||
}},
|
||||
allocate_gpu_memory);
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
lut->broadcast_lut(active_streams);
|
||||
|
||||
this->last_block = new CudaRadixCiphertextFFI;
|
||||
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
|
||||
@@ -85,24 +85,6 @@ template <typename Torus> struct int_cmux_buffer {
|
||||
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), predicate_lut->get_lut(0, 0),
|
||||
predicate_lut->get_degree(0), predicate_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, inverted_lut_f, gpu_memory_allocated);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), predicate_lut->get_lut(0, 1),
|
||||
predicate_lut->get_degree(1), predicate_lut->get_max_degree(1),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, lut_f, gpu_memory_allocated);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
message_extract_lut->get_lut(0, 0), message_extract_lut->get_degree(0),
|
||||
message_extract_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
message_extract_lut_f, gpu_memory_allocated);
|
||||
Torus *h_lut_indexes = predicate_lut->h_lut_indexes;
|
||||
for (int index = 0; index < 2 * num_radix_blocks; index++) {
|
||||
if (index < num_radix_blocks) {
|
||||
@@ -115,12 +97,18 @@ template <typename Torus> struct int_cmux_buffer {
|
||||
predicate_lut->get_lut_indexes(0, 0), h_lut_indexes,
|
||||
2 * num_radix_blocks * sizeof(Torus), streams.stream(0),
|
||||
streams.gpu_index(0), allocate_gpu_memory);
|
||||
|
||||
auto active_streams_pred =
|
||||
streams.active_gpu_subset(2 * num_radix_blocks, params.pbs_type);
|
||||
predicate_lut->broadcast_lut(active_streams_pred);
|
||||
predicate_lut->generate_and_broadcast_bivariate_lut(
|
||||
active_streams_pred, {0, 1}, {inverted_lut_f, lut_f},
|
||||
gpu_memory_allocated);
|
||||
|
||||
auto active_streams_msg =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
message_extract_lut->broadcast_lut(active_streams_msg);
|
||||
|
||||
message_extract_lut->generate_and_broadcast_lut(
|
||||
active_streams_msg, {0}, {message_extract_lut_f}, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
|
||||
@@ -39,22 +39,21 @@ template <typename Torus> struct int_are_all_block_true_buffer {
|
||||
max_chunks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
is_max_value = new int_radix_lut<Torus>(streams, params, 2, max_chunks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
auto is_max_value_f = [max_value](Torus x) -> Torus {
|
||||
return x == max_value;
|
||||
};
|
||||
preallocated_h_lut = (Torus *)malloc(
|
||||
(params.glwe_dimension + 1) * params.polynomial_size * sizeof(Torus));
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), is_max_value->get_lut(0, 0),
|
||||
is_max_value->get_degree(0), is_max_value->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, is_max_value_f, gpu_memory_allocated);
|
||||
|
||||
is_max_value = new int_radix_lut<Torus>(streams, params, 2, max_chunks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(max_chunks, params.pbs_type);
|
||||
is_max_value->broadcast_lut(active_streams);
|
||||
|
||||
auto is_max_value_f = [max_value](Torus x) -> Torus {
|
||||
return x == max_value;
|
||||
};
|
||||
|
||||
is_max_value->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {is_max_value_f}, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
@@ -103,15 +102,10 @@ template <typename Torus> struct int_comparison_eq_buffer {
|
||||
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), is_non_zero_lut->get_lut(0, 0),
|
||||
is_non_zero_lut->get_degree(0), is_non_zero_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, is_non_zero_lut_f, gpu_memory_allocated);
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
is_non_zero_lut->broadcast_lut(active_streams);
|
||||
is_non_zero_lut->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {is_non_zero_lut_f}, gpu_memory_allocated);
|
||||
|
||||
// Scalar may have up to num_radix_blocks blocks
|
||||
scalar_comparison_luts = new int_radix_lut<Torus>(
|
||||
@@ -129,32 +123,28 @@ template <typename Torus> struct int_comparison_eq_buffer {
|
||||
return (lhs == rhs);
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::function<Torus(Torus)>> lut_funcs;
|
||||
std::vector<uint32_t> lut_indices;
|
||||
for (int i = 0; i < total_modulus; i++) {
|
||||
auto lut_f = [i, operator_f](Torus x) -> Torus {
|
||||
return operator_f(i, x);
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
scalar_comparison_luts->get_lut(0, i),
|
||||
scalar_comparison_luts->get_degree(i),
|
||||
scalar_comparison_luts->get_max_degree(i), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
lut_f, gpu_memory_allocated);
|
||||
lut_funcs.push_back(lut_f);
|
||||
lut_indices.push_back(i);
|
||||
}
|
||||
scalar_comparison_luts->broadcast_lut(active_streams);
|
||||
|
||||
scalar_comparison_luts->generate_and_broadcast_lut(
|
||||
active_streams, lut_indices, lut_funcs, gpu_memory_allocated);
|
||||
|
||||
if (op == COMPARISON_TYPE::EQ || op == COMPARISON_TYPE::NE) {
|
||||
operator_lut =
|
||||
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), operator_lut->get_lut(0, 0),
|
||||
operator_lut->get_degree(0), operator_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, operator_f, gpu_memory_allocated);
|
||||
|
||||
operator_lut->broadcast_lut(active_streams);
|
||||
operator_lut->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {operator_f}, gpu_memory_allocated);
|
||||
// operator_lut->broadcast_lut(active_streams);
|
||||
} else {
|
||||
operator_lut = nullptr;
|
||||
}
|
||||
@@ -221,9 +211,6 @@ template <typename Torus> struct int_tree_sign_reduction_buffer {
|
||||
streams.stream(0), streams.gpu_index(0), tmp_y, num_radix_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
// LUTs
|
||||
tree_inner_leaf_lut =
|
||||
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
tree_last_leaf_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, 1, allocate_gpu_memory, size_tracker);
|
||||
@@ -234,15 +221,14 @@ template <typename Torus> struct int_tree_sign_reduction_buffer {
|
||||
tree_last_leaf_scalar_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, 1, allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
tree_inner_leaf_lut->get_lut(0, 0), tree_inner_leaf_lut->get_degree(0),
|
||||
tree_inner_leaf_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
block_selector_f, gpu_memory_allocated);
|
||||
tree_inner_leaf_lut =
|
||||
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
tree_inner_leaf_lut->broadcast_lut(active_streams);
|
||||
tree_inner_leaf_lut->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {block_selector_f}, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
@@ -426,12 +412,8 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), identity_lut->get_lut(0, 0),
|
||||
identity_lut->get_degree(0), identity_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, identity_lut_f, gpu_memory_allocated);
|
||||
identity_lut->broadcast_lut(active_streams);
|
||||
identity_lut->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {identity_lut_f}, gpu_memory_allocated);
|
||||
|
||||
uint32_t total_modulus = params.message_modulus * params.carry_modulus;
|
||||
auto is_zero_f = [total_modulus](Torus x) -> Torus {
|
||||
@@ -441,13 +423,8 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
is_zero_lut = new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), is_zero_lut->get_lut(0, 0),
|
||||
is_zero_lut->get_degree(0), is_zero_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, is_zero_f, gpu_memory_allocated);
|
||||
|
||||
is_zero_lut->broadcast_lut(active_streams);
|
||||
is_zero_lut->generate_and_broadcast_lut(active_streams, {0}, {is_zero_f},
|
||||
gpu_memory_allocated);
|
||||
|
||||
switch (op) {
|
||||
case COMPARISON_TYPE::MAX:
|
||||
@@ -522,13 +499,9 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
PANIC("Cuda error: sign_lut creation failed due to wrong function.")
|
||||
};
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), signed_lut->get_lut(0, 0),
|
||||
signed_lut->get_degree(0), signed_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, signed_lut_f, gpu_memory_allocated);
|
||||
auto active_streams = streams.active_gpu_subset(1, params.pbs_type);
|
||||
signed_lut->broadcast_lut(active_streams);
|
||||
signed_lut->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {signed_lut_f}, gpu_memory_allocated);
|
||||
}
|
||||
preallocated_h_lut = (Torus *)malloc(
|
||||
(params.glwe_dimension + 1) * params.polynomial_size * sizeof(Torus));
|
||||
|
||||
@@ -283,12 +283,9 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
zero_out_if_not_1_lut_2};
|
||||
size_t lut_gpu_indexes[2] = {0, 3};
|
||||
for (int j = 0; j < 2; j++) {
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(lut_gpu_indexes[j]),
|
||||
streams.gpu_index(lut_gpu_indexes[j]), luts[j]->get_lut(0, 0),
|
||||
luts[j]->get_degree(0), luts[j]->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, zero_out_if_not_1_lut_f, gpu_memory_allocated);
|
||||
luts[j]->generate_and_broadcast_lut(streams.get_ith(lut_gpu_indexes[j]),
|
||||
{0}, {zero_out_if_not_1_lut_f},
|
||||
gpu_memory_allocated);
|
||||
}
|
||||
|
||||
luts[0] = zero_out_if_not_2_lut_1;
|
||||
@@ -296,12 +293,9 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
lut_gpu_indexes[0] = 1;
|
||||
lut_gpu_indexes[1] = 2;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(lut_gpu_indexes[j]),
|
||||
streams.gpu_index(lut_gpu_indexes[j]), luts[j]->get_lut(0, 0),
|
||||
luts[j]->get_degree(0), luts[j]->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, zero_out_if_not_2_lut_f, gpu_memory_allocated);
|
||||
luts[j]->generate_and_broadcast_lut(streams.get_ith(lut_gpu_indexes[j]),
|
||||
{0}, {zero_out_if_not_2_lut_f},
|
||||
gpu_memory_allocated);
|
||||
}
|
||||
|
||||
quotient_lut_1 =
|
||||
@@ -321,21 +315,12 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
};
|
||||
auto quotient_lut_3_f = [](Torus cond) -> Torus { return cond * 3; };
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(2), streams.gpu_index(2), quotient_lut_1->get_lut(0, 0),
|
||||
quotient_lut_1->get_degree(0), quotient_lut_1->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, quotient_lut_1_f, gpu_memory_allocated);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(1), streams.gpu_index(1), quotient_lut_2->get_lut(0, 0),
|
||||
quotient_lut_2->get_degree(0), quotient_lut_2->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, quotient_lut_2_f, gpu_memory_allocated);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), quotient_lut_3->get_lut(0, 0),
|
||||
quotient_lut_3->get_degree(0), quotient_lut_3->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, quotient_lut_3_f, gpu_memory_allocated);
|
||||
quotient_lut_1->generate_and_broadcast_lut(
|
||||
streams.get_ith(2), {0}, {quotient_lut_1_f}, gpu_memory_allocated);
|
||||
quotient_lut_2->generate_and_broadcast_lut(
|
||||
streams.get_ith(1), {0}, {quotient_lut_2_f}, gpu_memory_allocated);
|
||||
quotient_lut_3->generate_and_broadcast_lut(
|
||||
streams.get_ith(0), {0}, {quotient_lut_3_f}, gpu_memory_allocated);
|
||||
|
||||
message_extract_lut_1 = new int_radix_lut<Torus>(
|
||||
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
|
||||
@@ -350,15 +335,12 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
luts[0] = message_extract_lut_1;
|
||||
luts[1] = message_extract_lut_2;
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type);
|
||||
|
||||
for (int j = 0; j < 2; j++) {
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), luts[j]->get_lut(0, 0),
|
||||
luts[j]->get_degree(0), luts[j]->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, lut_f_message_extract, gpu_memory_allocated);
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type);
|
||||
luts[j]->broadcast_lut(active_streams);
|
||||
luts[j]->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {lut_f_message_extract}, gpu_memory_allocated);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1007,24 +989,14 @@ template <typename Torus> struct unsigned_int_div_rem_memory {
|
||||
masking_luts_2[i] = new int_radix_lut<Torus>(
|
||||
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
masking_luts_1[i]->get_lut(0, 0), masking_luts_1[i]->get_degree(0),
|
||||
masking_luts_1[i]->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
lut_f_masking, gpu_memory_allocated);
|
||||
auto active_streams_1 = streams.active_gpu_subset(1, params.pbs_type);
|
||||
masking_luts_1[i]->broadcast_lut(active_streams_1);
|
||||
masking_luts_1[i]->generate_and_broadcast_lut(
|
||||
active_streams_1, {0}, {lut_f_masking}, gpu_memory_allocated);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
masking_luts_2[i]->get_lut(0, 0), masking_luts_2[i]->get_degree(0),
|
||||
masking_luts_2[i]->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
lut_f_masking, gpu_memory_allocated);
|
||||
auto active_streams_2 =
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type);
|
||||
masking_luts_2[i]->broadcast_lut(active_streams_2);
|
||||
masking_luts_2[i]->generate_and_broadcast_lut(
|
||||
active_streams_2, {0}, {lut_f_masking}, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
// create and generate message_extract_lut_1 and message_extract_lut_2
|
||||
@@ -1042,15 +1014,12 @@ template <typename Torus> struct unsigned_int_div_rem_memory {
|
||||
|
||||
int_radix_lut<Torus> *luts[2] = {message_extract_lut_1,
|
||||
message_extract_lut_2};
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type);
|
||||
for (int j = 0; j < 2; j++) {
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), luts[j]->get_lut(0, 0),
|
||||
luts[j]->get_degree(0), luts[j]->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, lut_f_message_extract, gpu_memory_allocated);
|
||||
luts[j]->broadcast_lut(active_streams);
|
||||
luts[j]->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {lut_f_message_extract}, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
// Give name to closures to improve readability
|
||||
@@ -1141,14 +1110,8 @@ template <typename Torus> struct unsigned_int_div_rem_memory {
|
||||
merge_overflow_flags_luts[i] = new int_radix_lut<Torus>(
|
||||
streams, params, 1, 1, allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
merge_overflow_flags_luts[i]->get_lut(0, 0),
|
||||
merge_overflow_flags_luts[i]->get_degree(0),
|
||||
merge_overflow_flags_luts[i]->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, lut_f_bit, gpu_memory_allocated);
|
||||
merge_overflow_flags_luts[i]->broadcast_lut(active_gpu_count_for_bits);
|
||||
merge_overflow_flags_luts[i]->generate_and_broadcast_bivariate_lut(
|
||||
active_gpu_count_for_bits, {0}, {lut_f_bit}, gpu_memory_allocated);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1557,16 +1520,12 @@ template <typename Torus> struct int_div_rem_memory {
|
||||
compare_signed_bits_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, 1, allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
compare_signed_bits_lut->get_lut(0, 0),
|
||||
compare_signed_bits_lut->get_degree(0),
|
||||
compare_signed_bits_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
f_compare_extracted_signed_bits, gpu_memory_allocated);
|
||||
auto active_gpu_count_cmp =
|
||||
streams.active_gpu_subset(1, params.pbs_type); // only 1 block needed
|
||||
compare_signed_bits_lut->broadcast_lut(active_gpu_count_cmp);
|
||||
|
||||
compare_signed_bits_lut->generate_and_broadcast_bivariate_lut(
|
||||
active_gpu_count_cmp, {0}, {f_compare_extracted_signed_bits},
|
||||
gpu_memory_allocated);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -53,13 +53,8 @@ template <typename Torus> struct int_prepare_count_of_consecutive_bits_buffer {
|
||||
return count;
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), univ_lut_mem->get_lut(0, 0),
|
||||
univ_lut_mem->get_degree(0), univ_lut_mem->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, generate_uni_lut_lambda, allocate_gpu_memory);
|
||||
|
||||
univ_lut_mem->broadcast_lut(active_streams);
|
||||
univ_lut_mem->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {generate_uni_lut_lambda}, allocate_gpu_memory);
|
||||
|
||||
auto generate_bi_lut_lambda =
|
||||
[num_bits](Torus block_num_bit_count,
|
||||
@@ -70,13 +65,8 @@ template <typename Torus> struct int_prepare_count_of_consecutive_bits_buffer {
|
||||
return 0;
|
||||
};
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), biv_lut_mem->get_lut(0, 0),
|
||||
biv_lut_mem->get_degree(0), biv_lut_mem->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, generate_bi_lut_lambda, allocate_gpu_memory);
|
||||
|
||||
biv_lut_mem->broadcast_lut(active_streams);
|
||||
biv_lut_mem->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {generate_bi_lut_lambda}, allocate_gpu_memory);
|
||||
|
||||
this->tmp_ct = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
@@ -232,7 +222,7 @@ template <typename Torus> struct int_ilog2_buffer {
|
||||
this->sum_output_not_propagated, counter_num_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->lut_message_not =
|
||||
lut_message_not =
|
||||
new int_radix_lut<Torus>(streams, params, 1, counter_num_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
std::function<Torus(Torus)> lut_message_lambda =
|
||||
@@ -240,16 +230,11 @@ template <typename Torus> struct int_ilog2_buffer {
|
||||
uint64_t message = x % this->params.message_modulus;
|
||||
return (~message) % this->params.message_modulus;
|
||||
};
|
||||
generate_device_accumulator(streams.stream(0), streams.gpu_index(0),
|
||||
this->lut_message_not->get_lut(0, 0),
|
||||
this->lut_message_not->get_degree(0),
|
||||
this->lut_message_not->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size,
|
||||
params.message_modulus, params.carry_modulus,
|
||||
lut_message_lambda, allocate_gpu_memory);
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(counter_num_blocks, params.pbs_type);
|
||||
lut_message_not->broadcast_lut(active_streams);
|
||||
lut_message_not->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {lut_message_lambda}, allocate_gpu_memory);
|
||||
|
||||
this->lut_carry_not =
|
||||
new int_radix_lut<Torus>(streams, params, 1, counter_num_blocks,
|
||||
@@ -259,13 +244,8 @@ template <typename Torus> struct int_ilog2_buffer {
|
||||
uint64_t carry = x / this->params.message_modulus;
|
||||
return (~carry) % this->params.message_modulus;
|
||||
};
|
||||
generate_device_accumulator(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
this->lut_carry_not->get_lut(0, 0), this->lut_carry_not->get_degree(0),
|
||||
this->lut_carry_not->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
lut_carry_lambda, allocate_gpu_memory);
|
||||
lut_carry_not->broadcast_lut(active_streams);
|
||||
lut_carry_not->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {lut_carry_lambda}, allocate_gpu_memory);
|
||||
|
||||
this->message_blocks_not = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "utils/helper_multi_gpu.cuh"
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
|
||||
#include <stdio.h>
|
||||
@@ -835,6 +836,56 @@ struct int_radix_lut_custom_input_output {
|
||||
}
|
||||
}
|
||||
|
||||
void generate_and_broadcast_lut(
|
||||
const CudaStreams &streams, std::vector<uint32_t> lut_indexes,
|
||||
std::vector<std::function<OutputTorus(OutputTorus)>> f,
|
||||
bool gpu_memory_allocated) {
|
||||
// streams should be a subset of active_streams
|
||||
|
||||
for (uint32_t i = 0; i < lut_indexes.size(); ++i) {
|
||||
generate_device_accumulator<OutputTorus>(
|
||||
streams.stream(0), streams.gpu_index(0), get_lut(0, lut_indexes[i]),
|
||||
get_degree(lut_indexes[i]), get_max_degree(lut_indexes[i]),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, f[i], gpu_memory_allocated);
|
||||
}
|
||||
//broadcast_lut(streams);
|
||||
}
|
||||
|
||||
void generate_and_broadcast_bivariate_lut(
|
||||
const CudaStreams &streams, std::vector<uint32_t> lut_indexes,
|
||||
std::vector<std::function<OutputTorus(OutputTorus, OutputTorus)>> f,
|
||||
bool gpu_memory_allocated) {
|
||||
// streams should be a subset of active_streams
|
||||
|
||||
/* for (int fidx = 0; fidx < f.size(); ++fidx) {
|
||||
__int128_t f_hash = 0;
|
||||
uint32_t bits_per_lut_val = 5;
|
||||
uint32_t input_modulus_sup =
|
||||
params.message_modulus * params.carry_modulus;
|
||||
for (uint32_t i = 0; i < input_modulus_sup; ++i) {
|
||||
OutputTorus f_eval =
|
||||
f[fidx](i / params.message_modulus, i % params.message_modulus);
|
||||
GPU_ASSERT(f_eval < (1 << bits_per_lut_val),
|
||||
"LUT value expected bitwidth overflow");
|
||||
f_hash |= f_eval;
|
||||
f_hash <<= bits_per_lut_val;
|
||||
}
|
||||
printf("%016llX%016llX\n",
|
||||
(unsigned long long)((f_hash >> 64) & 0xFFFFFFFFFFFFFFFF),
|
||||
(unsigned long long)(f_hash & 0xFFFFFFFFFFFFFFFF));
|
||||
}
|
||||
*/
|
||||
for (uint32_t i = 0; i < lut_indexes.size(); ++i) {
|
||||
generate_device_accumulator_bivariate<InputTorus>(
|
||||
streams.stream(0), streams.gpu_index(0), get_lut(0, lut_indexes[i]),
|
||||
get_degree(lut_indexes[i]), get_max_degree(lut_indexes[i]),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, f[i], gpu_memory_allocated);
|
||||
}
|
||||
//broadcast_lut(streams);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
PANIC_IF_FALSE(lut_indexes_vec.size() == lut_vec.size(),
|
||||
"Lut vec and Lut vec indexes must have the same size");
|
||||
@@ -985,18 +1036,15 @@ template <typename Torus> struct int_bit_extract_luts_buffer {
|
||||
bits_per_block * num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
std::vector<std::function<Torus(Torus)>> lut_funs;
|
||||
std::vector<uint32_t> lut_indices;
|
||||
for (int i = 0; i < bits_per_block; i++) {
|
||||
|
||||
auto operator_f = [i, final_offset](Torus x) -> Torus {
|
||||
Torus y = (x >> i) & 1;
|
||||
return y << final_offset;
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, i),
|
||||
lut->get_degree(i), lut->get_max_degree(i), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
operator_f, gpu_memory_allocated);
|
||||
lut_funs.push_back(operator_f);
|
||||
lut_indices.push_back(i);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1015,7 +1063,10 @@ template <typename Torus> struct int_bit_extract_luts_buffer {
|
||||
|
||||
auto active_streams = streams.active_gpu_subset(
|
||||
bits_per_block * num_radix_blocks, params.pbs_type);
|
||||
lut->broadcast_lut(active_streams);
|
||||
|
||||
lut->generate_and_broadcast_lut(active_streams, lut_indices, lut_funs,
|
||||
gpu_memory_allocated);
|
||||
// lut->broadcast_lut(active_streams);
|
||||
|
||||
/**
|
||||
* the input indexes should take the first bits_per_block PBS to target
|
||||
@@ -1091,24 +1142,6 @@ template <typename Torus> struct int_fullprop_buffer {
|
||||
};
|
||||
|
||||
//
|
||||
Torus *lut_buffer_message = lut->get_lut(0, 0);
|
||||
uint64_t *message_degree = lut->get_degree(0);
|
||||
uint64_t *message_max_degree = lut->get_max_degree(0);
|
||||
Torus *lut_buffer_carry = lut->get_lut(0, 1);
|
||||
uint64_t *carry_degree = lut->get_degree(1);
|
||||
uint64_t *carry_max_degree = lut->get_max_degree(1);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lut_buffer_message,
|
||||
message_degree, message_max_degree, params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
lut_f_message, gpu_memory_allocated);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lut_buffer_carry, carry_degree,
|
||||
carry_max_degree, params.glwe_dimension, params.polynomial_size,
|
||||
params.message_modulus, params.carry_modulus, lut_f_carry,
|
||||
gpu_memory_allocated);
|
||||
|
||||
uint64_t lwe_indexes_size = 2 * sizeof(Torus);
|
||||
Torus *h_lwe_indexes = (Torus *)malloc(lwe_indexes_size);
|
||||
@@ -1118,9 +1151,15 @@ template <typename Torus> struct int_fullprop_buffer {
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
lwe_indexes, h_lwe_indexes, lwe_indexes_size, streams.stream(0),
|
||||
streams.gpu_index(0), allocate_gpu_memory);
|
||||
|
||||
//
|
||||
// No broadcast is needed because full prop is done on 1 single GPU.
|
||||
// By passing a single-GPU CudaStreams with streams.get_ith(0) the LUT is
|
||||
// not broadcast.
|
||||
//
|
||||
lut->generate_and_broadcast_lut(streams.get_ith(0), {0, 1},
|
||||
{lut_f_message, lut_f_carry},
|
||||
gpu_memory_allocated);
|
||||
|
||||
tmp_small_lwe_vector = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
@@ -1238,9 +1277,10 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
if (total_ciphertexts > 0 ||
|
||||
reduce_degrees_for_single_carry_propagation) {
|
||||
uint64_t size_tracker = 0;
|
||||
allocated_luts_message_carry = true;
|
||||
luts_message_carry = new int_radix_lut<Torus>(
|
||||
streams, params, 2, pbs_count, true, size_tracker);
|
||||
allocated_luts_message_carry = true;
|
||||
|
||||
uint64_t message_modulus_bits =
|
||||
(uint64_t)std::log2(params.message_modulus);
|
||||
uint64_t carry_modulus_bits = (uint64_t)std::log2(params.carry_modulus);
|
||||
@@ -1256,7 +1296,9 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
streams, upper_bound_num_blocks, size_tracker, true);
|
||||
}
|
||||
}
|
||||
|
||||
if (allocated_luts_message_carry) {
|
||||
|
||||
auto message_acc = luts_message_carry->get_lut(0, 0);
|
||||
auto carry_acc = luts_message_carry->get_lut(0, 1);
|
||||
|
||||
@@ -1268,22 +1310,11 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
return x / message_modulus;
|
||||
};
|
||||
|
||||
// generate accumulators
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), message_acc,
|
||||
luts_message_carry->get_degree(0),
|
||||
luts_message_carry->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, message_modulus, params.carry_modulus,
|
||||
lut_f_message, gpu_memory_allocated);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), carry_acc,
|
||||
luts_message_carry->get_degree(1),
|
||||
luts_message_carry->get_max_degree(1), params.glwe_dimension,
|
||||
params.polynomial_size, message_modulus, params.carry_modulus,
|
||||
lut_f_carry, gpu_memory_allocated);
|
||||
auto active_gpu_count_mc =
|
||||
streams.active_gpu_subset(pbs_count, params.pbs_type);
|
||||
luts_message_carry->broadcast_lut(active_gpu_count_mc);
|
||||
luts_message_carry->generate_and_broadcast_lut(
|
||||
active_gpu_count_mc, {0, 1}, {lut_f_message, lut_f_carry},
|
||||
gpu_memory_allocated);
|
||||
}
|
||||
}
|
||||
int_sum_ciphertexts_vec_memory(
|
||||
@@ -1418,10 +1449,6 @@ template <typename Torus> struct int_seq_group_prop_memory {
|
||||
uint32_t group_size, uint32_t big_lwe_size_bytes,
|
||||
bool allocate_gpu_memory, uint64_t &size_tracker) {
|
||||
gpu_memory_allocated = allocate_gpu_memory;
|
||||
auto glwe_dimension = params.glwe_dimension;
|
||||
auto polynomial_size = params.polynomial_size;
|
||||
auto message_modulus = params.message_modulus;
|
||||
auto carry_modulus = params.carry_modulus;
|
||||
|
||||
grouping_size = group_size;
|
||||
group_resolved_carries = new CudaRadixCiphertextFFI;
|
||||
@@ -1431,22 +1458,20 @@ template <typename Torus> struct int_seq_group_prop_memory {
|
||||
allocate_gpu_memory);
|
||||
|
||||
int num_seq_luts = grouping_size - 1;
|
||||
Torus *h_seq_lut_indexes = (Torus *)malloc(num_seq_luts * sizeof(Torus));
|
||||
lut_sequential_algorithm =
|
||||
new int_radix_lut<Torus>(streams, params, num_seq_luts, num_seq_luts,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
std::vector<std::function<Torus(Torus)>> lut_funcs;
|
||||
std::vector<uint32_t> lut_indices;
|
||||
Torus *h_seq_lut_indexes = (Torus *)malloc(num_seq_luts * sizeof(Torus));
|
||||
|
||||
for (int index = 0; index < num_seq_luts; index++) {
|
||||
auto f_lut_sequential = [index](Torus propa_cum_sum_block) {
|
||||
return (propa_cum_sum_block >> (index + 1)) & 1;
|
||||
};
|
||||
auto seq_lut = lut_sequential_algorithm->get_lut(0, index);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), seq_lut,
|
||||
lut_sequential_algorithm->get_degree(index),
|
||||
lut_sequential_algorithm->get_max_degree(index), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus, f_lut_sequential,
|
||||
gpu_memory_allocated);
|
||||
lut_funcs.push_back(f_lut_sequential);
|
||||
h_seq_lut_indexes[index] = index;
|
||||
lut_indices.push_back(index);
|
||||
}
|
||||
Torus *seq_lut_indexes = lut_sequential_algorithm->get_lut_indexes(0, 0);
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
@@ -1454,9 +1479,12 @@ template <typename Torus> struct int_seq_group_prop_memory {
|
||||
streams.stream(0), streams.gpu_index(0), allocate_gpu_memory);
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_seq_luts, params.pbs_type);
|
||||
lut_sequential_algorithm->broadcast_lut(active_streams);
|
||||
lut_sequential_algorithm->generate_and_broadcast_lut(
|
||||
active_streams, lut_indices, lut_funcs, gpu_memory_allocated);
|
||||
// lut_sequential_algorithm->broadcast_lut(active_streams);
|
||||
free(h_seq_lut_indexes);
|
||||
};
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
group_resolved_carries,
|
||||
@@ -1478,10 +1506,6 @@ template <typename Torus> struct int_hs_group_prop_memory {
|
||||
uint32_t num_groups, uint32_t big_lwe_size_bytes,
|
||||
bool allocate_gpu_memory, uint64_t &size_tracker) {
|
||||
gpu_memory_allocated = allocate_gpu_memory;
|
||||
auto glwe_dimension = params.glwe_dimension;
|
||||
auto polynomial_size = params.polynomial_size;
|
||||
auto message_modulus = params.message_modulus;
|
||||
auto carry_modulus = params.carry_modulus;
|
||||
|
||||
auto f_lut_hillis_steele = [](Torus msb, Torus lsb) -> Torus {
|
||||
if (msb == 2) {
|
||||
@@ -1501,16 +1525,11 @@ template <typename Torus> struct int_hs_group_prop_memory {
|
||||
lut_hillis_steele = new int_radix_lut<Torus>(
|
||||
streams, params, 1, num_groups, allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
lut_hillis_steele->get_lut(0, 0), lut_hillis_steele->get_degree(0),
|
||||
lut_hillis_steele->get_max_degree(0), glwe_dimension, polynomial_size,
|
||||
message_modulus, carry_modulus, f_lut_hillis_steele,
|
||||
gpu_memory_allocated);
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_groups, params.pbs_type);
|
||||
lut_hillis_steele->broadcast_lut(active_streams);
|
||||
};
|
||||
lut_hillis_steele->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {f_lut_hillis_steele}, gpu_memory_allocated);
|
||||
}
|
||||
void release(CudaStreams streams) {
|
||||
|
||||
lut_hillis_steele->release(streams);
|
||||
@@ -1800,112 +1819,6 @@ template <typename Torus> struct int_prop_simu_group_carries_memory {
|
||||
num_extra_luts = 1;
|
||||
}
|
||||
|
||||
uint32_t num_luts_second_step = 2 * grouping_size + num_extra_luts;
|
||||
luts_array_second_step = new int_radix_lut<Torus>(
|
||||
streams, params, num_luts_second_step, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
// luts for first group inner propagation
|
||||
for (int lut_id = 0; lut_id < grouping_size - 1; lut_id++) {
|
||||
auto f_first_grouping_inner_propagation =
|
||||
[lut_id](Torus propa_cum_sum_block) -> Torus {
|
||||
uint64_t carry = (propa_cum_sum_block >> lut_id) & 1;
|
||||
|
||||
if (carry != 0) {
|
||||
return 2ull; // Generates Carry
|
||||
} else {
|
||||
return 0ull; // Does not generate carry
|
||||
}
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
luts_array_second_step->get_lut(0, lut_id),
|
||||
luts_array_second_step->get_degree(lut_id),
|
||||
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus,
|
||||
f_first_grouping_inner_propagation, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
auto f_first_grouping_outer_propagation =
|
||||
[num_bits_in_block](Torus block) -> Torus {
|
||||
return (block >> (num_bits_in_block - 1)) & 1;
|
||||
};
|
||||
|
||||
int lut_id = grouping_size - 1;
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
luts_array_second_step->get_lut(0, lut_id),
|
||||
luts_array_second_step->get_degree(lut_id),
|
||||
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus,
|
||||
f_first_grouping_outer_propagation, gpu_memory_allocated);
|
||||
|
||||
// for other groupings inner propagation
|
||||
for (int index = 0; index < grouping_size; index++) {
|
||||
uint32_t lut_id = index + grouping_size;
|
||||
|
||||
auto f_other_groupings_inner_propagation =
|
||||
[index](Torus propa_cum_sum_block) -> Torus {
|
||||
uint64_t mask = (2 << index) - 1;
|
||||
if (propa_cum_sum_block >= (2 << index)) {
|
||||
return 2ull; // Generates
|
||||
} else if ((propa_cum_sum_block & mask) == mask) {
|
||||
return 1ull; // Propagate
|
||||
} else {
|
||||
return 0ull; // Nothing
|
||||
}
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
luts_array_second_step->get_lut(0, lut_id),
|
||||
luts_array_second_step->get_degree(lut_id),
|
||||
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus,
|
||||
f_other_groupings_inner_propagation, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
if (use_sequential_algorithm_to_resolve_group_carries) {
|
||||
for (int index = 0; index < grouping_size - 1; index++) {
|
||||
uint32_t lut_id = index + 2 * grouping_size;
|
||||
|
||||
auto f_group_propagation = [index, block_modulus,
|
||||
num_bits_in_block](Torus block) -> Torus {
|
||||
if (block == (block_modulus - 1)) {
|
||||
return 0ull;
|
||||
} else {
|
||||
return ((UINT64_MAX << index) % (1ull << (num_bits_in_block + 1)));
|
||||
}
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
luts_array_second_step->get_lut(0, lut_id),
|
||||
luts_array_second_step->get_degree(lut_id),
|
||||
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus,
|
||||
f_group_propagation, gpu_memory_allocated);
|
||||
}
|
||||
} else {
|
||||
uint32_t lut_id = 2 * grouping_size;
|
||||
auto f_group_propagation = [block_modulus](Torus block) {
|
||||
if (block == (block_modulus - 1)) {
|
||||
return 2ull;
|
||||
} else {
|
||||
return UINT64_MAX % (block_modulus * 2ull);
|
||||
}
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
luts_array_second_step->get_lut(0, lut_id),
|
||||
luts_array_second_step->get_degree(lut_id),
|
||||
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus, f_group_propagation,
|
||||
gpu_memory_allocated);
|
||||
}
|
||||
|
||||
Torus *h_second_lut_indexes = (Torus *)malloc(lut_indexes_size);
|
||||
|
||||
for (int index = 0; index < num_radix_blocks; index++) {
|
||||
@@ -1941,6 +1854,11 @@ template <typename Torus> struct int_prop_simu_group_carries_memory {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t num_luts_second_step = 2 * grouping_size + num_extra_luts;
|
||||
luts_array_second_step = new int_radix_lut<Torus>(
|
||||
streams, params, num_luts_second_step, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
// copy the indexes to the gpu
|
||||
Torus *second_lut_indexes = luts_array_second_step->get_lut_indexes(0, 0);
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
@@ -1951,9 +1869,92 @@ template <typename Torus> struct int_prop_simu_group_carries_memory {
|
||||
scalar_array_cum_sum, h_scalar_array_cum_sum,
|
||||
num_radix_blocks * sizeof(Torus), streams.stream(0),
|
||||
streams.gpu_index(0), allocate_gpu_memory);
|
||||
|
||||
std::vector<std::function<Torus(Torus)>> lut_funcs;
|
||||
std::vector<uint32_t> lut_ids;
|
||||
|
||||
// luts for first group inner propagation
|
||||
for (int lut_id = 0; lut_id < grouping_size - 1; lut_id++) {
|
||||
auto f_first_grouping_inner_propagation =
|
||||
[lut_id](Torus propa_cum_sum_block) -> Torus {
|
||||
uint64_t carry = (propa_cum_sum_block >> lut_id) & 1;
|
||||
|
||||
if (carry != 0) {
|
||||
return 2ull; // Generates Carry
|
||||
} else {
|
||||
return 0ull; // Does not generate carry
|
||||
}
|
||||
};
|
||||
lut_funcs.push_back(f_first_grouping_inner_propagation);
|
||||
lut_ids.push_back(lut_id);
|
||||
}
|
||||
|
||||
auto f_first_grouping_outer_propagation =
|
||||
[num_bits_in_block](Torus block) -> Torus {
|
||||
return (block >> (num_bits_in_block - 1)) & 1;
|
||||
};
|
||||
|
||||
int lut_id = grouping_size - 1;
|
||||
|
||||
lut_funcs.push_back(f_first_grouping_outer_propagation);
|
||||
lut_ids.push_back(lut_id);
|
||||
|
||||
// for other groupings inner propagation
|
||||
for (int index = 0; index < grouping_size; index++) {
|
||||
uint32_t lut_id = index + grouping_size;
|
||||
|
||||
auto f_other_groupings_inner_propagation =
|
||||
[index](Torus propa_cum_sum_block) -> Torus {
|
||||
uint64_t mask = (2 << index) - 1;
|
||||
if (propa_cum_sum_block >= (2 << index)) {
|
||||
return 2ull; // Generates
|
||||
} else if ((propa_cum_sum_block & mask) == mask) {
|
||||
return 1ull; // Propagate
|
||||
} else {
|
||||
return 0ull; // Nothing
|
||||
}
|
||||
};
|
||||
|
||||
lut_funcs.push_back(f_other_groupings_inner_propagation);
|
||||
lut_ids.push_back(lut_id);
|
||||
}
|
||||
|
||||
if (use_sequential_algorithm_to_resolve_group_carries) {
|
||||
for (int index = 0; index < grouping_size - 1; index++) {
|
||||
uint32_t lut_id = index + 2 * grouping_size;
|
||||
|
||||
auto f_group_propagation = [index, block_modulus,
|
||||
num_bits_in_block](Torus block) -> Torus {
|
||||
if (block == (block_modulus - 1)) {
|
||||
return 0ull;
|
||||
} else {
|
||||
return ((UINT64_MAX << index) % (1ull << (num_bits_in_block + 1)));
|
||||
}
|
||||
};
|
||||
|
||||
lut_funcs.push_back(f_group_propagation);
|
||||
lut_ids.push_back(lut_id);
|
||||
}
|
||||
} else {
|
||||
uint32_t lut_id = 2 * grouping_size;
|
||||
auto f_group_propagation = [block_modulus](Torus block) {
|
||||
if (block == (block_modulus - 1)) {
|
||||
return 2ull;
|
||||
} else {
|
||||
return UINT64_MAX % (block_modulus * 2ull);
|
||||
}
|
||||
};
|
||||
|
||||
lut_funcs.push_back(f_group_propagation);
|
||||
lut_ids.push_back(lut_id);
|
||||
}
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
luts_array_second_step->broadcast_lut(active_streams);
|
||||
luts_array_second_step->generate_and_broadcast_lut(
|
||||
active_streams, lut_ids, lut_funcs, gpu_memory_allocated);
|
||||
|
||||
// luts_array_second_step->broadcast_lut(active_streams);
|
||||
|
||||
if (use_sequential_algorithm_to_resolve_group_carries) {
|
||||
|
||||
@@ -2041,12 +2042,28 @@ template <typename Torus> struct int_sc_prop_memory {
|
||||
uint32_t requested_flag;
|
||||
bool gpu_memory_allocated;
|
||||
|
||||
void setup_message_extract_indices_for_carry_async(CudaStreams streams,
|
||||
uint32_t num_radix_blocks,
|
||||
bool allocate_gpu_memory) {
|
||||
Torus *h_lut_indexes = lut_message_extract->h_lut_indexes;
|
||||
for (int index = 0; index < num_radix_blocks + 1; index++) {
|
||||
if (index < num_radix_blocks) {
|
||||
h_lut_indexes[index] = 0;
|
||||
} else {
|
||||
h_lut_indexes[index] = 1;
|
||||
}
|
||||
}
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
|
||||
(num_radix_blocks + 1) * sizeof(Torus), streams.stream(0),
|
||||
streams.gpu_index(0), allocate_gpu_memory);
|
||||
}
|
||||
|
||||
int_sc_prop_memory(CudaStreams streams, int_radix_params params,
|
||||
uint32_t num_radix_blocks, uint32_t requested_flag_in,
|
||||
bool allocate_gpu_memory, uint64_t &size_tracker) {
|
||||
gpu_memory_allocated = allocate_gpu_memory;
|
||||
this->params = params;
|
||||
auto glwe_dimension = params.glwe_dimension;
|
||||
auto polynomial_size = params.polynomial_size;
|
||||
auto message_modulus = params.message_modulus;
|
||||
auto carry_modulus = params.carry_modulus;
|
||||
@@ -2069,24 +2086,6 @@ template <typename Torus> struct int_sc_prop_memory {
|
||||
streams, params, num_radix_blocks, grouping_size, num_groups,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
// Step 3 elements
|
||||
int num_luts_message_extract =
|
||||
requested_flag == outputFlag::FLAG_NONE ? 1 : 2;
|
||||
lut_message_extract = new int_radix_lut<Torus>(
|
||||
streams, params, num_luts_message_extract, num_radix_blocks + 1,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
// lut for the first block in the first grouping
|
||||
auto f_message_extract = [message_modulus](Torus block) -> Torus {
|
||||
return (block >> 1) % message_modulus;
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
lut_message_extract->get_lut(0, 0), lut_message_extract->get_degree(0),
|
||||
lut_message_extract->get_max_degree(0), glwe_dimension, polynomial_size,
|
||||
message_modulus, carry_modulus, f_message_extract,
|
||||
gpu_memory_allocated);
|
||||
|
||||
// This store a single block that with be used to store the overflow or
|
||||
// carry results
|
||||
output_flag = new CudaRadixCiphertextFFI;
|
||||
@@ -2137,22 +2136,30 @@ template <typename Torus> struct int_sc_prop_memory {
|
||||
return output1 << 3 | output2 << 2;
|
||||
};
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
lut_overflow_flag_prep->get_lut(0, 0),
|
||||
lut_overflow_flag_prep->get_degree(0),
|
||||
lut_overflow_flag_prep->get_max_degree(0), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus, f_overflow_fp,
|
||||
gpu_memory_allocated);
|
||||
|
||||
auto active_streams = streams.active_gpu_subset(1, params.pbs_type);
|
||||
lut_overflow_flag_prep->broadcast_lut(active_streams);
|
||||
lut_overflow_flag_prep->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {f_overflow_fp}, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
// Step 3 elements
|
||||
int num_luts_message_extract =
|
||||
requested_flag == outputFlag::FLAG_NONE ? 1 : 2;
|
||||
lut_message_extract = new int_radix_lut<Torus>(
|
||||
streams, params, num_luts_message_extract, num_radix_blocks + 1,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
// lut for the first block in the first grouping
|
||||
auto f_message_extract = [message_modulus](Torus block) -> Torus {
|
||||
return (block >> 1) % message_modulus;
|
||||
};
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks + 1, params.pbs_type);
|
||||
|
||||
// For the final cleanup in case of overflow or carry (it seems that I can)
|
||||
// It seems that this lut could be apply together with the other one but for
|
||||
// now we won't do it
|
||||
if (requested_flag == outputFlag::FLAG_OVERFLOW) { // Overflow case
|
||||
switch (requested_flag) {
|
||||
case outputFlag::FLAG_OVERFLOW: { // Overflow case
|
||||
auto f_overflow_last = [num_radix_blocks,
|
||||
requested_flag_in](Torus block) -> Torus {
|
||||
uint32_t position = (num_radix_blocks == 1 &&
|
||||
@@ -2164,62 +2171,38 @@ template <typename Torus> struct int_sc_prop_memory {
|
||||
Torus does_overflow_if_carry_is_0 = (block >> 2) & 1;
|
||||
if (input_carry == outputFlag::FLAG_OVERFLOW) {
|
||||
return does_overflow_if_carry_is_1;
|
||||
} else {
|
||||
return does_overflow_if_carry_is_0;
|
||||
}
|
||||
return does_overflow_if_carry_is_0;
|
||||
};
|
||||
setup_message_extract_indices_for_carry_async(streams, num_radix_blocks,
|
||||
allocate_gpu_memory);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
lut_message_extract->get_lut(0, 1),
|
||||
lut_message_extract->get_degree(1),
|
||||
lut_message_extract->get_max_degree(1), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus, f_overflow_last,
|
||||
lut_message_extract->generate_and_broadcast_lut(
|
||||
active_streams, {0, 1}, {f_message_extract, f_overflow_last},
|
||||
gpu_memory_allocated);
|
||||
|
||||
Torus *h_lut_indexes = lut_message_extract->h_lut_indexes;
|
||||
for (int index = 0; index < num_radix_blocks + 1; index++) {
|
||||
if (index < num_radix_blocks) {
|
||||
h_lut_indexes[index] = 0;
|
||||
} else {
|
||||
h_lut_indexes[index] = 1;
|
||||
}
|
||||
}
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
|
||||
(num_radix_blocks + 1) * sizeof(Torus), streams.stream(0),
|
||||
streams.gpu_index(0), allocate_gpu_memory);
|
||||
break;
|
||||
}
|
||||
if (requested_flag == outputFlag::FLAG_CARRY) { // Carry case
|
||||
case outputFlag::FLAG_CARRY: { // Carry case
|
||||
|
||||
setup_message_extract_indices_for_carry_async(streams, num_radix_blocks,
|
||||
allocate_gpu_memory);
|
||||
|
||||
auto f_carry_last = [](Torus block) -> Torus {
|
||||
return ((block >> 2) & 1);
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
lut_message_extract->get_lut(0, 1),
|
||||
lut_message_extract->get_degree(1),
|
||||
lut_message_extract->get_max_degree(1), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus, f_carry_last,
|
||||
lut_message_extract->generate_and_broadcast_lut(
|
||||
active_streams, {0, 1}, {f_message_extract, f_carry_last},
|
||||
gpu_memory_allocated);
|
||||
|
||||
Torus *h_lut_indexes = lut_message_extract->h_lut_indexes;
|
||||
for (int index = 0; index < num_radix_blocks + 1; index++) {
|
||||
if (index < num_radix_blocks) {
|
||||
h_lut_indexes[index] = 0;
|
||||
} else {
|
||||
h_lut_indexes[index] = 1;
|
||||
}
|
||||
}
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
|
||||
(num_radix_blocks + 1) * sizeof(Torus), streams.stream(0),
|
||||
streams.gpu_index(0), allocate_gpu_memory);
|
||||
break;
|
||||
}
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks + 1, params.pbs_type);
|
||||
lut_message_extract->broadcast_lut(active_streams);
|
||||
default:
|
||||
lut_message_extract->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {f_message_extract}, gpu_memory_allocated);
|
||||
break;
|
||||
}
|
||||
|
||||
// lut_message_extract->broadcast_lut(active_streams);
|
||||
};
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
@@ -2517,16 +2500,11 @@ template <typename Torus> struct int_borrow_prop_memory {
|
||||
return (block >> 1) % message_modulus;
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
lut_message_extract->get_lut(0, 0), lut_message_extract->get_degree(0),
|
||||
lut_message_extract->get_max_degree(0), glwe_dimension, polynomial_size,
|
||||
message_modulus, carry_modulus, f_message_extract,
|
||||
gpu_memory_allocated);
|
||||
active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
|
||||
lut_message_extract->broadcast_lut(active_streams);
|
||||
lut_message_extract->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {f_message_extract}, gpu_memory_allocated);
|
||||
|
||||
if (compute_overflow) {
|
||||
lut_borrow_flag =
|
||||
@@ -2537,12 +2515,8 @@ template <typename Torus> struct int_borrow_prop_memory {
|
||||
return ((block >> 2) & 1);
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
lut_borrow_flag->get_lut(0, 0), lut_borrow_flag->get_degree(0),
|
||||
lut_borrow_flag->get_max_degree(0), glwe_dimension, polynomial_size,
|
||||
message_modulus, carry_modulus, f_borrow_flag, gpu_memory_allocated);
|
||||
lut_borrow_flag->broadcast_lut(active_streams);
|
||||
lut_borrow_flag->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {f_borrow_flag}, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
active_streams =
|
||||
|
||||
@@ -37,17 +37,14 @@ template <typename Torus> struct int_mul_memory {
|
||||
zero_out_predicate_lut =
|
||||
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
zero_out_predicate_lut->get_lut(0, 0),
|
||||
zero_out_predicate_lut->get_degree(0),
|
||||
zero_out_predicate_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
zero_out_predicate_lut_f, gpu_memory_allocated);
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
zero_out_predicate_lut->broadcast_lut(active_streams);
|
||||
zero_out_predicate_lut->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {zero_out_predicate_lut_f},
|
||||
gpu_memory_allocated);
|
||||
|
||||
// zero_out_predicate_lut->broadcast_lut(active_streams);
|
||||
|
||||
zero_out_mem = new int_zero_out_if_buffer<Torus>(
|
||||
streams, params, num_radix_blocks, allocate_gpu_memory, size_tracker);
|
||||
@@ -55,10 +52,7 @@ template <typename Torus> struct int_mul_memory {
|
||||
return;
|
||||
}
|
||||
|
||||
auto glwe_dimension = params.glwe_dimension;
|
||||
auto polynomial_size = params.polynomial_size;
|
||||
auto message_modulus = params.message_modulus;
|
||||
auto carry_modulus = params.carry_modulus;
|
||||
|
||||
// 'vector_result_lsb' contains blocks from all possible shifts of
|
||||
// radix_lwe_left excluding zero ciphertext blocks
|
||||
@@ -102,18 +96,6 @@ template <typename Torus> struct int_mul_memory {
|
||||
return (x * y) / message_modulus;
|
||||
};
|
||||
|
||||
// generate accumulators
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lsb_acc,
|
||||
luts_array->get_degree(0), luts_array->get_max_degree(0),
|
||||
glwe_dimension, polynomial_size, message_modulus, carry_modulus,
|
||||
lut_f_lsb, gpu_memory_allocated);
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), msb_acc,
|
||||
luts_array->get_degree(1), luts_array->get_max_degree(1),
|
||||
glwe_dimension, polynomial_size, message_modulus, carry_modulus,
|
||||
lut_f_msb, gpu_memory_allocated);
|
||||
|
||||
// lut_indexes_vec for luts_array should be reinitialized
|
||||
// first lsb_vector_block_count value should reference to lsb_acc
|
||||
// last msb_vector_block_count values should reference to msb_acc
|
||||
@@ -123,9 +105,12 @@ template <typename Torus> struct int_mul_memory {
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
luts_array->get_lut_indexes(0, lsb_vector_block_count), 1,
|
||||
msb_vector_block_count);
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(total_block_count, params.pbs_type);
|
||||
luts_array->broadcast_lut(active_streams);
|
||||
luts_array->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0, 1}, {lut_f_lsb, lut_f_msb}, gpu_memory_allocated);
|
||||
|
||||
// create memory object for sum ciphertexts
|
||||
sum_ciphertexts_mem = new int_sum_ciphertexts_vec_memory<Torus>(
|
||||
streams, params, num_radix_blocks, 2 * num_radix_blocks,
|
||||
|
||||
@@ -85,15 +85,11 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
|
||||
}
|
||||
|
||||
// right shift
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
cur_lut_bivariate->get_lut(0, 0), cur_lut_bivariate->get_degree(0),
|
||||
cur_lut_bivariate->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
shift_lut_f, gpu_memory_allocated);
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
cur_lut_bivariate->broadcast_lut(active_streams);
|
||||
cur_lut_bivariate->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {shift_lut_f}, gpu_memory_allocated);
|
||||
|
||||
lut_buffers_bivariate.push_back(cur_lut_bivariate);
|
||||
}
|
||||
@@ -172,16 +168,10 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
|
||||
}
|
||||
|
||||
// right shift
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
cur_lut_bivariate->get_lut(0, 0), cur_lut_bivariate->get_degree(0),
|
||||
cur_lut_bivariate->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
shift_lut_f, gpu_memory_allocated);
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
cur_lut_bivariate->broadcast_lut(active_streams);
|
||||
|
||||
cur_lut_bivariate->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {shift_lut_f}, gpu_memory_allocated);
|
||||
lut_buffers_bivariate.push_back(cur_lut_bivariate);
|
||||
}
|
||||
}
|
||||
@@ -271,16 +261,11 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
|
||||
return shifted | padding;
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
shift_last_block_lut_univariate->get_lut(0, 0),
|
||||
shift_last_block_lut_univariate->get_degree(0),
|
||||
shift_last_block_lut_univariate->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, last_block_lut_f, gpu_memory_allocated);
|
||||
auto active_streams_shift_last =
|
||||
streams.active_gpu_subset(1, params.pbs_type);
|
||||
shift_last_block_lut_univariate->broadcast_lut(active_streams_shift_last);
|
||||
shift_last_block_lut_univariate->generate_and_broadcast_lut(
|
||||
active_streams_shift_last, {0}, {last_block_lut_f},
|
||||
gpu_memory_allocated);
|
||||
|
||||
lut_buffers_univariate.push_back(shift_last_block_lut_univariate);
|
||||
}
|
||||
@@ -298,15 +283,8 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
|
||||
return (params.message_modulus - 1) * x_sign_bit;
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
padding_block_lut_univariate->get_lut(0, 0),
|
||||
padding_block_lut_univariate->get_degree(0),
|
||||
padding_block_lut_univariate->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
padding_block_lut_f, gpu_memory_allocated);
|
||||
// auto active_streams = streams.active_gpu_subset(1, params.pbs_type);
|
||||
padding_block_lut_univariate->broadcast_lut(active_streams);
|
||||
padding_block_lut_univariate->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {padding_block_lut_f}, gpu_memory_allocated);
|
||||
|
||||
lut_buffers_univariate.push_back(padding_block_lut_univariate);
|
||||
|
||||
@@ -339,16 +317,11 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
|
||||
return message_of_current_block + carry_of_previous_block;
|
||||
};
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
shift_blocks_lut_bivariate->get_lut(0, 0),
|
||||
shift_blocks_lut_bivariate->get_degree(0),
|
||||
shift_blocks_lut_bivariate->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
blocks_lut_f, gpu_memory_allocated);
|
||||
auto active_streams_shift_blocks =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
shift_blocks_lut_bivariate->broadcast_lut(active_streams_shift_blocks);
|
||||
shift_blocks_lut_bivariate->generate_and_broadcast_bivariate_lut(
|
||||
active_streams_shift_blocks, {0}, {blocks_lut_f},
|
||||
gpu_memory_allocated);
|
||||
|
||||
lut_buffers_bivariate.push_back(shift_blocks_lut_bivariate);
|
||||
}
|
||||
|
||||
@@ -113,27 +113,21 @@ template <typename Torus> struct int_shift_and_rotate_buffer {
|
||||
else
|
||||
return current_bit;
|
||||
};
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), mux_lut->get_lut(0, 0),
|
||||
mux_lut->get_degree(0), mux_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, mux_lut_f, gpu_memory_allocated);
|
||||
;
|
||||
auto active_gpu_count_mux = streams.active_gpu_subset(
|
||||
bits_per_block * num_radix_blocks, params.pbs_type);
|
||||
mux_lut->broadcast_lut(active_gpu_count_mux);
|
||||
|
||||
mux_lut->generate_and_broadcast_lut(active_gpu_count_mux, {0}, {mux_lut_f},
|
||||
gpu_memory_allocated);
|
||||
|
||||
auto cleaning_lut_f = [params](Torus x) -> Torus {
|
||||
return x % params.message_modulus;
|
||||
};
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), cleaning_lut->get_lut(0, 0),
|
||||
cleaning_lut->get_degree(0), cleaning_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, cleaning_lut_f, gpu_memory_allocated);
|
||||
|
||||
auto active_gpu_count_cleaning =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
cleaning_lut->broadcast_lut(active_gpu_count_cleaning);
|
||||
cleaning_lut->generate_and_broadcast_lut(
|
||||
active_gpu_count_cleaning, {0}, {cleaning_lut_f}, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
|
||||
@@ -74,45 +74,26 @@ template <typename Torus> struct int_overflowing_sub_memory {
|
||||
luts_array, size_tracker,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
auto lut_does_block_generate_carry = luts_array->get_lut(0, 0);
|
||||
auto lut_does_block_generate_or_propagate = luts_array->get_lut(0, 1);
|
||||
|
||||
// generate luts (aka accumulators)
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lut_does_block_generate_carry,
|
||||
luts_array->get_degree(0), luts_array->get_max_degree(0),
|
||||
glwe_dimension, polynomial_size, message_modulus, carry_modulus,
|
||||
f_lut_does_block_generate_carry, gpu_memory_allocated);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
lut_does_block_generate_or_propagate, luts_array->get_degree(1),
|
||||
luts_array->get_max_degree(1), glwe_dimension, polynomial_size,
|
||||
message_modulus, carry_modulus, f_lut_does_block_generate_or_propagate,
|
||||
gpu_memory_allocated);
|
||||
if (allocate_gpu_memory)
|
||||
cuda_set_value_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
luts_array->get_lut_indexes(0, 1), 1,
|
||||
num_radix_blocks - 1);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
luts_borrow_propagation_sum->get_lut(0, 0),
|
||||
luts_borrow_propagation_sum->get_degree(0),
|
||||
luts_borrow_propagation_sum->get_max_degree(0), glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus,
|
||||
f_luts_borrow_propagation_sum, gpu_memory_allocated);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), message_acc->get_lut(0, 0),
|
||||
message_acc->get_degree(0), message_acc->get_max_degree(0),
|
||||
glwe_dimension, polynomial_size, message_modulus, carry_modulus,
|
||||
f_message_acc, gpu_memory_allocated);
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
|
||||
luts_array->broadcast_lut(active_streams);
|
||||
luts_borrow_propagation_sum->broadcast_lut(active_streams);
|
||||
message_acc->broadcast_lut(active_streams);
|
||||
luts_borrow_propagation_sum->generate_and_broadcast_bivariate_lut(
|
||||
active_streams, {0}, {f_luts_borrow_propagation_sum},
|
||||
gpu_memory_allocated);
|
||||
|
||||
luts_array->generate_and_broadcast_lut(
|
||||
active_streams, {0, 1},
|
||||
{f_lut_does_block_generate_carry,
|
||||
f_lut_does_block_generate_or_propagate},
|
||||
gpu_memory_allocated);
|
||||
// generate luts (aka accumulators)
|
||||
|
||||
message_acc->generate_and_broadcast_lut(
|
||||
active_streams, {0}, {f_message_acc}, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
|
||||
@@ -298,14 +298,10 @@ template <typename Torus> struct int_aggregate_one_hot_buffer {
|
||||
int_radix_lut<Torus> *lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, 0),
|
||||
lut->get_degree(0), lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
id_fn, allocate_gpu_memory);
|
||||
lut->generate_and_broadcast_lut(
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type), {0}, {id_fn},
|
||||
allocate_gpu_memory);
|
||||
|
||||
lut->broadcast_lut(
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type));
|
||||
this->stream_identity_luts[i] = lut;
|
||||
}
|
||||
|
||||
@@ -318,27 +314,17 @@ template <typename Torus> struct int_aggregate_one_hot_buffer {
|
||||
|
||||
this->message_extract_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
this->message_extract_lut->get_lut(0, 0),
|
||||
this->message_extract_lut->get_degree(0),
|
||||
this->message_extract_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
msg_fn, allocate_gpu_memory);
|
||||
this->message_extract_lut->broadcast_lut(
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type));
|
||||
|
||||
this->message_extract_lut->generate_and_broadcast_lut(
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type), {0}, {msg_fn},
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->carry_extract_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
this->carry_extract_lut->get_lut(0, 0),
|
||||
this->carry_extract_lut->get_degree(0),
|
||||
this->carry_extract_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
carry_fn, allocate_gpu_memory);
|
||||
this->carry_extract_lut->broadcast_lut(
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type));
|
||||
|
||||
this->carry_extract_lut->generate_and_broadcast_lut(
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type), {0}, {carry_fn},
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->partial_aggregated_vectors =
|
||||
new CudaRadixCiphertextFFI *[num_streams];
|
||||
@@ -1185,15 +1171,9 @@ template <typename Torus> struct int_unchecked_first_index_of_clear_buffer {
|
||||
this->prefix_sum_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 2, num_inputs, allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
this->prefix_sum_lut->get_lut(0, 0),
|
||||
this->prefix_sum_lut->get_degree(0),
|
||||
this->prefix_sum_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
prefix_sum_fn, allocate_gpu_memory);
|
||||
this->prefix_sum_lut->broadcast_lut(
|
||||
streams.active_gpu_subset(num_inputs, params.pbs_type));
|
||||
this->prefix_sum_lut->generate_and_broadcast_bivariate_lut(
|
||||
streams.active_gpu_subset(num_inputs, params.pbs_type), {0},
|
||||
{prefix_sum_fn}, allocate_gpu_memory);
|
||||
|
||||
auto cleanup_fn = [ALREADY_SEEN, params](Torus x) -> Torus {
|
||||
Torus val = x % params.message_modulus;
|
||||
@@ -1203,14 +1183,9 @@ template <typename Torus> struct int_unchecked_first_index_of_clear_buffer {
|
||||
};
|
||||
this->cleanup_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, num_inputs, allocate_gpu_memory, size_tracker);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
this->cleanup_lut->get_lut(0, 0), this->cleanup_lut->get_degree(0),
|
||||
this->cleanup_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
cleanup_fn, allocate_gpu_memory);
|
||||
this->cleanup_lut->broadcast_lut(
|
||||
streams.active_gpu_subset(num_inputs, params.pbs_type));
|
||||
this->cleanup_lut->generate_and_broadcast_lut(
|
||||
streams.active_gpu_subset(num_inputs, params.pbs_type), {0},
|
||||
{cleanup_fn}, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
@@ -1376,15 +1351,9 @@ template <typename Torus> struct int_unchecked_first_index_of_buffer {
|
||||
this->prefix_sum_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 2, num_inputs, allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
this->prefix_sum_lut->get_lut(0, 0),
|
||||
this->prefix_sum_lut->get_degree(0),
|
||||
this->prefix_sum_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
prefix_sum_fn, allocate_gpu_memory);
|
||||
this->prefix_sum_lut->broadcast_lut(
|
||||
streams.active_gpu_subset(num_inputs, params.pbs_type));
|
||||
this->prefix_sum_lut->generate_and_broadcast_bivariate_lut(
|
||||
streams.active_gpu_subset(num_inputs, params.pbs_type), {0},
|
||||
{prefix_sum_fn}, allocate_gpu_memory);
|
||||
|
||||
auto cleanup_fn = [ALREADY_SEEN, params](Torus x) -> Torus {
|
||||
Torus val = x % params.message_modulus;
|
||||
@@ -1394,14 +1363,9 @@ template <typename Torus> struct int_unchecked_first_index_of_buffer {
|
||||
};
|
||||
this->cleanup_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, num_inputs, allocate_gpu_memory, size_tracker);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
this->cleanup_lut->get_lut(0, 0), this->cleanup_lut->get_degree(0),
|
||||
this->cleanup_lut->get_max_degree(0), params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
cleanup_fn, allocate_gpu_memory);
|
||||
this->cleanup_lut->broadcast_lut(
|
||||
streams.active_gpu_subset(num_inputs, params.pbs_type));
|
||||
this->cleanup_lut->generate_and_broadcast_lut(
|
||||
streams.active_gpu_subset(num_inputs, params.pbs_type), {0},
|
||||
{cleanup_fn}, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
|
||||
24
backends/tfhe-cuda-backend/cuda/include/trivium/trivium.h
Normal file
24
backends/tfhe-cuda-backend/cuda/include/trivium/trivium.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#ifndef TRIVIUM_H
|
||||
#define TRIVIUM_H
|
||||
|
||||
#include "../integer/integer.h"
|
||||
|
||||
extern "C" {
|
||||
uint64_t scratch_cuda_trivium_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs);
|
||||
|
||||
void cuda_trivium_generate_keystream_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
|
||||
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
|
||||
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,295 @@
|
||||
#ifndef TRIVIUM_UTILITIES_H
|
||||
#define TRIVIUM_UTILITIES_H
|
||||
#include "../integer/integer_utilities.h"
|
||||
|
||||
/// Struct to hold the LUTs.
|
||||
template <typename Torus> struct int_trivium_lut_buffers {
|
||||
// Bivariate AND Gate LUT:
|
||||
// AND operation: f(a, b) = (a & 1) & (b & 1).
|
||||
// This is a Bivariate PBS used for the non-linear parts of Trivium.
|
||||
int_radix_lut<Torus> *and_lut;
|
||||
|
||||
// Univariate Identity LUT:
|
||||
// MESSAGE EXTRACTION operation: f(x) = x & 1.
|
||||
// This is a Univariate PBS used to "flush" the state: it resets the noise
|
||||
// after additions and ensures the message stays within the binary message
|
||||
// space.
|
||||
int_radix_lut<Torus> *flush_lut;
|
||||
|
||||
int_trivium_lut_buffers(CudaStreams streams, const int_radix_params ¶ms,
|
||||
bool allocate_gpu_memory, uint32_t num_trivium_inputs,
|
||||
uint64_t &size_tracker) {
|
||||
|
||||
constexpr uint32_t BATCH_SIZE = 64;
|
||||
constexpr uint32_t MAX_AND_PER_STEP = 3;
|
||||
uint32_t total_lut_ops = num_trivium_inputs * BATCH_SIZE * MAX_AND_PER_STEP;
|
||||
|
||||
this->and_lut = new int_radix_lut<Torus>(streams, params, 1, total_lut_ops,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
std::function<Torus(Torus, Torus)> and_lambda =
|
||||
[](Torus a, Torus b) -> Torus { return (a & 1) & (b & 1); };
|
||||
|
||||
auto active_streams_and =
|
||||
streams.active_gpu_subset(total_lut_ops, params.pbs_type);
|
||||
this->and_lut->generate_and_broadcast_bivariate_lut(
|
||||
active_streams_and, {0}, {and_lambda}, allocate_gpu_memory);
|
||||
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
|
||||
uint32_t total_flush_ops = num_trivium_inputs * BATCH_SIZE * 4;
|
||||
|
||||
this->flush_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, total_flush_ops, allocate_gpu_memory, size_tracker);
|
||||
|
||||
std::function<Torus(Torus)> flush_lambda = [](Torus x) -> Torus {
|
||||
return x & 1;
|
||||
};
|
||||
|
||||
auto active_streams_flush =
|
||||
streams.active_gpu_subset(total_flush_ops, params.pbs_type);
|
||||
this->flush_lut->generate_and_broadcast_lut(
|
||||
active_streams_flush, {0}, {flush_lambda}, allocate_gpu_memory);
|
||||
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
this->and_lut->release(streams);
|
||||
delete this->and_lut;
|
||||
this->and_lut = nullptr;
|
||||
|
||||
this->flush_lut->release(streams);
|
||||
delete this->flush_lut;
|
||||
this->flush_lut = nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
/// Struct to hold the state and temporary workspaces required for
|
||||
/// Trivium execution on the GPU.
|
||||
///
|
||||
/// This struct manages the memory for the internal registers (A, B, C),
|
||||
/// temporary buffers used during the update function, and buffers used for
|
||||
/// packing data before and after PBS.
|
||||
template <typename Torus> struct int_trivium_state_workspaces {
|
||||
// Trivium Internal State Registers:
|
||||
// Register A: 93 bits
|
||||
CudaRadixCiphertextFFI *a_reg;
|
||||
// Register B: 84 bits
|
||||
CudaRadixCiphertextFFI *b_reg;
|
||||
// Register C: 111 bits
|
||||
CudaRadixCiphertextFFI *c_reg;
|
||||
|
||||
// Shift Workspace:
|
||||
// Used to manage bitshifting operations on the registers
|
||||
CudaRadixCiphertextFFI *shift_workspace;
|
||||
|
||||
// Temporary Update Buffers:
|
||||
// Intermediate buffers for the trivium update logic (t1, t2, t3)
|
||||
CudaRadixCiphertextFFI *temp_t1;
|
||||
CudaRadixCiphertextFFI *temp_t2;
|
||||
CudaRadixCiphertextFFI *temp_t3;
|
||||
|
||||
// Buffers to hold the new values for the registers after an update step
|
||||
CudaRadixCiphertextFFI *new_a;
|
||||
CudaRadixCiphertextFFI *new_b;
|
||||
CudaRadixCiphertextFFI *new_c;
|
||||
|
||||
// PBS Packing Buffers:
|
||||
// Buffers for packing inputs into the bivariate lookup table (AND gate)
|
||||
CudaRadixCiphertextFFI *packed_pbs_lhs;
|
||||
CudaRadixCiphertextFFI *packed_pbs_rhs;
|
||||
// Buffer for the output of the bivariate PBS
|
||||
CudaRadixCiphertextFFI *packed_pbs_out;
|
||||
|
||||
// Flush/Cleanup Packing Buffers:
|
||||
// Buffers for the "flush" LUT which cleans up noise after additions
|
||||
CudaRadixCiphertextFFI *packed_flush_in;
|
||||
CudaRadixCiphertextFFI *packed_flush_out;
|
||||
|
||||
int_trivium_state_workspaces(CudaStreams streams,
|
||||
const int_radix_params ¶ms,
|
||||
bool allocate_gpu_memory, uint32_t num_inputs,
|
||||
uint64_t &size_tracker) {
|
||||
|
||||
this->a_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->a_reg, 93 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->b_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->b_reg, 84 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->c_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->c_reg, 111 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->shift_workspace = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->shift_workspace,
|
||||
128 * num_inputs, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
uint32_t batch_blocks = 64 * num_inputs;
|
||||
|
||||
this->temp_t1 = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->temp_t1, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->temp_t2 = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->temp_t2, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->temp_t3 = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->temp_t3, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->new_a = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->new_a, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->new_b = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->new_b, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->new_c = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->new_c, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->packed_pbs_lhs = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_pbs_lhs,
|
||||
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_pbs_rhs = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_pbs_rhs,
|
||||
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_pbs_out = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_pbs_out,
|
||||
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_flush_in = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_flush_in,
|
||||
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_flush_out = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_flush_out,
|
||||
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams, bool allocate_gpu_memory) {
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->a_reg, allocate_gpu_memory);
|
||||
delete this->a_reg;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->b_reg, allocate_gpu_memory);
|
||||
delete this->b_reg;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->c_reg, allocate_gpu_memory);
|
||||
delete this->c_reg;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->shift_workspace, allocate_gpu_memory);
|
||||
delete this->shift_workspace;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->temp_t1, allocate_gpu_memory);
|
||||
delete this->temp_t1;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->temp_t2, allocate_gpu_memory);
|
||||
delete this->temp_t2;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->temp_t3, allocate_gpu_memory);
|
||||
delete this->temp_t3;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->new_a, allocate_gpu_memory);
|
||||
delete this->new_a;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->new_b, allocate_gpu_memory);
|
||||
delete this->new_b;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->new_c, allocate_gpu_memory);
|
||||
delete this->new_c;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_pbs_lhs, allocate_gpu_memory);
|
||||
delete this->packed_pbs_lhs;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_pbs_rhs, allocate_gpu_memory);
|
||||
delete this->packed_pbs_rhs;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_pbs_out, allocate_gpu_memory);
|
||||
delete this->packed_pbs_out;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_flush_in, allocate_gpu_memory);
|
||||
delete this->packed_flush_in;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_flush_out, allocate_gpu_memory);
|
||||
delete this->packed_flush_out;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus> struct int_trivium_buffer {
|
||||
int_radix_params params;
|
||||
bool allocate_gpu_memory;
|
||||
uint32_t num_inputs;
|
||||
|
||||
int_trivium_lut_buffers<Torus> *luts;
|
||||
int_trivium_state_workspaces<Torus> *state;
|
||||
|
||||
int_trivium_buffer(CudaStreams streams, const int_radix_params ¶ms,
|
||||
bool allocate_gpu_memory, uint32_t num_inputs,
|
||||
uint64_t &size_tracker) {
|
||||
this->params = params;
|
||||
this->allocate_gpu_memory = allocate_gpu_memory;
|
||||
this->num_inputs = num_inputs;
|
||||
|
||||
this->luts = new int_trivium_lut_buffers<Torus>(
|
||||
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
|
||||
|
||||
this->state = new int_trivium_state_workspaces<Torus>(
|
||||
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
luts->release(streams);
|
||||
delete luts;
|
||||
luts = nullptr;
|
||||
|
||||
state->release(streams, allocate_gpu_memory);
|
||||
delete state;
|
||||
state = nullptr;
|
||||
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -174,40 +174,6 @@ template <typename Torus> struct zk_expand_mem {
|
||||
message_and_carry_extract_luts = new int_radix_lut<Torus>(
|
||||
streams, params, 4, 2 * num_lwes, allocate_gpu_memory, size_tracker);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
message_and_carry_extract_luts->get_lut(0, 0),
|
||||
message_and_carry_extract_luts->get_degree(0),
|
||||
message_and_carry_extract_luts->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, message_extract_lut_f, gpu_memory_allocated);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
message_and_carry_extract_luts->get_lut(0, 1),
|
||||
message_and_carry_extract_luts->get_degree(1),
|
||||
message_and_carry_extract_luts->get_max_degree(1),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, carry_extract_lut_f, gpu_memory_allocated);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
message_and_carry_extract_luts->get_lut(0, 2),
|
||||
message_and_carry_extract_luts->get_degree(2),
|
||||
message_and_carry_extract_luts->get_max_degree(2),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, message_extract_and_sanitize_bool_lut_f,
|
||||
gpu_memory_allocated);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0),
|
||||
message_and_carry_extract_luts->get_lut(0, 3),
|
||||
message_and_carry_extract_luts->get_degree(3),
|
||||
message_and_carry_extract_luts->get_max_degree(3),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, carry_extract_and_sanitize_bool_lut_f,
|
||||
gpu_memory_allocated);
|
||||
|
||||
// We are always packing two LWEs. We just need to be sure we have enough
|
||||
// space in the carry part to store a message of the same size as is in the
|
||||
// message part.
|
||||
@@ -292,7 +258,13 @@ template <typename Torus> struct zk_expand_mem {
|
||||
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(2 * num_lwes, params.pbs_type);
|
||||
message_and_carry_extract_luts->broadcast_lut(active_streams);
|
||||
|
||||
message_and_carry_extract_luts->generate_and_broadcast_lut(
|
||||
active_streams, {0, 1, 2, 3},
|
||||
{message_extract_lut_f, carry_extract_lut_f,
|
||||
message_extract_and_sanitize_bool_lut_f,
|
||||
carry_extract_and_sanitize_bool_lut_f},
|
||||
gpu_memory_allocated);
|
||||
|
||||
message_and_carry_extract_luts->allocate_lwe_vector_for_non_trivial_indexes(
|
||||
active_streams, 2 * num_lwes, size_tracker, allocate_gpu_memory);
|
||||
|
||||
@@ -1067,6 +1067,85 @@ void generate_device_accumulator_bivariate(
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
template <typename Torus> struct int_lut_cache {
|
||||
int_lut_cache() {}
|
||||
|
||||
Torus *get_cached_univariate_lut(std::function<Torus(Torus)> &f, uint64_t *degree,
|
||||
uint64_t *max_degree, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size,
|
||||
uint32_t input_message_modulus,
|
||||
uint32_t input_carry_modulus,
|
||||
uint32_t output_message_modulus,
|
||||
uint32_t output_carry_modulus) {
|
||||
/*__int128_t f_hash = 0;
|
||||
uint32_t bits_per_lut_val = 5;
|
||||
uint32_t input_modulus_sup = input_message_modulus * input_carry_modulus;
|
||||
for (uint32_t i = 0; i < input_modulus_sup; ++i) {
|
||||
Torus f_eval = f(i);
|
||||
GPU_ASSERT(f_eval < (1 << bits_per_lut_val),
|
||||
"LUT value expected bitwidth overflow");
|
||||
f_hash |= f_eval;
|
||||
f_hash <<= bits_per_lut_val;
|
||||
}
|
||||
|
||||
std::lock_guard cache_lock(_mutex);
|
||||
if (_lut_cache.find(f_hash) != _lut_cache.end()) {
|
||||
lut_ptr &ptr = _lut_cache[f_hash];
|
||||
GPU_ASSERT(ptr.output_message_modulus == output_message_modulus,
|
||||
"Error modulus");
|
||||
GPU_ASSERT(ptr.input_message_modulus == input_message_modulus,
|
||||
"Error modulus");
|
||||
GPU_ASSERT(ptr.glwe_dimension == glwe_dimension, "Error modulus");
|
||||
*max_degree = ptr.max_degree;
|
||||
*degree = ptr.degree;
|
||||
return ptr.ptr;
|
||||
}*/
|
||||
|
||||
// host lut
|
||||
Torus *h_lut =
|
||||
(Torus *)malloc((glwe_dimension + 1) * polynomial_size * sizeof(Torus));
|
||||
|
||||
*max_degree = input_message_modulus * input_carry_modulus - 1;
|
||||
*degree = generate_lookup_table_with_encoding<Torus>(
|
||||
h_lut, glwe_dimension, polynomial_size, input_message_modulus,
|
||||
input_carry_modulus, output_message_modulus, output_carry_modulus, f);
|
||||
|
||||
/*lut_ptr new_ptr = {h_lut,
|
||||
glwe_dimension,
|
||||
input_message_modulus,
|
||||
input_carry_modulus,
|
||||
output_message_modulus,
|
||||
output_carry_modulus,
|
||||
*max_degree,
|
||||
*degree};*/
|
||||
//_lut_cache[f_hash] = new_ptr;
|
||||
return h_lut;
|
||||
}
|
||||
|
||||
~int_lut_cache() {
|
||||
std::lock_guard cache_lock(_mutex);
|
||||
for (auto v : _lut_cache) {
|
||||
free(v.second.ptr);
|
||||
}
|
||||
_lut_cache.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
struct lut_ptr {
|
||||
Torus *ptr;
|
||||
uint32_t glwe_dimension;
|
||||
uint32_t input_message_modulus;
|
||||
uint32_t input_carry_modulus;
|
||||
uint32_t output_message_modulus;
|
||||
uint32_t output_carry_modulus;
|
||||
uint64_t max_degree;
|
||||
uint64_t degree;
|
||||
};
|
||||
std::map<__int128_t, lut_ptr> _lut_cache;
|
||||
std::mutex _mutex;
|
||||
};
|
||||
static int_lut_cache<uint64_t> g_LutCache64;
|
||||
|
||||
/*
|
||||
* generate bivariate accumulator with factor scaling for device pointer
|
||||
* v_stream - cuda stream
|
||||
@@ -1098,8 +1177,8 @@ void generate_device_accumulator_bivariate_with_factor(
|
||||
(glwe_dimension + 1) * polynomial_size * sizeof(Torus), stream, gpu_index,
|
||||
gpu_memory_allocated);
|
||||
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
free(h_lut);
|
||||
// cuda_synchronize_stream(stream, gpu_index);
|
||||
// free(h_lut);
|
||||
}
|
||||
/*
|
||||
* generate bivariate accumulator for device pointer
|
||||
@@ -1145,23 +1224,36 @@ void generate_device_accumulator_with_encoding(
|
||||
uint32_t output_message_modulus, uint32_t output_carry_modulus,
|
||||
std::function<Torus(Torus)> f, bool gpu_memory_allocated) {
|
||||
|
||||
static constexpr auto is_u64 = std::is_same_v<Torus, uint64_t>;
|
||||
Torus *h_lut = nullptr;
|
||||
// host lut
|
||||
Torus *h_lut =
|
||||
(Torus *)malloc((glwe_dimension + 1) * polynomial_size * sizeof(Torus));
|
||||
|
||||
*max_degree = input_message_modulus * input_carry_modulus - 1;
|
||||
// fill accumulator
|
||||
*degree = generate_lookup_table_with_encoding<Torus>(
|
||||
h_lut, glwe_dimension, polynomial_size, input_message_modulus,
|
||||
input_carry_modulus, output_message_modulus, output_carry_modulus, f);
|
||||
if constexpr (is_u64) {
|
||||
h_lut = g_LutCache64.get_cached_univariate_lut(
|
||||
f, degree, max_degree, glwe_dimension, polynomial_size,
|
||||
input_message_modulus, input_carry_modulus, output_message_modulus,
|
||||
output_carry_modulus);
|
||||
} else {
|
||||
h_lut =
|
||||
(Torus *)malloc((glwe_dimension + 1) * polynomial_size * sizeof(Torus));
|
||||
|
||||
*max_degree = input_message_modulus * input_carry_modulus - 1;
|
||||
// fill accumulator
|
||||
*degree = generate_lookup_table_with_encoding<Torus>(
|
||||
h_lut, glwe_dimension, polynomial_size, input_message_modulus,
|
||||
input_carry_modulus, output_message_modulus, output_carry_modulus, f);
|
||||
}
|
||||
/*
|
||||
// copy host lut and lut_indexes_vec to device
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
acc, h_lut, (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
||||
stream, gpu_index, gpu_memory_allocated);
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
free(h_lut);
|
||||
*/
|
||||
if (!std::is_same_v<Torus, uint64_t>) {
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
free(h_lut);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
void generate_device_accumulator_with_encoding_with_cpu_prealloc(
|
||||
cudaStream_t stream, uint32_t gpu_index, Torus *acc, uint64_t *degree,
|
||||
@@ -1264,8 +1356,8 @@ void generate_many_lut_device_accumulator(
|
||||
acc, h_lut, (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
||||
stream, gpu_index, gpu_memory_allocated);
|
||||
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
free(h_lut);
|
||||
//cuda_synchronize_stream(stream, gpu_index);
|
||||
//free(h_lut);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
|
||||
45
backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cu
Normal file
45
backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cu
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "../../include/trivium/trivium.h"
|
||||
#include "trivium.cuh"
|
||||
|
||||
uint64_t scratch_cuda_trivium_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
glwe_dimension * polynomial_size, lwe_dimension,
|
||||
ks_level, ks_base_log, pbs_level, pbs_base_log,
|
||||
grouping_factor, message_modulus, carry_modulus,
|
||||
noise_reduction_type);
|
||||
|
||||
return scratch_cuda_trivium_encrypt<uint64_t>(
|
||||
CudaStreams(streams), (int_trivium_buffer<uint64_t> **)mem_ptr, params,
|
||||
allocate_gpu_memory, num_inputs);
|
||||
}
|
||||
|
||||
void cuda_trivium_generate_keystream_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
|
||||
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
|
||||
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
|
||||
auto buffer = (int_trivium_buffer<uint64_t> *)mem_ptr;
|
||||
|
||||
host_trivium_generate_keystream<uint64_t>(
|
||||
CudaStreams(streams), keystream_output, key, iv, num_inputs, num_steps,
|
||||
buffer, bsks, (uint64_t *const *)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
|
||||
|
||||
int_trivium_buffer<uint64_t> *mem_ptr =
|
||||
(int_trivium_buffer<uint64_t> *)(*mem_ptr_void);
|
||||
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
341
backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cuh
Normal file
341
backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cuh
Normal file
@@ -0,0 +1,341 @@
|
||||
#ifndef TRIVIUM_CUH
|
||||
#define TRIVIUM_CUH
|
||||
|
||||
#include "../../include/trivium/trivium_utilities.h"
|
||||
#include "../integer/integer.cuh"
|
||||
#include "../integer/radix_ciphertext.cuh"
|
||||
#include "../integer/scalar_addition.cuh"
|
||||
#include "../linearalgebra/addition.cuh"
|
||||
|
||||
// Reverses the order of bits (blocks) in a ciphertext buffer.
|
||||
// Used to align the input Key/IV with the internal state format if needed.
|
||||
template <typename Torus>
|
||||
void reverse_bitsliced_radix_inplace(CudaStreams streams,
|
||||
int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *radix,
|
||||
uint32_t num_bits_in_reg) {
|
||||
uint32_t N = mem->num_inputs;
|
||||
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
|
||||
|
||||
for (uint32_t i = 0; i < num_bits_in_reg; i++) {
|
||||
uint32_t src_start = i * N;
|
||||
uint32_t src_end = (i + 1) * N;
|
||||
|
||||
uint32_t dest_start = (num_bits_in_reg - 1 - i) * N;
|
||||
uint32_t dest_end = (num_bits_in_reg - i) * N;
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), temp, dest_start, dest_end,
|
||||
radix, src_start, src_end);
|
||||
}
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), radix, 0, num_bits_in_reg * N,
|
||||
temp, 0, num_bits_in_reg * N);
|
||||
}
|
||||
|
||||
// Creates a slice of specific bits in a register without copying data.
|
||||
template <typename Torus>
|
||||
__host__ void slice_reg_batch(CudaRadixCiphertextFFI *slice,
|
||||
const CudaRadixCiphertextFFI *reg,
|
||||
uint32_t start_bit_idx, uint32_t num_bits,
|
||||
uint32_t num_inputs) {
|
||||
as_radix_ciphertext_slice<Torus>(slice, reg, start_bit_idx * num_inputs,
|
||||
(start_bit_idx + num_bits) * num_inputs);
|
||||
}
|
||||
|
||||
// Handles the shift-register update: discards old bits, shifts the rest,
|
||||
// and inserts the newly computed bits at the beginning.
|
||||
template <typename Torus>
|
||||
__host__ void shift_and_insert_batch(CudaStreams streams,
|
||||
int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *reg,
|
||||
CudaRadixCiphertextFFI *new_bits,
|
||||
uint32_t reg_size, uint32_t num_inputs) {
|
||||
|
||||
constexpr uint32_t BATCH = 64;
|
||||
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
|
||||
|
||||
uint32_t num_blocks_to_keep = (reg_size - BATCH) * num_inputs;
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), temp, 0, BATCH * num_inputs,
|
||||
new_bits, 0, BATCH * num_inputs);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), temp, BATCH * num_inputs,
|
||||
reg_size * num_inputs, reg, 0, num_blocks_to_keep);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), reg, 0, reg_size * num_inputs,
|
||||
temp, 0, reg_size * num_inputs);
|
||||
}
|
||||
|
||||
// core logic: computes 64 parallel updates for the state registers.
|
||||
// It performs the XORs (additions) and the AND gates (using Bivariate PBS),
|
||||
// then updates the registers and writes to output if needed.
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
trivium_compute_64_steps(CudaStreams streams, int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *output_dest, void *const *bsks,
|
||||
uint64_t *const *ksks) {
|
||||
|
||||
uint32_t N = mem->num_inputs;
|
||||
constexpr uint32_t BATCH = 64;
|
||||
uint32_t batch_size_blocks = BATCH * N;
|
||||
auto s = mem->state;
|
||||
|
||||
CudaRadixCiphertextFFI a65_slice, a92_slice, a91_slice, a90_slice, a68_slice;
|
||||
slice_reg_batch<Torus>(&a65_slice, s->a_reg, 2, BATCH, N);
|
||||
slice_reg_batch<Torus>(&a92_slice, s->a_reg, 29, BATCH, N);
|
||||
slice_reg_batch<Torus>(&a91_slice, s->a_reg, 28, BATCH, N);
|
||||
slice_reg_batch<Torus>(&a90_slice, s->a_reg, 27, BATCH, N);
|
||||
slice_reg_batch<Torus>(&a68_slice, s->a_reg, 5, BATCH, N);
|
||||
|
||||
CudaRadixCiphertextFFI b68_slice, b83_slice, b82_slice, b81_slice, b77_slice;
|
||||
slice_reg_batch<Torus>(&b68_slice, s->b_reg, 5, BATCH, N);
|
||||
slice_reg_batch<Torus>(&b83_slice, s->b_reg, 20, BATCH, N);
|
||||
slice_reg_batch<Torus>(&b82_slice, s->b_reg, 19, BATCH, N);
|
||||
slice_reg_batch<Torus>(&b81_slice, s->b_reg, 18, BATCH, N);
|
||||
slice_reg_batch<Torus>(&b77_slice, s->b_reg, 14, BATCH, N);
|
||||
|
||||
CudaRadixCiphertextFFI c65_slice, c110_slice, c109_slice, c108_slice,
|
||||
c86_slice;
|
||||
slice_reg_batch<Torus>(&c65_slice, s->c_reg, 2, BATCH, N);
|
||||
slice_reg_batch<Torus>(&c110_slice, s->c_reg, 47, BATCH, N);
|
||||
slice_reg_batch<Torus>(&c109_slice, s->c_reg, 46, BATCH, N);
|
||||
slice_reg_batch<Torus>(&c108_slice, s->c_reg, 45, BATCH, N);
|
||||
slice_reg_batch<Torus>(&c86_slice, s->c_reg, 23, BATCH, N);
|
||||
|
||||
// t1 = a66 + a93
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->temp_t1,
|
||||
&a65_slice, &a92_slice, s->temp_t1->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
|
||||
// t2 = b69 + b84
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->temp_t2,
|
||||
&b68_slice, &b83_slice, s->temp_t2->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
|
||||
// t3 = c66 + c111
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->temp_t3,
|
||||
&c65_slice, &c110_slice, s->temp_t3->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, 0,
|
||||
batch_size_blocks, &c109_slice, 0, batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
|
||||
batch_size_blocks, 2 * batch_size_blocks, &a91_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, &b82_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, 0,
|
||||
batch_size_blocks, &c108_slice, 0, batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
|
||||
batch_size_blocks, 2 * batch_size_blocks, &a90_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, &b81_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
integer_radix_apply_bivariate_lookup_table<Torus>(
|
||||
streams, s->packed_pbs_out, s->packed_pbs_lhs, s->packed_pbs_rhs, bsks,
|
||||
ksks, mem->luts->and_lut, 3 * batch_size_blocks,
|
||||
mem->params.message_modulus);
|
||||
|
||||
CudaRadixCiphertextFFI and_res_a, and_res_b, and_res_c;
|
||||
as_radix_ciphertext_slice<Torus>(&and_res_a, s->packed_pbs_out, 0,
|
||||
batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&and_res_b, s->packed_pbs_out,
|
||||
batch_size_blocks, 2 * batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&and_res_c, s->packed_pbs_out,
|
||||
2 * batch_size_blocks,
|
||||
3 * batch_size_blocks);
|
||||
|
||||
// a = t3 + a69 + and_res_a
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_a,
|
||||
s->temp_t3, &a68_slice, s->new_a->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_a,
|
||||
s->new_a, &and_res_a, s->new_a->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
|
||||
// b = t1 + b78 + and_res_b
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_b,
|
||||
s->temp_t1, &b77_slice, s->new_b->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_b,
|
||||
s->new_b, &and_res_b, s->new_b->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
|
||||
// c = t2 + c87 + and_res_c
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_c,
|
||||
s->temp_t2, &c86_slice, s->new_c->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_c,
|
||||
s->new_c, &and_res_c, s->new_c->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
|
||||
if (output_dest != nullptr) {
|
||||
// z = t1 + t2 + t3
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), output_dest,
|
||||
s->temp_t1, s->temp_t2, output_dest->num_radix_blocks,
|
||||
mem->params.message_modulus,
|
||||
mem->params.carry_modulus);
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), output_dest,
|
||||
output_dest, s->temp_t3, output_dest->num_radix_blocks,
|
||||
mem->params.message_modulus,
|
||||
mem->params.carry_modulus);
|
||||
}
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0,
|
||||
batch_size_blocks, s->new_a, 0, batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, s->new_c, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
uint32_t total_flush_blocks = 3 * batch_size_blocks;
|
||||
|
||||
if (output_dest != nullptr) {
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
3 * batch_size_blocks, 4 * batch_size_blocks, output_dest, 0,
|
||||
batch_size_blocks);
|
||||
total_flush_blocks += batch_size_blocks;
|
||||
}
|
||||
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks,
|
||||
mem->luts->flush_lut, total_flush_blocks);
|
||||
|
||||
CudaRadixCiphertextFFI flushed_a, flushed_b, flushed_c;
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_a, s->packed_flush_out, 0,
|
||||
batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_b, s->packed_flush_out,
|
||||
batch_size_blocks, 2 * batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_c, s->packed_flush_out,
|
||||
2 * batch_size_blocks,
|
||||
3 * batch_size_blocks);
|
||||
|
||||
shift_and_insert_batch(streams, mem, s->a_reg, &flushed_a, 93, N);
|
||||
shift_and_insert_batch(streams, mem, s->b_reg, &flushed_b, 84, N);
|
||||
shift_and_insert_batch(streams, mem, s->c_reg, &flushed_c, 111, N);
|
||||
|
||||
if (output_dest != nullptr) {
|
||||
CudaRadixCiphertextFFI flushed_out;
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_out, s->packed_flush_out,
|
||||
3 * batch_size_blocks,
|
||||
4 * batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), output_dest, 0,
|
||||
batch_size_blocks, &flushed_out, 0, batch_size_blocks);
|
||||
|
||||
reverse_bitsliced_radix_inplace<Torus>(streams, mem, output_dest, 64);
|
||||
}
|
||||
}
|
||||
|
||||
// Sets up the initial state: loads Key and IV, fixes constants,
|
||||
// and runs the warm-up phase (1152 steps).
|
||||
template <typename Torus>
|
||||
__host__ void trivium_init(CudaStreams streams, int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI const *key_bitsliced,
|
||||
CudaRadixCiphertextFFI const *iv_bitsliced,
|
||||
void *const *bsks, uint64_t *const *ksks) {
|
||||
uint32_t N = mem->num_inputs;
|
||||
auto s = mem->state;
|
||||
|
||||
CudaRadixCiphertextFFI src_key_slice;
|
||||
slice_reg_batch<Torus>(&src_key_slice, key_bitsliced, 0, 80, N);
|
||||
|
||||
CudaRadixCiphertextFFI dest_a_slice;
|
||||
slice_reg_batch<Torus>(&dest_a_slice, s->a_reg, 0, 80, N);
|
||||
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
&dest_a_slice, &src_key_slice);
|
||||
|
||||
reverse_bitsliced_radix_inplace<Torus>(streams, mem, s->a_reg, 80);
|
||||
|
||||
CudaRadixCiphertextFFI src_iv_slice;
|
||||
slice_reg_batch<Torus>(&src_iv_slice, iv_bitsliced, 0, 80, N);
|
||||
|
||||
CudaRadixCiphertextFFI dest_b_slice;
|
||||
slice_reg_batch<Torus>(&dest_b_slice, s->b_reg, 0, 80, N);
|
||||
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
&dest_b_slice, &src_iv_slice);
|
||||
|
||||
reverse_bitsliced_radix_inplace<Torus>(streams, mem, s->b_reg, 80);
|
||||
|
||||
CudaRadixCiphertextFFI dest_c_ones;
|
||||
slice_reg_batch<Torus>(&dest_c_ones, s->c_reg, 108, 3, N);
|
||||
|
||||
host_add_scalar_one_inplace<Torus>(streams, &dest_c_ones,
|
||||
mem->params.message_modulus,
|
||||
mem->params.carry_modulus);
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, &dest_c_ones, &dest_c_ones, bsks, ksks, mem->luts->flush_lut,
|
||||
dest_c_ones.num_radix_blocks);
|
||||
|
||||
for (int i = 0; i < 18; i++) {
|
||||
trivium_compute_64_steps(streams, mem, nullptr, bsks, ksks);
|
||||
}
|
||||
}
|
||||
|
||||
// Main entry point: checks input validity, initializes state,
|
||||
// and loops to generate the keystream in batches of 64.
|
||||
template <typename Torus>
|
||||
__host__ void host_trivium_generate_keystream(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *keystream_output,
|
||||
CudaRadixCiphertextFFI const *key_bitsliced,
|
||||
CudaRadixCiphertextFFI const *iv_bitsliced, uint32_t num_inputs,
|
||||
uint32_t num_steps, int_trivium_buffer<Torus> *mem, void *const *bsks,
|
||||
uint64_t *const *ksks) {
|
||||
|
||||
PANIC_IF_FALSE(num_steps % 64 == 0,
|
||||
"Trivium Error: num_steps must be a multiple of 64.\n");
|
||||
|
||||
trivium_init(streams, mem, key_bitsliced, iv_bitsliced, bsks, ksks);
|
||||
|
||||
uint32_t num_batches = num_steps / 64;
|
||||
for (uint32_t i = 0; i < num_batches; i++) {
|
||||
CudaRadixCiphertextFFI batch_out_slice;
|
||||
slice_reg_batch<Torus>(&batch_out_slice, keystream_output, i * 64, 64,
|
||||
num_inputs);
|
||||
trivium_compute_64_steps(streams, mem, &batch_out_slice, bsks, ksks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t scratch_cuda_trivium_encrypt(CudaStreams streams,
|
||||
int_trivium_buffer<Torus> **mem_ptr,
|
||||
int_radix_params params,
|
||||
bool allocate_gpu_memory,
|
||||
uint32_t num_inputs) {
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
*mem_ptr = new int_trivium_buffer<Torus>(streams, params, allocate_gpu_memory,
|
||||
num_inputs, size_tracker);
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -2511,6 +2511,42 @@ unsafe extern "C" {
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_trivium_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
num_inputs: u32,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_trivium_generate_keystream_64(
|
||||
streams: CudaStreamsFFI,
|
||||
keystream_output: *mut CudaRadixCiphertextFFI,
|
||||
key: *const CudaRadixCiphertextFFI,
|
||||
iv: *const CudaRadixCiphertextFFI,
|
||||
num_inputs: u32,
|
||||
num_steps: u32,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_trivium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0;
|
||||
pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1;
|
||||
pub type KS_TYPE = ffi::c_uint;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "cuda/include/integer/integer.h"
|
||||
#include "cuda/include/integer/rerand.h"
|
||||
#include "cuda/include/aes/aes.h"
|
||||
#include "cuda/include/trivium/trivium.h"
|
||||
#include "cuda/include/zk/zk.h"
|
||||
#include "cuda/include/keyswitch/keyswitch.h"
|
||||
#include "cuda/include/keyswitch/ks_enums.h"
|
||||
|
||||
@@ -133,6 +133,12 @@ path = "benches/integer/aes.rs"
|
||||
harness = false
|
||||
required-features = ["integer", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "integer-trivium"
|
||||
path = "benches/integer/trivium.rs"
|
||||
harness = false
|
||||
required-features = ["integer", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "integer-aes256"
|
||||
path = "benches/integer/aes256.rs"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
mod aes;
|
||||
mod aes256;
|
||||
mod oprf;
|
||||
mod trivium;
|
||||
mod vector_find;
|
||||
|
||||
mod rerand;
|
||||
|
||||
89
tfhe-benchmark/benches/integer/trivium.rs
Normal file
89
tfhe-benchmark/benches/integer/trivium.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use criterion::Criterion;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub mod cuda {
|
||||
use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
use benchmark::utilities::{write_to_json, OperatorType};
|
||||
use criterion::{black_box, criterion_group, Criterion};
|
||||
use tfhe::core_crypto::gpu::CudaStreams;
|
||||
use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
use tfhe::integer::gpu::CudaServerKey;
|
||||
use tfhe::integer::keycache::KEY_CACHE;
|
||||
use tfhe::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey};
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::AtomicPatternParameters;
|
||||
|
||||
fn encrypt_bit_stream(cks: &RadixClientKey, bits: &[u64]) -> RadixCiphertext {
|
||||
RadixCiphertext::from(
|
||||
bits.iter()
|
||||
.map(|&bit| cks.encrypt_one_block(bit))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn cuda_trivium(c: &mut Criterion) {
|
||||
let bench_name = "integer::cuda::trivium";
|
||||
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60))
|
||||
.warm_up_time(std::time::Duration::from_secs(5));
|
||||
|
||||
let param = BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
let atomic_param: AtomicPatternParameters = param.into();
|
||||
|
||||
let key_bits = vec![0u64; 80];
|
||||
let iv_bits = vec![0u64; 80];
|
||||
|
||||
let param_name = param.name();
|
||||
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix);
|
||||
let sks = CudaServerKey::new(&cpu_cks, &streams);
|
||||
let cks = RadixClientKey::from((cpu_cks, 1));
|
||||
|
||||
let ct_key = encrypt_bit_stream(&cks, &key_bits);
|
||||
let ct_iv = encrypt_bit_stream(&cks, &iv_bits);
|
||||
|
||||
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
|
||||
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
|
||||
|
||||
for num_steps in [64, 512] {
|
||||
let bench_id = format!("{bench_name}::{param_name}::generate_{num_steps}_bits");
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
black_box(
|
||||
sks.trivium_generate_keystream(&d_key, &d_iv, num_steps, &streams)
|
||||
.unwrap(),
|
||||
);
|
||||
})
|
||||
});
|
||||
|
||||
write_to_json::<u64, _>(
|
||||
&bench_id,
|
||||
atomic_param,
|
||||
param.name(),
|
||||
&format!("trivium_generation_{}_bits", num_steps),
|
||||
&OperatorType::Atomic,
|
||||
80,
|
||||
vec![atomic_param.message_modulus().0.ilog2(); 80],
|
||||
);
|
||||
}
|
||||
|
||||
bench_group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(gpu_trivium, cuda_trivium);
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use cuda::gpu_trivium;
|
||||
|
||||
fn main() {
|
||||
#[cfg(feature = "gpu")]
|
||||
gpu_trivium();
|
||||
|
||||
Criterion::default().configure_from_args().final_summary();
|
||||
}
|
||||
1
tfhe/docs/.gitbook/assets/api-levels.svg
Normal file
1
tfhe/docs/.gitbook/assets/api-levels.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 181 KiB |
@@ -19,11 +19,13 @@ The overall process to write an homomorphic program is the same for all types. T
|
||||
|
||||
This library has different modules, with different levels of abstraction.
|
||||
|
||||
There is the **core\_crypto** module, which is the lowest level API with the primitive functions and types of the TFHE scheme.
|
||||
There is the [core\_crypto](../core-crypto-api/presentation.md) module, which is the lowest level API with the primitive functions and types of the TFHE scheme.
|
||||
|
||||
Above the core\_crypto module, there are the **Boolean**, **shortint**, and **integer** modules, which contain easy to use APIs enabling evaluation of Boolean, short integer, and integer circuits.
|
||||
Above the core\_crypto module, there are the [Boolean](boolean/README.md), [shortint](shortint/README.md), and [integer](integer/README.md) modules, which contain easy to use APIs enabling evaluation of Boolean, short integer, and integer circuits.
|
||||
|
||||
Finally, there is the high-level module built on top of the Boolean, shortint, integer modules. This module is meant to abstract cryptographic complexities: no cryptographical knowledge is required to start developing an FHE application. Another benefit of the high-level module is the drastically simplified development process compared to lower level modules.
|
||||
Finally, there is the high-level module built on top of the shortint and integer modules. This module is meant to abstract cryptographic complexities: no cryptographical knowledge is required to start developing an FHE application. Another benefit of the high-level module is the drastically simplified development process compared to lower level modules.
|
||||
|
||||

|
||||
|
||||
#### high-level API
|
||||
|
||||
|
||||
@@ -1084,18 +1084,10 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
output.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between input ({:?}) and output ({:?})",
|
||||
input.ciphertext_modulus(),
|
||||
output.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
|
||||
input.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between output ({:?}) and accumulator ({:?})",
|
||||
output.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
@@ -1795,18 +1787,10 @@ pub fn std_multi_bit_programmable_bootstrap_lwe_ciphertext<
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
output.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between input ({:?}) and output ({:?})",
|
||||
input.ciphertext_modulus(),
|
||||
output.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
|
||||
input.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between output ({:?}) and accumulator ({:?})",
|
||||
output.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
@@ -1868,14 +1852,6 @@ pub fn std_multi_bit_f128_blind_rotate_assign<Scalar, InputCont, OutputCont, Key
|
||||
multi_bit_bsk.input_lwe_dimension(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
|
||||
input.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let grouping_factor = multi_bit_bsk.grouping_factor();
|
||||
|
||||
let lut_poly_size = accumulator.polynomial_size();
|
||||
@@ -2213,18 +2189,10 @@ pub fn std_multi_bit_programmable_bootstrap_f128_lwe_ciphertext<
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
output.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between input ({:?}) and output ({:?})",
|
||||
input.ciphertext_modulus(),
|
||||
output.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
|
||||
input.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus between output ({:?}) and accumulator ({:?})",
|
||||
output.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
@@ -2718,3 +2686,224 @@ pub fn multi_bit_programmable_bootstrap_f128_lwe_ciphertext<
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
|
||||
}
|
||||
|
||||
// ============== Noise measurement trait implementations ============== //
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
|
||||
AllocateMultiBitModSwitchResult, LweMultiBitFft128BlindRotate, LweMultiBitFft128Bootstrap,
|
||||
LweMultiBitFftBlindRotate, LweMultiBitFftBootstrap, MultiBitModSwitch,
|
||||
};
|
||||
|
||||
impl<
|
||||
Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
|
||||
C: Container<Element = Scalar> + Sync,
|
||||
> AllocateMultiBitModSwitchResult for LweCiphertext<C>
|
||||
{
|
||||
type Output = StandardMultiBitModulusSwitchedCt<Scalar, Vec<Scalar>>;
|
||||
type SideResources = ();
|
||||
|
||||
fn allocate_multi_bit_mod_switch_result(
|
||||
&self,
|
||||
_side_resources: &mut Self::SideResources,
|
||||
) -> Self::Output {
|
||||
// We will mod switch but we keep the current modulus as the noise is interesting in the
|
||||
// context of the input modulus
|
||||
Self::Output {
|
||||
input: LweCiphertextOwned::from_container(
|
||||
self.as_ref().to_vec(),
|
||||
self.ciphertext_modulus(),
|
||||
),
|
||||
// Placeholder values for those as they will be filled during mod switch
|
||||
// Choose defaults that should crash things
|
||||
grouping_factor: LweBskGroupingFactor(usize::MAX),
|
||||
log_modulus: CiphertextModulusLog(usize::MAX),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
|
||||
C: Container<Element = Scalar>,
|
||||
OutCont: ContainerMut<Element = Scalar> + Sync,
|
||||
> MultiBitModSwitch<StandardMultiBitModulusSwitchedCt<Scalar, OutCont>> for LweCiphertext<C>
|
||||
{
|
||||
type SideResources = ();
|
||||
|
||||
fn multi_bit_mod_switch(
|
||||
&self,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
output_modulus_log: CiphertextModulusLog,
|
||||
output: &mut StandardMultiBitModulusSwitchedCt<Scalar, OutCont>,
|
||||
_side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
assert_eq!(output.input.ciphertext_modulus(), self.ciphertext_modulus());
|
||||
|
||||
let StandardMultiBitModulusSwitchedCt {
|
||||
input: output_input,
|
||||
grouping_factor: output_grouping_factor,
|
||||
log_modulus: output_log_modulus,
|
||||
} = output;
|
||||
|
||||
*output_grouping_factor = grouping_factor;
|
||||
*output_log_modulus = output_modulus_log;
|
||||
output_input.as_mut().copy_from_slice(self.as_ref());
|
||||
|
||||
// Nothing to do given the mod switch will be lazily evaluated by the following multi bit
|
||||
// blind rotation
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
InputScalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
|
||||
InputCont: Container<Element = InputScalar> + Sync,
|
||||
OutputScalar: UnsignedTorus + Sync,
|
||||
OutputCont: ContainerMut<Element = OutputScalar>,
|
||||
AccCont: Container<Element = OutputScalar>,
|
||||
KeyCont: Container<Element = c64> + Sync,
|
||||
>
|
||||
LweMultiBitFftBlindRotate<
|
||||
StandardMultiBitModulusSwitchedCt<InputScalar, InputCont>,
|
||||
LweCiphertext<OutputCont>,
|
||||
GlweCiphertext<AccCont>,
|
||||
> for FourierLweMultiBitBootstrapKey<KeyCont>
|
||||
{
|
||||
type SideResources = (ThreadCount, bool);
|
||||
|
||||
fn lwe_multi_bit_fft_blind_rotate(
|
||||
&self,
|
||||
input: &StandardMultiBitModulusSwitchedCt<InputScalar, InputCont>,
|
||||
output: &mut LweCiphertext<OutputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
let (thread_count, deterministic_execution) = *side_resources;
|
||||
|
||||
let mut local_accumulator = GlweCiphertext::from_container(
|
||||
accumulator.as_ref().to_vec(),
|
||||
accumulator.polynomial_size(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
multi_bit_blind_rotate_assign(
|
||||
input,
|
||||
&mut local_accumulator,
|
||||
self,
|
||||
thread_count,
|
||||
deterministic_execution,
|
||||
);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
InputScalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
|
||||
InputCont: Container<Element = InputScalar> + Sync,
|
||||
OutputScalar: UnsignedTorus + Sync,
|
||||
OutputCont: ContainerMut<Element = OutputScalar>,
|
||||
AccCont: Container<Element = OutputScalar>,
|
||||
KeyCont: Container<Element = f64> + Sync + Split,
|
||||
>
|
||||
LweMultiBitFft128BlindRotate<
|
||||
StandardMultiBitModulusSwitchedCt<InputScalar, InputCont>,
|
||||
LweCiphertext<OutputCont>,
|
||||
GlweCiphertext<AccCont>,
|
||||
> for Fourier128LweMultiBitBootstrapKey<KeyCont>
|
||||
{
|
||||
type SideResources = ThreadCount;
|
||||
|
||||
fn lwe_multi_bit_fft_128_blind_rotate(
|
||||
&self,
|
||||
input: &StandardMultiBitModulusSwitchedCt<InputScalar, InputCont>,
|
||||
output: &mut LweCiphertext<OutputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
let thread_count = *side_resources;
|
||||
|
||||
let mut local_accumulator = GlweCiphertext::from_container(
|
||||
accumulator.as_ref().to_vec(),
|
||||
accumulator.polynomial_size(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
multi_bit_f128_deterministic_blind_rotate_assign(
|
||||
input,
|
||||
&mut local_accumulator,
|
||||
self,
|
||||
thread_count,
|
||||
);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>,
|
||||
InputCont: Container<Element = Scalar> + Sync,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
AccCont: Container<Element = Scalar>,
|
||||
KeyCont: Container<Element = c64> + Sync,
|
||||
>
|
||||
LweMultiBitFftBootstrap<
|
||||
LweCiphertext<InputCont>,
|
||||
LweCiphertext<OutputCont>,
|
||||
GlweCiphertext<AccCont>,
|
||||
> for FourierLweMultiBitBootstrapKey<KeyCont>
|
||||
{
|
||||
type SideResources = (ThreadCount, bool);
|
||||
|
||||
fn lwe_multi_bit_fft_bootstrap(
|
||||
&self,
|
||||
input: &LweCiphertext<InputCont>,
|
||||
output: &mut LweCiphertext<OutputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
let (thread_count, deterministic_execution) = *side_resources;
|
||||
|
||||
multi_bit_programmable_bootstrap_lwe_ciphertext(
|
||||
input,
|
||||
output,
|
||||
accumulator,
|
||||
self,
|
||||
thread_count,
|
||||
deterministic_execution,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
InputScalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>,
|
||||
InputCont: Container<Element = InputScalar> + Sync,
|
||||
OutputScalar: UnsignedTorus + Sync,
|
||||
OutputCont: ContainerMut<Element = OutputScalar>,
|
||||
AccCont: Container<Element = OutputScalar>,
|
||||
KeyCont: Container<Element = f64> + Sync + Split,
|
||||
>
|
||||
LweMultiBitFft128Bootstrap<
|
||||
LweCiphertext<InputCont>,
|
||||
LweCiphertext<OutputCont>,
|
||||
GlweCiphertext<AccCont>,
|
||||
> for Fourier128LweMultiBitBootstrapKey<KeyCont>
|
||||
{
|
||||
type SideResources = (ThreadCount, bool);
|
||||
|
||||
fn lwe_multi_bit_fft_128_bootstrap(
|
||||
&self,
|
||||
input: &LweCiphertext<InputCont>,
|
||||
output: &mut LweCiphertext<OutputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
let (thread_count, deterministic_execution) = *side_resources;
|
||||
|
||||
multi_bit_programmable_bootstrap_f128_lwe_ciphertext(
|
||||
input,
|
||||
output,
|
||||
accumulator,
|
||||
self,
|
||||
thread_count,
|
||||
deterministic_execution,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,9 @@ where
|
||||
}
|
||||
|
||||
// ============== Noise measurement trait implementations ============== //
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::AllocateLweBootstrapResult;
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
|
||||
AllocateLweBootstrapResult, AllocateLweMultiBitBlindRotateResult,
|
||||
};
|
||||
|
||||
impl<Scalar: UnsignedInteger, AccCont: Container<Element = Scalar>> AllocateLweBootstrapResult
|
||||
for GlweCiphertext<AccCont>
|
||||
@@ -100,3 +102,25 @@ impl<Scalar: UnsignedInteger, AccCont: Container<Element = Scalar>> AllocateLweB
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Scalar: UnsignedInteger, AccCont: Container<Element = Scalar>>
|
||||
AllocateLweMultiBitBlindRotateResult for GlweCiphertext<AccCont>
|
||||
{
|
||||
type Output = LweCiphertextOwned<Scalar>;
|
||||
type SideResources = ();
|
||||
|
||||
fn allocate_lwe_multi_bit_blind_rotate_result(
|
||||
&self,
|
||||
_side_resources: &mut Self::SideResources,
|
||||
) -> Self::Output {
|
||||
let glwe_dim = self.glwe_size().to_glwe_dimension();
|
||||
let poly_size = self.polynomial_size();
|
||||
let equivalent_lwe_dim = glwe_dim.to_equivalent_lwe_dimension(poly_size);
|
||||
|
||||
LweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
equivalent_lwe_dim.to_lwe_size(),
|
||||
self.ciphertext_modulus(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::core_crypto::commons::noise_formulas::lwe_multi_bit_programmable_boot
|
||||
multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul,
|
||||
multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul,
|
||||
};
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::PBS_FFT_64_MANTISSA_SIZE;
|
||||
use crate::core_crypto::commons::noise_formulas::secure_noise::{
|
||||
minimal_lwe_variance_for_132_bits_security_gaussian,
|
||||
minimal_lwe_variance_for_132_bits_security_tuniform,
|
||||
@@ -48,6 +49,7 @@ where
|
||||
polynomial_size,
|
||||
pbs_decomposition_base_log,
|
||||
pbs_decomposition_level_count,
|
||||
PBS_FFT_64_MANTISSA_SIZE,
|
||||
modulus_as_f64,
|
||||
)
|
||||
}
|
||||
@@ -58,6 +60,7 @@ where
|
||||
polynomial_size,
|
||||
pbs_decomposition_base_log,
|
||||
pbs_decomposition_level_count,
|
||||
PBS_FFT_64_MANTISSA_SIZE,
|
||||
modulus_as_f64,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap::pbs_variance_132_bits_security_gaussian_fft_mul;
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::PBS_FFT_64_MANTISSA_SIZE;
|
||||
use crate::core_crypto::commons::noise_formulas::secure_noise::minimal_lwe_variance_for_132_bits_security_gaussian;
|
||||
use crate::core_crypto::commons::test_tools::{torus_modular_diff, variance};
|
||||
use rayon::prelude::*;
|
||||
@@ -38,6 +39,7 @@ where
|
||||
polynomial_size,
|
||||
pbs_decomposition_base_log,
|
||||
pbs_decomposition_level_count,
|
||||
PBS_FFT_64_MANTISSA_SIZE,
|
||||
modulus_as_f64,
|
||||
);
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_2_fft_mul(
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(
|
||||
@@ -25,6 +26,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_2_fft_mul(
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
),
|
||||
)
|
||||
@@ -40,16 +42,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_2_fft_mul_impl(
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
(1_f64 / 2.0)
|
||||
* input_lwe_dimension
|
||||
* (0.0022
|
||||
* (2.0
|
||||
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
core::f64::consts::LOG2_E * modulus.ln() - 53.0
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
@@ -86,6 +89,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul(
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(
|
||||
@@ -95,6 +99,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul(
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
),
|
||||
)
|
||||
@@ -110,16 +115,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul_impl(
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
(1_f64 / 3.0)
|
||||
* input_lwe_dimension
|
||||
* (0.00492
|
||||
* (2.0
|
||||
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
core::f64::consts::LOG2_E * modulus.ln() - 53.0
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
@@ -156,6 +162,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul(
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(
|
||||
@@ -165,6 +172,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul(
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
),
|
||||
)
|
||||
@@ -180,16 +188,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul_impl(
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
(1_f64 / 4.0)
|
||||
* input_lwe_dimension
|
||||
* (0.00855
|
||||
* (2.0
|
||||
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
core::f64::consts::LOG2_E * modulus.ln() - 53.0
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
@@ -226,6 +235,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_2_fft_mul(
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(
|
||||
@@ -235,6 +245,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_2_fft_mul(
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
),
|
||||
)
|
||||
@@ -250,16 +261,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_2_fft_mul_impl(
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
(1_f64 / 2.0)
|
||||
* input_lwe_dimension
|
||||
* (0.0022
|
||||
* (2.0
|
||||
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
core::f64::consts::LOG2_E * modulus.ln() - 53.0
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
@@ -302,6 +314,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul(
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(
|
||||
@@ -311,6 +324,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul(
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
),
|
||||
)
|
||||
@@ -326,16 +340,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul_impl(
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
(1_f64 / 3.0)
|
||||
* input_lwe_dimension
|
||||
* (0.00492
|
||||
* (2.0
|
||||
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
core::f64::consts::LOG2_E * modulus.ln() - 53.0
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
@@ -378,6 +393,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul(
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(
|
||||
@@ -387,6 +403,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul(
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
),
|
||||
)
|
||||
@@ -402,16 +419,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul_impl(
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
(1_f64 / 4.0)
|
||||
* input_lwe_dimension
|
||||
* (0.00855
|
||||
* (2.0
|
||||
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
core::f64::consts::LOG2_E * modulus.ln() - 53.0
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
|
||||
@@ -16,6 +16,7 @@ pub fn pbs_variance_132_bits_security_gaussian_fft_mul(
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(pbs_variance_132_bits_security_gaussian_fft_mul_impl(
|
||||
@@ -24,6 +25,7 @@ pub fn pbs_variance_132_bits_security_gaussian_fft_mul(
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
))
|
||||
}
|
||||
@@ -38,15 +40,16 @@ pub fn pbs_variance_132_bits_security_gaussian_fft_mul_impl(
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
input_lwe_dimension
|
||||
* (0.00705
|
||||
* (2.0
|
||||
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
core::f64::consts::LOG2_E * modulus.ln() - 53.0
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
@@ -83,6 +86,7 @@ pub fn pbs_variance_132_bits_security_tuniform_fft_mul(
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(pbs_variance_132_bits_security_tuniform_fft_mul_impl(
|
||||
@@ -91,6 +95,7 @@ pub fn pbs_variance_132_bits_security_tuniform_fft_mul(
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
))
|
||||
}
|
||||
@@ -105,15 +110,16 @@ pub fn pbs_variance_132_bits_security_tuniform_fft_mul_impl(
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
input_lwe_dimension
|
||||
* (0.00705
|
||||
* (2.0
|
||||
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
core::f64::consts::LOG2_E * modulus.ln() - 53.0
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
// This file was autogenerated, do not modify by hand.
|
||||
#![allow(unused_parens)]
|
||||
#![allow(clippy::neg_multiply)]
|
||||
#![allow(clippy::suspicious_operation_groupings)]
|
||||
|
||||
use crate::core_crypto::commons::dispersion::Variance;
|
||||
use crate::core_crypto::commons::parameters::*;
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn pbs_128_variance_132_bits_security_gaussian_fft_mul(
|
||||
input_lwe_dimension: LweDimension,
|
||||
output_glwe_dimension: GlweDimension,
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(pbs_128_variance_132_bits_security_gaussian_fft_mul_impl(
|
||||
input_lwe_dimension.0 as f64,
|
||||
output_glwe_dimension.0 as f64,
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
))
|
||||
}
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn pbs_128_variance_132_bits_security_gaussian_fft_mul_impl(
|
||||
input_lwe_dimension: f64,
|
||||
output_glwe_dimension: f64,
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
input_lwe_dimension
|
||||
* (0.00705
|
||||
* (2.0
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
.exp2()
|
||||
* decomposition_level_count.powf(1.01827)
|
||||
* output_glwe_dimension.powf(1.22003)
|
||||
* output_polynomial_size.powf(2.22003)
|
||||
* (output_glwe_dimension + 1.0).powf(1.01827)
|
||||
+ decomposition_level_count
|
||||
* output_polynomial_size
|
||||
* ((4.0 - 2.88539008177793 * modulus.ln()).exp2()
|
||||
+ (-0.0497829131652661 * output_glwe_dimension * output_polynomial_size
|
||||
+ 5.31469187675068)
|
||||
.exp2())
|
||||
* ((1_f64 / 12.0) * decomposition_base.powf(2.0) + 0.166666666666667)
|
||||
* (output_glwe_dimension + 1.0)
|
||||
- 1_f64 / 24.0 * modulus.powf(-2.0)
|
||||
+ (1_f64 / 2.0)
|
||||
* output_glwe_dimension
|
||||
* output_polynomial_size
|
||||
* (0.0208333333333333 * modulus.powf(-2.0)
|
||||
+ 0.0416666666666667
|
||||
* decomposition_base.powf(-2.0 * decomposition_level_count))
|
||||
+ (1_f64 / 24.0) * decomposition_base.powf(-2.0 * decomposition_level_count))
|
||||
}
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn pbs_128_variance_132_bits_security_tuniform_fft_mul(
|
||||
input_lwe_dimension: LweDimension,
|
||||
output_glwe_dimension: GlweDimension,
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(pbs_128_variance_132_bits_security_tuniform_fft_mul_impl(
|
||||
input_lwe_dimension.0 as f64,
|
||||
output_glwe_dimension.0 as f64,
|
||||
output_polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
mantissa_size,
|
||||
modulus,
|
||||
))
|
||||
}
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn pbs_128_variance_132_bits_security_tuniform_fft_mul_impl(
|
||||
input_lwe_dimension: f64,
|
||||
output_glwe_dimension: f64,
|
||||
output_polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
mantissa_size: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
input_lwe_dimension
|
||||
* (0.00705
|
||||
* (2.0
|
||||
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
|
||||
}
|
||||
+ 2.88539008177793 * decomposition_base.ln()
|
||||
- 2.88539008177793 * modulus.ln())
|
||||
.exp2()
|
||||
* decomposition_level_count.powf(1.01827)
|
||||
* output_glwe_dimension.powf(1.22003)
|
||||
* output_polynomial_size.powf(2.22003)
|
||||
* (output_glwe_dimension + 1.0).powf(1.01827)
|
||||
+ decomposition_level_count
|
||||
* output_polynomial_size
|
||||
* ((4.44 - 2.88539008177793 * modulus.ln()).exp2()
|
||||
+ (1_f64 / 3.0)
|
||||
* modulus.powf(-2.0)
|
||||
* ((2.0
|
||||
* (-0.025167785 * output_glwe_dimension * output_polynomial_size
|
||||
+ core::f64::consts::LOG2_E * modulus.ln()
|
||||
+ 4.10067100000001)
|
||||
.ceil())
|
||||
.exp2()
|
||||
+ 0.5))
|
||||
* ((1_f64 / 12.0) * decomposition_base.powf(2.0) + 0.166666666666667)
|
||||
* (output_glwe_dimension + 1.0)
|
||||
- 1_f64 / 24.0 * modulus.powf(-2.0)
|
||||
+ (1_f64 / 2.0)
|
||||
* output_glwe_dimension
|
||||
* output_polynomial_size
|
||||
* (0.0208333333333333 * modulus.powf(-2.0)
|
||||
+ 0.0416666666666667
|
||||
* decomposition_base.powf(-2.0 * decomposition_level_count))
|
||||
+ (1_f64 / 24.0) * decomposition_base.powf(-2.0 * decomposition_level_count))
|
||||
}
|
||||
@@ -5,7 +5,6 @@ pub mod lwe_keyswitch;
|
||||
pub mod lwe_multi_bit_programmable_bootstrap;
|
||||
pub mod lwe_packing_keyswitch;
|
||||
pub mod lwe_programmable_bootstrap;
|
||||
pub mod lwe_programmable_bootstrap_128;
|
||||
pub mod modulus_switch;
|
||||
pub mod multi_bit_modulus_switch;
|
||||
pub mod secure_noise;
|
||||
|
||||
@@ -7,16 +7,22 @@ use crate::core_crypto::commons::noise_formulas::lwe_multi_bit_programmable_boot
|
||||
multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul,
|
||||
multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul,
|
||||
};
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::LweMultiBitFftBlindRotate;
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
|
||||
LweMultiBitFft128BlindRotate, LweMultiBitFft128Bootstrap, LweMultiBitFftBlindRotate,
|
||||
LweMultiBitFftBootstrap,
|
||||
};
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::{
|
||||
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationModulus,
|
||||
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationModulus, PBS_FFT_128_MANTISSA_SIZE,
|
||||
PBS_FFT_64_MANTISSA_SIZE,
|
||||
};
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweSize,
|
||||
LweBskGroupingFactor, LweDimension, PolynomialSize,
|
||||
};
|
||||
use crate::core_crypto::commons::traits::container::Container;
|
||||
use crate::core_crypto::entities::lwe_multi_bit_bootstrap_key::FourierLweMultiBitBootstrapKey;
|
||||
use crate::core_crypto::entities::lwe_multi_bit_bootstrap_key::{
|
||||
Fourier128LweMultiBitBootstrapKey, FourierLweMultiBitBootstrapKey,
|
||||
};
|
||||
use crate::core_crypto::fft_impl::fft64::c64;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
@@ -116,6 +122,11 @@ impl NoiseSimulationLweMultiBitFourierBsk {
|
||||
pub fn modulus(&self) -> NoiseSimulationModulus {
|
||||
self.modulus
|
||||
}
|
||||
|
||||
pub fn mantissa_size(&self) -> f64 {
|
||||
let _ = self;
|
||||
PBS_FFT_64_MANTISSA_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
|
||||
@@ -147,6 +158,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
3 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul(
|
||||
@@ -155,6 +167,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
4 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul(
|
||||
@@ -163,6 +176,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
gf => panic!("Unsupported grouping factor: {gf}"),
|
||||
@@ -174,6 +188,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
3 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul(
|
||||
@@ -182,6 +197,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
4 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul(
|
||||
@@ -190,6 +206,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
gf => panic!("Unsupported grouping factor: {gf}"),
|
||||
@@ -208,3 +225,238 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl LweMultiBitFftBootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
|
||||
for NoiseSimulationLweMultiBitFourierBsk
|
||||
{
|
||||
type SideResources = ();
|
||||
|
||||
fn lwe_multi_bit_fft_bootstrap(
|
||||
&self,
|
||||
input: &NoiseSimulationLwe,
|
||||
output: &mut NoiseSimulationLwe,
|
||||
accumulator: &NoiseSimulationGlwe,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
// Noise-wise it is the same
|
||||
self.lwe_multi_bit_fft_blind_rotate(input, output, accumulator, side_resources);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct NoiseSimulationLweMultiBitFourier128Bsk {
|
||||
input_lwe_dimension: LweDimension,
|
||||
output_glwe_size: GlweSize,
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomp_base_log: DecompositionBaseLog,
|
||||
decomp_level_count: DecompositionLevelCount,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_distribution: DynamicDistribution<u128>,
|
||||
modulus: NoiseSimulationModulus,
|
||||
}
|
||||
|
||||
impl NoiseSimulationLweMultiBitFourier128Bsk {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
input_lwe_dimension: LweDimension,
|
||||
output_glwe_size: GlweSize,
|
||||
output_polynomial_size: PolynomialSize,
|
||||
decomp_base_log: DecompositionBaseLog,
|
||||
decomp_level_count: DecompositionLevelCount,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_distribution: DynamicDistribution<u128>,
|
||||
modulus: NoiseSimulationModulus,
|
||||
) -> Self {
|
||||
Self {
|
||||
input_lwe_dimension,
|
||||
output_glwe_size,
|
||||
output_polynomial_size,
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
grouping_factor,
|
||||
noise_distribution,
|
||||
modulus,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn matches_actual_bsk<C: Container<Element = f64>>(
|
||||
&self,
|
||||
lwe_bsk: &Fourier128LweMultiBitBootstrapKey<C>,
|
||||
) -> bool {
|
||||
let Self {
|
||||
input_lwe_dimension,
|
||||
output_glwe_size: glwe_size,
|
||||
output_polynomial_size: polynomial_size,
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
grouping_factor,
|
||||
noise_distribution: _,
|
||||
modulus: _,
|
||||
} = *self;
|
||||
|
||||
let bsk_input_lwe_dimension = lwe_bsk.input_lwe_dimension();
|
||||
let bsk_glwe_size = lwe_bsk.glwe_size();
|
||||
let bsk_polynomial_size = lwe_bsk.polynomial_size();
|
||||
let bsk_decomp_base_log = lwe_bsk.decomposition_base_log();
|
||||
let bsk_decomp_level_count = lwe_bsk.decomposition_level_count();
|
||||
let bsk_grouping_factor = lwe_bsk.grouping_factor();
|
||||
|
||||
input_lwe_dimension == bsk_input_lwe_dimension
|
||||
&& glwe_size == bsk_glwe_size
|
||||
&& polynomial_size == bsk_polynomial_size
|
||||
&& decomp_base_log == bsk_decomp_base_log
|
||||
&& decomp_level_count == bsk_decomp_level_count
|
||||
&& grouping_factor == bsk_grouping_factor
|
||||
}
|
||||
|
||||
pub fn input_lwe_dimension(&self) -> LweDimension {
|
||||
self.input_lwe_dimension
|
||||
}
|
||||
|
||||
pub fn output_glwe_size(&self) -> GlweSize {
|
||||
self.output_glwe_size
|
||||
}
|
||||
|
||||
pub fn output_polynomial_size(&self) -> PolynomialSize {
|
||||
self.output_polynomial_size
|
||||
}
|
||||
|
||||
pub fn decomp_base_log(&self) -> DecompositionBaseLog {
|
||||
self.decomp_base_log
|
||||
}
|
||||
|
||||
pub fn decomp_level_count(&self) -> DecompositionLevelCount {
|
||||
self.decomp_level_count
|
||||
}
|
||||
|
||||
pub fn grouping_factor(&self) -> LweBskGroupingFactor {
|
||||
self.grouping_factor
|
||||
}
|
||||
|
||||
pub fn noise_distribution(&self) -> DynamicDistribution<u128> {
|
||||
self.noise_distribution
|
||||
}
|
||||
|
||||
pub fn modulus(&self) -> NoiseSimulationModulus {
|
||||
self.modulus
|
||||
}
|
||||
|
||||
pub fn mantissa_size(&self) -> f64 {
|
||||
let _ = self;
|
||||
PBS_FFT_128_MANTISSA_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
impl LweMultiBitFft128BlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
|
||||
for NoiseSimulationLweMultiBitFourier128Bsk
|
||||
{
|
||||
type SideResources = ();
|
||||
|
||||
fn lwe_multi_bit_fft_128_blind_rotate(
|
||||
&self,
|
||||
input: &NoiseSimulationLwe,
|
||||
output: &mut NoiseSimulationLwe,
|
||||
accumulator: &NoiseSimulationGlwe,
|
||||
_side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
assert_eq!(self.input_lwe_dimension(), input.lwe_dimension());
|
||||
assert_eq!(
|
||||
self.output_glwe_size(),
|
||||
accumulator.glwe_dimension().to_glwe_size()
|
||||
);
|
||||
assert_eq!(self.output_polynomial_size(), accumulator.polynomial_size());
|
||||
assert_eq!(self.modulus(), accumulator.modulus());
|
||||
let grouping_factor = self.grouping_factor();
|
||||
|
||||
let br_additive_variance = match self.noise_distribution() {
|
||||
DynamicDistribution::Gaussian(_) => match grouping_factor.0 {
|
||||
2 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_2_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
3 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
4 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
gf => panic!("Unsupported grouping factor: {gf}"),
|
||||
},
|
||||
DynamicDistribution::TUniform(_) => match grouping_factor.0 {
|
||||
2 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_2_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
3 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
4 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
gf => panic!("Unsupported grouping factor: {gf}"),
|
||||
},
|
||||
};
|
||||
|
||||
let output_lwe_dimension = self
|
||||
.output_glwe_size()
|
||||
.to_glwe_dimension()
|
||||
.to_equivalent_lwe_dimension(self.output_polynomial_size());
|
||||
|
||||
*output = NoiseSimulationLwe::new(
|
||||
output_lwe_dimension,
|
||||
Variance(accumulator.variance_per_occupied_slot().0 + br_additive_variance.0),
|
||||
accumulator.modulus(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl LweMultiBitFft128Bootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
|
||||
for NoiseSimulationLweMultiBitFourier128Bsk
|
||||
{
|
||||
type SideResources = ();
|
||||
|
||||
fn lwe_multi_bit_fft_128_bootstrap(
|
||||
&self,
|
||||
input: &NoiseSimulationLwe,
|
||||
output: &mut NoiseSimulationLwe,
|
||||
accumulator: &NoiseSimulationGlwe,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
// Noise-wise it is the same
|
||||
self.lwe_multi_bit_fft_128_blind_rotate(input, output, accumulator, side_resources);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,15 +3,12 @@ use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap::{
|
||||
pbs_variance_132_bits_security_gaussian_fft_mul,
|
||||
pbs_variance_132_bits_security_tuniform_fft_mul,
|
||||
};
|
||||
use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap_128::{
|
||||
pbs_128_variance_132_bits_security_gaussian_fft_mul,
|
||||
pbs_128_variance_132_bits_security_tuniform_fft_mul,
|
||||
};
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
|
||||
LweClassicFft128Bootstrap, LweClassicFftBootstrap,
|
||||
};
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::{
|
||||
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationModulus,
|
||||
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationModulus, PBS_FFT_128_MANTISSA_SIZE,
|
||||
PBS_FFT_64_MANTISSA_SIZE,
|
||||
};
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweSize, LweDimension,
|
||||
@@ -108,6 +105,11 @@ impl NoiseSimulationLweFourierBsk {
|
||||
pub fn modulus(&self) -> NoiseSimulationModulus {
|
||||
self.modulus
|
||||
}
|
||||
|
||||
pub fn mantissa_size(&self) -> f64 {
|
||||
let _ = self;
|
||||
PBS_FFT_64_MANTISSA_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
impl LweClassicFftBootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
|
||||
@@ -137,6 +139,7 @@ impl LweClassicFftBootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulat
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
DynamicDistribution::TUniform(_) => pbs_variance_132_bits_security_tuniform_fft_mul(
|
||||
@@ -145,6 +148,7 @@ impl LweClassicFftBootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulat
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
};
|
||||
@@ -248,6 +252,11 @@ impl NoiseSimulationLweFourier128Bsk {
|
||||
pub fn modulus(&self) -> NoiseSimulationModulus {
|
||||
self.modulus
|
||||
}
|
||||
|
||||
pub fn mantissa_size(&self) -> f64 {
|
||||
let _ = self;
|
||||
PBS_FFT_128_MANTISSA_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
impl LweClassicFft128Bootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
|
||||
@@ -271,30 +280,24 @@ impl LweClassicFft128Bootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
|
||||
assert_eq!(self.modulus(), accumulator.modulus());
|
||||
|
||||
let br_additive_variance = match self.noise_distribution() {
|
||||
DynamicDistribution::Gaussian(_) => {
|
||||
pbs_128_variance_132_bits_security_gaussian_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
// Current PBS 128 implem has 104 bits of equivalent mantissa
|
||||
104.0f64,
|
||||
self.modulus().as_f64(),
|
||||
)
|
||||
}
|
||||
DynamicDistribution::TUniform(_) => {
|
||||
pbs_128_variance_132_bits_security_tuniform_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
// Current PBS 128 implem has 104 bits of equivalent mantissa
|
||||
104.0f64,
|
||||
self.modulus().as_f64(),
|
||||
)
|
||||
}
|
||||
DynamicDistribution::Gaussian(_) => pbs_variance_132_bits_security_gaussian_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
DynamicDistribution::TUniform(_) => pbs_variance_132_bits_security_tuniform_fft_mul(
|
||||
self.input_lwe_dimension(),
|
||||
self.output_glwe_size().to_glwe_dimension(),
|
||||
self.output_polynomial_size(),
|
||||
self.decomp_base_log(),
|
||||
self.decomp_level_count(),
|
||||
self.mantissa_size(),
|
||||
self.modulus().as_f64(),
|
||||
),
|
||||
};
|
||||
|
||||
let output_lwe_dimension = self
|
||||
|
||||
@@ -6,6 +6,9 @@ pub mod modulus_switch;
|
||||
pub mod traits;
|
||||
|
||||
pub use lwe_keyswitch::NoiseSimulationLweKeyswitchKey;
|
||||
pub use lwe_multi_bit_programmable_bootstrap::{
|
||||
NoiseSimulationLweMultiBitFourier128Bsk, NoiseSimulationLweMultiBitFourierBsk,
|
||||
};
|
||||
pub use lwe_packing_keyswitch::NoiseSimulationLwePackingKeyswitchKey;
|
||||
pub use lwe_programmable_bootstrap::{
|
||||
NoiseSimulationLweFourier128Bsk, NoiseSimulationLweFourierBsk,
|
||||
@@ -23,6 +26,9 @@ use crate::core_crypto::commons::parameters::{
|
||||
CiphertextModulusLog, GlweDimension, GlweSize, LweDimension, PolynomialSize,
|
||||
};
|
||||
|
||||
pub const PBS_FFT_64_MANTISSA_SIZE: f64 = 53.;
|
||||
pub const PBS_FFT_128_MANTISSA_SIZE: f64 = 104.;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum NoiseSimulationModulus {
|
||||
NativeU128,
|
||||
|
||||
@@ -31,3 +31,49 @@ pub trait LweMultiBitFft128BlindRotate<Input, Output, Accumulator> {
|
||||
side_resources: &mut Self::SideResources,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait AllocateLweMultiBitBootstrapResult {
|
||||
type Output;
|
||||
type SideResources;
|
||||
|
||||
fn allocate_lwe_multi_bit_bootstrap_result(
|
||||
&self,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) -> Self::Output;
|
||||
}
|
||||
|
||||
impl<T: AllocateLweMultiBitBlindRotateResult> AllocateLweMultiBitBootstrapResult for T {
|
||||
type Output = <T as AllocateLweMultiBitBlindRotateResult>::Output;
|
||||
type SideResources = <T as AllocateLweMultiBitBlindRotateResult>::SideResources;
|
||||
|
||||
fn allocate_lwe_multi_bit_bootstrap_result(
|
||||
&self,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) -> Self::Output {
|
||||
self.allocate_lwe_multi_bit_blind_rotate_result(side_resources)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LweMultiBitFftBootstrap<Input, Output, Accumulator> {
|
||||
type SideResources;
|
||||
|
||||
fn lwe_multi_bit_fft_bootstrap(
|
||||
&self,
|
||||
input: &Input,
|
||||
output: &mut Output,
|
||||
accumulator: &Accumulator,
|
||||
side_resources: &mut Self::SideResources,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait LweMultiBitFft128Bootstrap<Input, Output, Accumulator> {
|
||||
type SideResources;
|
||||
|
||||
fn lwe_multi_bit_fft_128_bootstrap(
|
||||
&self,
|
||||
input: &Input,
|
||||
output: &mut Output,
|
||||
accumulator: &Accumulator,
|
||||
side_resources: &mut Self::SideResources,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,7 +9,9 @@ pub mod scalar_mul;
|
||||
pub use add_sub::{LweUncorrelatedAdd, LweUncorrelatedSub};
|
||||
pub use lwe_keyswitch::{AllocateLweKeyswitchResult, LweKeyswitch};
|
||||
pub use lwe_multi_bit_programmable_bootstrap::{
|
||||
AllocateLweMultiBitBlindRotateResult, LweMultiBitFft128BlindRotate, LweMultiBitFftBlindRotate,
|
||||
AllocateLweMultiBitBlindRotateResult, AllocateLweMultiBitBootstrapResult,
|
||||
LweMultiBitFft128BlindRotate, LweMultiBitFft128Bootstrap, LweMultiBitFftBlindRotate,
|
||||
LweMultiBitFftBootstrap,
|
||||
};
|
||||
pub use lwe_packing_keyswitch::{AllocateLwePackingKeyswitchResult, LwePackingKeyswitch};
|
||||
pub use lwe_programmable_bootstrap::{
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::core_crypto::commons::noise_formulas::lwe_multi_bit_programmable_boot
|
||||
multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul,
|
||||
multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul,
|
||||
};
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::PBS_FFT_64_MANTISSA_SIZE;
|
||||
use crate::core_crypto::commons::noise_formulas::secure_noise::{
|
||||
minimal_lwe_variance_for_132_bits_security_gaussian,
|
||||
minimal_lwe_variance_for_132_bits_security_tuniform,
|
||||
@@ -58,6 +59,7 @@ where
|
||||
polynomial_size,
|
||||
pbs_decomposition_base_log,
|
||||
pbs_decomposition_level_count,
|
||||
PBS_FFT_64_MANTISSA_SIZE,
|
||||
modulus_as_f64,
|
||||
)
|
||||
}
|
||||
@@ -68,6 +70,7 @@ where
|
||||
polynomial_size,
|
||||
pbs_decomposition_base_log,
|
||||
pbs_decomposition_level_count,
|
||||
PBS_FFT_64_MANTISSA_SIZE,
|
||||
modulus_as_f64,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap::pbs_variance_132_bits_security_gaussian_fft_mul;
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::PBS_FFT_64_MANTISSA_SIZE;
|
||||
use crate::core_crypto::commons::noise_formulas::secure_noise::minimal_lwe_variance_for_132_bits_security_gaussian;
|
||||
use crate::core_crypto::commons::test_tools::{torus_modular_diff, variance};
|
||||
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
|
||||
@@ -49,6 +50,7 @@ where
|
||||
polynomial_size,
|
||||
pbs_decomposition_base_log,
|
||||
pbs_decomposition_level_count,
|
||||
PBS_FFT_64_MANTISSA_SIZE,
|
||||
modulus_as_f64,
|
||||
);
|
||||
|
||||
|
||||
@@ -10424,6 +10424,7 @@ pub unsafe fn unchecked_small_scalar_mul_integer_async(
|
||||
carry_modulus.0 as u32,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
@@ -10465,3 +10466,98 @@ pub unsafe fn extract_glwe_async<T: UnsignedInteger>(
|
||||
panic!("Unsupported integer size for CUDA GLWE extraction");
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_trivium_generate_keystream<T: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
keystream_output: &mut CudaRadixCiphertext,
|
||||
key: &CudaRadixCiphertext,
|
||||
iv: &CudaRadixCiphertext,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
pbs_type: PBSType,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
num_steps: u32,
|
||||
) {
|
||||
let mut keystream_degrees = keystream_output
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.0)
|
||||
.collect();
|
||||
let mut keystream_noise_levels = keystream_output
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
let mut cuda_ffi_keystream = prepare_cuda_radix_ffi(
|
||||
keystream_output,
|
||||
&mut keystream_degrees,
|
||||
&mut keystream_noise_levels,
|
||||
);
|
||||
|
||||
let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels);
|
||||
|
||||
let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels);
|
||||
|
||||
let num_inputs = (key.info.blocks.len() / 80) as u32;
|
||||
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_trivium_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
num_inputs,
|
||||
);
|
||||
|
||||
cuda_trivium_generate_keystream_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_keystream,
|
||||
&raw const cuda_ffi_key,
|
||||
&raw const cuda_ffi_iv,
|
||||
num_inputs,
|
||||
num_steps,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_trivium_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(keystream_output, &cuda_ffi_keystream);
|
||||
}
|
||||
|
||||
@@ -66,6 +66,7 @@ mod tests_noise_distribution;
|
||||
mod tests_signed;
|
||||
#[cfg(test)]
|
||||
mod tests_unsigned;
|
||||
mod trivium;
|
||||
|
||||
impl CudaServerKey {
|
||||
/// Create a trivial ciphertext filled with zeros on the GPU.
|
||||
|
||||
@@ -26,7 +26,7 @@ use crate::shortint::parameters::{AtomicPatternParameters, MetaParameters, Varia
|
||||
use crate::shortint::server_key::tests::noise_distribution::br_dp_ks_ms::br_dp_ks_any_ms;
|
||||
use crate::shortint::server_key::tests::noise_distribution::should_use_single_key_debug;
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
|
||||
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
|
||||
NoiseSimulationGenericBootstrapKey, NoiseSimulationGlwe, NoiseSimulationLwe,
|
||||
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulus,
|
||||
};
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::{
|
||||
@@ -360,7 +360,7 @@ fn noise_check_encrypt_br_dp_ks_ms_noise(params: MetaParameters) {
|
||||
let noise_simulation_modulus_switch_config =
|
||||
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
|
||||
let noise_simulation_bsk =
|
||||
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
|
||||
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
|
||||
let gpu_index = 0;
|
||||
let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::shortint::parameters::test_params::TEST_META_PARAM_CPU_2_2_KS_PBS_PKE
|
||||
use crate::shortint::parameters::{CompressionParameters, MetaParameters, Variance};
|
||||
use crate::shortint::server_key::tests::noise_distribution::br_dp_packingks_ms::br_dp_packing_ks_ms;
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
|
||||
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
|
||||
NoiseSimulationGenericBootstrapKey, NoiseSimulationGlwe, NoiseSimulationLwe,
|
||||
NoiseSimulationLwePackingKeyswitchKey, NoiseSimulationModulus,
|
||||
};
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::{
|
||||
@@ -412,7 +412,7 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_noise_gpu(meta_params: MetaParameters
|
||||
let cuda_compression_key = compressed_compression_key.decompress_to_cuda(&streams);
|
||||
|
||||
let noise_simulation_bsk =
|
||||
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
|
||||
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
|
||||
let noise_simulation_packing_key =
|
||||
NoiseSimulationLwePackingKeyswitchKey::new_from_comp_parameters(params, comp_params);
|
||||
|
||||
@@ -591,7 +591,7 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_pfail_gpu(meta_params: MetaParameters
|
||||
assert_eq!(original_carry_modulus.0, 4);
|
||||
|
||||
let noise_simulation_bsk =
|
||||
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
|
||||
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
|
||||
let noise_simulation_packing_key =
|
||||
NoiseSimulationLwePackingKeyswitchKey::new_from_comp_parameters(params, comp_params);
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use super::utils::noise_simulation::{CudaDynLwe, CudaSideResources};
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::{
|
||||
NoiseSimulationLweFourier128Bsk, NoiseSimulationLwePackingKeyswitchKey,
|
||||
};
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::NoiseSimulationLwePackingKeyswitchKey;
|
||||
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::{GlweCiphertext, LweCiphertextCount};
|
||||
@@ -28,8 +26,9 @@ use crate::shortint::server_key::tests::noise_distribution::dp_ks_pbs128_packing
|
||||
};
|
||||
use crate::shortint::server_key::tests::noise_distribution::should_use_single_key_debug;
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
|
||||
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
|
||||
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
|
||||
NoiseSimulationGenericBootstrapKey128, NoiseSimulationGlwe, NoiseSimulationLwe,
|
||||
NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey,
|
||||
NoiseSimulationModulusSwitchConfig,
|
||||
};
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::{
|
||||
mean_and_variance_check, DecryptionAndNoiseResult, NoiseSample,
|
||||
@@ -718,8 +717,10 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise_gpu(meta_params: M
|
||||
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(atomic_params);
|
||||
let noise_simulation_modulus_switch_config =
|
||||
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(atomic_params);
|
||||
let noise_simulation_bsk128 =
|
||||
NoiseSimulationLweFourier128Bsk::new_from_parameters(atomic_params, noise_squashing_params);
|
||||
let noise_simulation_bsk128 = NoiseSimulationGenericBootstrapKey128::new_from_parameters(
|
||||
atomic_params,
|
||||
noise_squashing_params,
|
||||
);
|
||||
let noise_simulation_packing_key =
|
||||
NoiseSimulationLwePackingKeyswitchKey::new_from_noise_squashing_parameters(
|
||||
noise_squashing_params,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
|
||||
AllocateCenteredBinaryShiftedStandardModSwitchResult,
|
||||
AllocateDriftTechniqueStandardModSwitchResult, AllocateLweBootstrapResult,
|
||||
AllocateLweKeyswitchResult, AllocateLwePackingKeyswitchResult, AllocateStandardModSwitchResult,
|
||||
CenteredBinaryShiftedStandardModSwitch, DriftTechniqueStandardModSwitch,
|
||||
LweClassicFftBootstrap, LweKeyswitch, ScalarMul, StandardModSwitch,
|
||||
AllocateLweKeyswitchResult, AllocateLwePackingKeyswitchResult, AllocateMultiBitModSwitchResult,
|
||||
AllocateStandardModSwitchResult, CenteredBinaryShiftedStandardModSwitch,
|
||||
DriftTechniqueStandardModSwitch, LweClassicFft128Bootstrap, LweClassicFftBootstrap,
|
||||
LweKeyswitch, MultiBitModSwitch, ScalarMul, StandardModSwitch,
|
||||
};
|
||||
use crate::core_crypto::commons::noise_formulas::noise_simulation::{
|
||||
NoiseSimulationLweFourier128Bsk, NoiseSimulationLweFourierBsk,
|
||||
@@ -25,8 +26,12 @@ use crate::integer::gpu::server_key::{
|
||||
use crate::integer::gpu::{
|
||||
cuda_centered_modulus_switch_64, unchecked_small_scalar_mul_integer_async, CudaStreams,
|
||||
};
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::NoiseSimulationModulusSwitchConfig;
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::traits::LwePackingKeyswitch;
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
|
||||
NoiseSimulationGenericBootstrapKey, NoiseSimulationModulusSwitchConfig,
|
||||
};
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::traits::{
|
||||
LweGenericBlindRotate128, LweGenericBootstrap, LwePackingKeyswitch,
|
||||
};
|
||||
/// Side resources for CUDA operations in noise simulation
|
||||
#[derive(Clone)]
|
||||
pub struct CudaSideResources {
|
||||
@@ -228,19 +233,22 @@ impl NoiseSimulationLweFourierBsk {
|
||||
&& decomp_base_log == bsk_decomp_base_log
|
||||
&& decomp_level_count == bsk_decomp_level_count
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(cuda_mb_bsk) => {
|
||||
let bsk_input_lwe_dimension = cuda_mb_bsk.input_lwe_dimension();
|
||||
let bsk_glwe_size = cuda_mb_bsk.glwe_dimension().to_glwe_size();
|
||||
let bsk_polynomial_size = cuda_mb_bsk.polynomial_size();
|
||||
let bsk_decomp_base_log = cuda_mb_bsk.decomp_base_log();
|
||||
let bsk_decomp_level_count = cuda_mb_bsk.decomp_level_count();
|
||||
// MultiBit key cannot match classic key
|
||||
CudaBootstrappingKey::MultiBit(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input_lwe_dimension == bsk_input_lwe_dimension
|
||||
&& glwe_size == bsk_glwe_size
|
||||
&& polynomial_size == bsk_polynomial_size
|
||||
&& decomp_base_log == bsk_decomp_base_log
|
||||
&& decomp_level_count == bsk_decomp_level_count
|
||||
// Extensions for NoiseSimulationGenericBootstrapKey to support GPU operations
|
||||
impl NoiseSimulationGenericBootstrapKey {
|
||||
pub fn matches_actual_bsk_gpu(&self, lwe_bsk: &CudaBootstrappingKey<u64>) -> bool {
|
||||
match self {
|
||||
Self::Classic(noise_simulation_lwe_fourier_bsk) => {
|
||||
noise_simulation_lwe_fourier_bsk.matches_actual_bsk_gpu(lwe_bsk)
|
||||
}
|
||||
Self::MultiBit(_) => todo!(
|
||||
"Implement the matching for NoiseSimulationLweMultiBitFourierBsk and forward here"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -743,12 +751,8 @@ impl AllocateLweBootstrapResult for CudaGlweCiphertextList<u128> {
|
||||
}
|
||||
|
||||
// Implement LweClassicFft128Bootstrap for CudaNoiseSquashingKey using 128-bit PBS CUDA function
|
||||
impl
|
||||
crate::core_crypto::commons::noise_formulas::noise_simulation::traits::LweClassicFft128Bootstrap<
|
||||
CudaDynLwe,
|
||||
CudaDynLwe,
|
||||
CudaGlweCiphertextList<u128>,
|
||||
> for crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey
|
||||
impl LweClassicFft128Bootstrap<CudaDynLwe, CudaDynLwe, CudaGlweCiphertextList<u128>>
|
||||
for CudaNoiseSquashingKey
|
||||
{
|
||||
type SideResources = CudaSideResources;
|
||||
|
||||
@@ -931,3 +935,83 @@ impl LwePackingKeyswitch<[&CudaDynLwe], CudaGlweCiphertextList<u128>>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Multi bit and generic extensions
|
||||
impl LweGenericBootstrap<CudaDynLwe, CudaDynLwe, CudaGlweCiphertextList<u64>> for CudaServerKey {
|
||||
type SideResources = CudaSideResources;
|
||||
|
||||
fn lwe_generic_bootstrap(
|
||||
&self,
|
||||
input: &CudaDynLwe,
|
||||
output: &mut CudaDynLwe,
|
||||
accumulator: &CudaGlweCiphertextList<u64>,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
match self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(_) => {
|
||||
self.lwe_classic_fft_pbs(input, output, accumulator, side_resources);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(_) => {
|
||||
todo!("TODO: this currently only manages classic PBS")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AllocateMultiBitModSwitchResult for CudaDynLwe {
|
||||
type Output = Self;
|
||||
type SideResources = CudaSideResources;
|
||||
|
||||
fn allocate_multi_bit_mod_switch_result(
|
||||
&self,
|
||||
_side_resources: &mut Self::SideResources,
|
||||
) -> Self::Output {
|
||||
todo!(
|
||||
"TODO: the output type likely needs to be a specialized enum CudaDynModSwitchedLwe\n\
|
||||
See shortint CPU impls, the standard mod switch results likely \
|
||||
need an update for the output type"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiBitModSwitch<Self> for CudaDynLwe {
|
||||
type SideResources = CudaSideResources;
|
||||
|
||||
fn multi_bit_mod_switch(
|
||||
&self,
|
||||
_grouping_factor: LweBskGroupingFactor,
|
||||
_output_modulus_log: CiphertextModulusLog,
|
||||
_output: &mut Self,
|
||||
_side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
todo!(
|
||||
"TODO: the output type likely needs to be a specialized enum CudaDynModSwitchedLwe\n\
|
||||
See shortint CPU impls, the standard mod switch results likely \
|
||||
need an update for the output type"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl LweGenericBlindRotate128<CudaDynLwe, CudaDynLwe, CudaGlweCiphertextList<u128>>
|
||||
for CudaNoiseSquashingKey
|
||||
{
|
||||
type SideResources = CudaSideResources;
|
||||
|
||||
fn lwe_generic_blind_rotate_128(
|
||||
&self,
|
||||
input: &CudaDynLwe,
|
||||
output: &mut CudaDynLwe,
|
||||
accumulator: &CudaGlweCiphertextList<u128>,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
match self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(_) => {
|
||||
self.lwe_classic_fft_128_pbs(input, output, accumulator, side_resources)
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(_) => todo!(
|
||||
"CPU manages this by taking a modswitched type to be able to apply \
|
||||
the blind rotate correctly without redoing the modswitch, to adapt for the GPU case"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ pub(crate) mod test_scalar_shift;
|
||||
pub(crate) mod test_scalar_sub;
|
||||
pub(crate) mod test_shift;
|
||||
pub(crate) mod test_sub;
|
||||
pub(crate) mod test_trivium;
|
||||
pub(crate) mod test_vector_comparisons;
|
||||
pub(crate) mod test_vector_find;
|
||||
|
||||
@@ -85,6 +86,50 @@ impl<F> GpuFunctionExecutor<F> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
|
||||
crate::Result<RadixCiphertext>,
|
||||
> for GpuFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
usize,
|
||||
&CudaStreams,
|
||||
) -> crate::Result<CudaUnsignedRadixCiphertext>,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a RadixCiphertext, &'a RadixCiphertext, usize),
|
||||
) -> crate::Result<RadixCiphertext> {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(
|
||||
&context.sks,
|
||||
&d_ctxt_1,
|
||||
&d_ctxt_2,
|
||||
input.2,
|
||||
&context.streams,
|
||||
)?;
|
||||
|
||||
Ok(gpu_result.to_radix_ciphertext(&context.streams))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize),
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parameterized_test, GpuFunctionExecutor,
|
||||
};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::test_trivium::{
|
||||
trivium_comparison_test, trivium_test_vector_1_test, trivium_test_vector_2_test,
|
||||
trivium_test_vector_3_test, trivium_test_vector_4_test,
|
||||
};
|
||||
use crate::shortint::parameters::{
|
||||
TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
};
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_test_vector_1 {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_test_vector_2 {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_test_vector_3 {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_test_vector_4 {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_comparison {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
fn integer_trivium_test_vector_1<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_test_vector_1_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_trivium_test_vector_2<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_test_vector_2_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_trivium_test_vector_3<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_test_vector_3_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_trivium_test_vector_4<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_test_vector_4_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_trivium_comparison<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_comparison_test(param, executor);
|
||||
}
|
||||
111
tfhe/src/integer/gpu/server_key/radix/trivium.rs
Normal file
111
tfhe/src/integer/gpu/server_key/radix/trivium.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{cuda_backend_trivium_generate_keystream, LweBskGroupingFactor, PBSType};
|
||||
|
||||
impl CudaServerKey {
|
||||
/// Generates a Trivium keystream homomorphically on the GPU.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The encrypted secret key.
|
||||
/// * `iv` - The encrypted initialization vector.
|
||||
/// * `num_steps` - The number of keystream bits to generate per input.
|
||||
/// * `streams` - The CUDA streams to use for execution.
|
||||
pub fn trivium_generate_keystream(
|
||||
&self,
|
||||
key: &CudaUnsignedRadixCiphertext,
|
||||
iv: &CudaUnsignedRadixCiphertext,
|
||||
num_steps: usize,
|
||||
streams: &CudaStreams,
|
||||
) -> crate::Result<CudaUnsignedRadixCiphertext> {
|
||||
let num_key_bits = 80;
|
||||
let num_iv_bits = 80;
|
||||
let batch_size = 64;
|
||||
|
||||
if key.as_ref().d_blocks.lwe_ciphertext_count().0 != num_key_bits {
|
||||
return Err(format!(
|
||||
"Input key must contain {} encrypted bits, but contains {}",
|
||||
num_key_bits,
|
||||
key.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
)
|
||||
.into());
|
||||
}
|
||||
if iv.as_ref().d_blocks.lwe_ciphertext_count().0 != num_iv_bits {
|
||||
return Err(format!(
|
||||
"Input IV must contain {} encrypted bits, but contains {}",
|
||||
num_iv_bits,
|
||||
iv.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
if !num_steps.is_multiple_of(batch_size) {
|
||||
return Err(format!(
|
||||
"The number of steps must be a multiple of {batch_size}, but is {num_steps}"
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let num_output_bits = num_steps;
|
||||
let mut keystream: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(num_output_bits, streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_trivium_generate_keystream(
|
||||
streams,
|
||||
keystream.as_mut(),
|
||||
key.as_ref(),
|
||||
iv.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
num_steps as u32,
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_trivium_generate_keystream(
|
||||
streams,
|
||||
keystream.as_mut(),
|
||||
key.as_ref(),
|
||||
iv.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
num_steps as u32,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(keystream)
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,8 @@ pub(crate) mod test_shift;
|
||||
pub(crate) mod test_slice;
|
||||
pub(crate) mod test_sub;
|
||||
pub(crate) mod test_sum;
|
||||
#[cfg(feature = "gpu")]
|
||||
pub(crate) mod test_trivium;
|
||||
pub(crate) mod test_vector_comparisons;
|
||||
pub(crate) mod test_vector_find;
|
||||
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
|
||||
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey};
|
||||
use crate::shortint::parameters::TestParameters;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn encrypt_bits(cks: &RadixClientKey, bits: &[u64]) -> RadixCiphertext {
|
||||
RadixCiphertext::from(
|
||||
bits.iter()
|
||||
.map(|&bit| cks.encrypt_one_block(bit))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
fn decrypt_bits(cks: &RadixClientKey, ct: &RadixCiphertext) -> Vec<u8> {
|
||||
ct.blocks
|
||||
.iter()
|
||||
.map(|block| cks.decrypt_one_block(block) as u8)
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct TriviumRef {
|
||||
a: Vec<u8>,
|
||||
b: Vec<u8>,
|
||||
c: Vec<u8>,
|
||||
}
|
||||
|
||||
impl TriviumRef {
|
||||
fn new(key: &[u8], iv: &[u8]) -> Self {
|
||||
let mut a = vec![0u8; 93];
|
||||
let mut b = vec![0u8; 84];
|
||||
let mut c = vec![0u8; 111];
|
||||
|
||||
for i in 0..80 {
|
||||
a[i] = key[79 - i];
|
||||
b[i] = iv[79 - i];
|
||||
}
|
||||
|
||||
c[108] = 1;
|
||||
c[109] = 1;
|
||||
c[110] = 1;
|
||||
|
||||
let mut triv = Self { a, b, c };
|
||||
for _ in 0..(18 * 64) {
|
||||
triv.next();
|
||||
}
|
||||
triv
|
||||
}
|
||||
|
||||
fn next(&mut self) -> u8 {
|
||||
let t1 = self.a[65] ^ self.a[92];
|
||||
let t2 = self.b[68] ^ self.b[83];
|
||||
let t3 = self.c[65] ^ self.c[110];
|
||||
|
||||
let out = t1 ^ t2 ^ t3;
|
||||
|
||||
let a_in = t3 ^ self.a[68] ^ (self.c[108] & self.c[109]);
|
||||
let b_in = t1 ^ self.b[77] ^ (self.a[90] & self.a[91]);
|
||||
let c_in = t2 ^ self.c[86] ^ (self.b[81] & self.b[82]);
|
||||
|
||||
self.a.pop();
|
||||
self.a.insert(0, a_in);
|
||||
self.b.pop();
|
||||
self.b.insert(0, b_in);
|
||||
self.c.pop();
|
||||
self.c.insert(0, c_in);
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trivium_ref_consistency() {
|
||||
let key = vec![0u8; 80];
|
||||
let iv = vec![0u8; 80];
|
||||
|
||||
let expected_hex = "FBE0BF265859051B";
|
||||
|
||||
let mut trivium = TriviumRef::new(&key, &iv);
|
||||
let mut output_bits = Vec::new();
|
||||
for _ in 0..64 {
|
||||
output_bits.push(trivium.next());
|
||||
}
|
||||
|
||||
let packed = get_hexadecimal_string_from_lsb_first_stream(&output_bits);
|
||||
|
||||
assert_eq!(&packed[0..16], expected_hex);
|
||||
}
|
||||
|
||||
fn get_hexadecimal_string_from_lsb_first_stream(a: &[u8]) -> String {
|
||||
assert!(a.len().is_multiple_of(8));
|
||||
let mut hexadecimal = String::new();
|
||||
for test in a.chunks(8) {
|
||||
let to_hex = |chunk: &[u8]| -> char {
|
||||
let mut val = 0u8;
|
||||
if chunk[0] == 1 {
|
||||
val |= 1;
|
||||
}
|
||||
if chunk[1] == 1 {
|
||||
val |= 2;
|
||||
}
|
||||
if chunk[2] == 1 {
|
||||
val |= 4;
|
||||
}
|
||||
if chunk[3] == 1 {
|
||||
val |= 8;
|
||||
}
|
||||
|
||||
match val {
|
||||
0..=9 => (val + b'0') as char,
|
||||
10..=15 => (val - 10 + b'A') as char,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
hexadecimal.push(to_hex(&test[4..8]));
|
||||
hexadecimal.push(to_hex(&test[0..4]));
|
||||
}
|
||||
hexadecimal
|
||||
}
|
||||
|
||||
pub fn trivium_test_vector_1_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
|
||||
crate::Result<RadixCiphertext>,
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let key = vec![0u64; 80];
|
||||
let iv = vec![0u64; 80];
|
||||
|
||||
let expected_output_0_63 = "FBE0BF265859051B517A2E4E239FC97F563203161907CF2DE7A8790FA1B2E9CDF75292030268B7382B4C1A759AA2599A285549986E74805903801A4CB5A5D4F2";
|
||||
|
||||
let ct_key = encrypt_bits(&cks, &key);
|
||||
let ct_iv = encrypt_bits(&cks, &iv);
|
||||
|
||||
let num_steps = 512;
|
||||
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
|
||||
|
||||
let decrypted_bits = decrypt_bits(&cks, &output_radix);
|
||||
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
|
||||
|
||||
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
|
||||
}
|
||||
|
||||
pub fn trivium_test_vector_2_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
|
||||
crate::Result<RadixCiphertext>,
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let mut key = vec![0u64; 80];
|
||||
let iv = vec![0u64; 80];
|
||||
key[7] = 1;
|
||||
|
||||
let expected_output_0_63 = "38EB86FF730D7A9CAF8DF13A4420540DBB7B651464C87501552041C249F29A64D2FBF515610921EBE06C8F92CECF7F8098FF20CCCC6A62B97BE8EF7454FC80F9";
|
||||
|
||||
let ct_key = encrypt_bits(&cks, &key);
|
||||
let ct_iv = encrypt_bits(&cks, &iv);
|
||||
|
||||
let num_steps = 512;
|
||||
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
|
||||
|
||||
let decrypted_bits = decrypt_bits(&cks, &output_radix);
|
||||
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
|
||||
|
||||
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
|
||||
}
|
||||
|
||||
pub fn trivium_test_vector_3_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
|
||||
crate::Result<RadixCiphertext>,
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let key = vec![0u64; 80];
|
||||
let mut iv = vec![0u64; 80];
|
||||
iv[7] = 1;
|
||||
|
||||
let expected_output_0_63 = "F8901736640549E3BA7D42EA2D07B9F49233C18D773008BD755585B1A8CBAB86C1E9A9B91F1AD33483FD6EE3696D659C9374260456A36AAE11F033A519CBD5D7";
|
||||
|
||||
let ct_key = encrypt_bits(&cks, &key);
|
||||
let ct_iv = encrypt_bits(&cks, &iv);
|
||||
|
||||
let num_steps = 512;
|
||||
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
|
||||
|
||||
let decrypted_bits = decrypt_bits(&cks, &output_radix);
|
||||
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
|
||||
|
||||
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
|
||||
}
|
||||
|
||||
pub fn trivium_test_vector_4_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
|
||||
crate::Result<RadixCiphertext>,
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let key_string = "0053A6F94C9FF24598EB";
|
||||
let mut key = vec![0u64; 80];
|
||||
for i in (0..key_string.len()).step_by(2) {
|
||||
let mut val = u8::from_str_radix(&key_string[i..i + 2], 16).unwrap();
|
||||
for j in 0..8 {
|
||||
key[8 * (i >> 1) + j] = (val % 2) as u64;
|
||||
val >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
let iv_string = "0D74DB42A91077DE45AC";
|
||||
let mut iv = vec![0u64; 80];
|
||||
for i in (0..iv_string.len()).step_by(2) {
|
||||
let mut val = u8::from_str_radix(&iv_string[i..i + 2], 16).unwrap();
|
||||
for j in 0..8 {
|
||||
iv[8 * (i >> 1) + j] = (val % 2) as u64;
|
||||
val >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
let expected_output_0_63 = "F4CD954A717F26A7D6930830C4E7CF0819F80E03F25F342C64ADC66ABA7F8A8E6EAA49F23632AE3CD41A7BD290A0132F81C6D4043B6E397D7388F3A03B5FE358";
|
||||
|
||||
let ct_key = encrypt_bits(&cks, &key);
|
||||
let ct_iv = encrypt_bits(&cks, &iv);
|
||||
|
||||
let num_steps = 512;
|
||||
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
|
||||
|
||||
let decrypted_bits = decrypt_bits(&cks, &output_radix);
|
||||
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
|
||||
|
||||
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
|
||||
}
|
||||
|
||||
pub fn trivium_comparison_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
|
||||
crate::Result<RadixCiphertext>,
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let num_runs = 5;
|
||||
let num_steps = 512;
|
||||
|
||||
for i in 0..num_runs {
|
||||
let mut rng = rand::thread_rng();
|
||||
let plain_key: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();
|
||||
let plain_iv: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();
|
||||
|
||||
let key_bits_u64: Vec<u64> = plain_key.iter().map(|&x| x as u64).collect();
|
||||
let iv_bits_u64: Vec<u64> = plain_iv.iter().map(|&x| x as u64).collect();
|
||||
|
||||
let ct_key = encrypt_bits(&cks, &key_bits_u64);
|
||||
let ct_iv = encrypt_bits(&cks, &iv_bits_u64);
|
||||
|
||||
let mut cpu_trivium = TriviumRef::new(&plain_key, &plain_iv);
|
||||
let mut cpu_output = Vec::with_capacity(num_steps);
|
||||
for _ in 0..num_steps {
|
||||
cpu_output.push(cpu_trivium.next());
|
||||
}
|
||||
|
||||
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
|
||||
let fhe_output = decrypt_bits(&cks, &output_radix);
|
||||
|
||||
assert_eq!(cpu_output.len(), fhe_output.len());
|
||||
assert_eq!(cpu_output, fhe_output, "Mismatch at iteration {i}");
|
||||
}
|
||||
}
|
||||
@@ -624,6 +624,24 @@ impl AtomicPatternParameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_deterministic_execution(&mut self, do_it: bool) {
|
||||
match self {
|
||||
Self::Standard(pbsparameters) => match pbsparameters {
|
||||
PBSParameters::PBS(_) => (),
|
||||
PBSParameters::MultiBitPBS(multi_bit_pbsparameters) => {
|
||||
multi_bit_pbsparameters.deterministic_execution = do_it
|
||||
}
|
||||
},
|
||||
Self::KeySwitch32(_) => (),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_deterministic_execution(mut self) -> Self {
|
||||
self.set_deterministic_execution(true);
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl ParameterSetConformant for AtomicPatternServerKey {
|
||||
|
||||
@@ -36,6 +36,7 @@ pub(crate) fn compute_delta<Scalar: UnsignedInteger + CastFrom<u64>>(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) struct ShortintEncoding<Scalar: UnsignedInteger> {
|
||||
pub(crate) ciphertext_modulus: CoreCiphertextModulus<Scalar>,
|
||||
pub(crate) message_modulus: MessageModulus,
|
||||
|
||||
@@ -99,6 +99,19 @@ impl NoiseSquashingParameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn set_deterministic_execution(&mut self, do_it: bool) {
|
||||
match self {
|
||||
Self::Classic(_) => (),
|
||||
Self::MultiBit(noise_squashing_multi_bit_parameters) => {
|
||||
noise_squashing_multi_bit_parameters.deterministic_execution = do_it
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn with_deterministic_execution(mut self) -> Self {
|
||||
self.set_deterministic_execution(true);
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use super::current_params::meta::cpu::*;
|
||||
use super::current_params::meta::gpu::*;
|
||||
use super::current_params::*;
|
||||
use super::{
|
||||
AtomicPatternParameters, ClassicPBSParameters, CompactPublicKeyEncryptionParameters,
|
||||
@@ -229,6 +230,7 @@ pub const TEST_PARAM_NOISE_SQUASHING_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFOR
|
||||
pub const TEST_PARAM_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128:
|
||||
NoiseSquashingParameters = V1_5_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
|
||||
// Meta params
|
||||
pub const TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128: MetaParameters =
|
||||
V1_5_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128;
|
||||
|
||||
@@ -237,3 +239,8 @@ pub const TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128: Met
|
||||
|
||||
pub const TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128: MetaParameters =
|
||||
V1_5_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128;
|
||||
|
||||
// GPU params we want to check for full scenario
|
||||
pub const TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128:
|
||||
MetaParameters =
|
||||
V1_5_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::dp_ks_ms::dp_ks_any_ms;
|
||||
use super::utils::noise_simulation::{
|
||||
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
|
||||
NoiseSimulationGenericBootstrapKey, NoiseSimulationGlwe, NoiseSimulationLwe,
|
||||
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
|
||||
};
|
||||
use super::utils::traits::*;
|
||||
@@ -13,20 +13,19 @@ use crate::core_crypto::commons::dispersion::Variance;
|
||||
use crate::core_crypto::commons::parameters::CiphertextModulusLog;
|
||||
use crate::shortint::atomic_pattern::AtomicPattern;
|
||||
use crate::shortint::ciphertext::NoiseLevel;
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::encoding::ShortintEncoding;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::parameters::test_params::{
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
};
|
||||
use crate::shortint::parameters::{AtomicPatternParameters, CarryModulus, MetaParameters};
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::NoiseSimulationModulus;
|
||||
use crate::shortint::server_key::tests::parameterized_test::create_parameterized_test;
|
||||
use crate::shortint::server_key::ServerKey;
|
||||
use crate::shortint::{Ciphertext, PaddingBit};
|
||||
use crate::shortint::Ciphertext;
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -63,7 +62,7 @@ pub fn br_dp_ks_any_ms<
|
||||
where
|
||||
// We need to be able to allocate the result and bootstrap the Input
|
||||
Accumulator: AllocateLweBootstrapResult<Output = PBSResult, SideResources = Resources>,
|
||||
PBSKey: LweClassicFftBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources>,
|
||||
PBSKey: LweGenericBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources>,
|
||||
// Result of the PBS/Blind rotate needs to be multipliable by the scalar
|
||||
PBSResult: ScalarMul<DPScalar, Output = ScalarMulResult, SideResources = Resources>,
|
||||
// We need to be able to allocate the result and keyswitch the result of the ScalarMul
|
||||
@@ -74,7 +73,9 @@ where
|
||||
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
|
||||
Output = MsResult,
|
||||
SideResources = Resources,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
|
||||
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
|
||||
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
|
||||
// We need to be able to allocate the result and apply drift technique + mod switch it
|
||||
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
|
||||
AfterDriftOutput = DriftTechniqueResult,
|
||||
@@ -88,7 +89,7 @@ where
|
||||
>,
|
||||
{
|
||||
let mut pbs_result = accumulator.allocate_lwe_bootstrap_result(side_resources);
|
||||
bsk.lwe_classic_fft_pbs(&input, &mut pbs_result, accumulator, side_resources);
|
||||
bsk.lwe_generic_bootstrap(&input, &mut pbs_result, accumulator, side_resources);
|
||||
let (pbs_result, after_dp, ks_result, drift_technique_result, ms_result) = dp_ks_any_ms(
|
||||
pbs_result,
|
||||
scalar,
|
||||
@@ -111,7 +112,9 @@ where
|
||||
/// Test function to verify that the noise checking tools match the actual atomic patterns
|
||||
/// implemented in shortint
|
||||
fn sanity_check_encrypt_br_dp_ks_pbs(meta_params: MetaParameters) {
|
||||
let params = meta_params.compute_parameters;
|
||||
let params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let cks = ClientKey::new(params);
|
||||
let sks = ServerKey::new(&cks);
|
||||
|
||||
@@ -139,7 +142,8 @@ fn sanity_check_encrypt_br_dp_ks_pbs(meta_params: MetaParameters) {
|
||||
|
||||
// Complete the AP by computing the PBS to match shortint
|
||||
let mut pbs_result = id_lut.allocate_lwe_bootstrap_result(&mut ());
|
||||
sks.lwe_classic_fft_pbs(&ms_result, &mut pbs_result, &id_lut, &mut ());
|
||||
|
||||
sks.apply_generic_blind_rotation(&ms_result, &mut pbs_result, &id_lut);
|
||||
|
||||
// Shortint APIs are not granular enough to compare ciphertexts at the MS level
|
||||
// and inject arbitrary LWEs as input to the blind rotate step of the PBS.
|
||||
@@ -169,6 +173,7 @@ create_parameterized_test!(sanity_check_encrypt_br_dp_ks_pbs {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
fn encrypt_br_dp_ks_any_ms_inner_helper(
|
||||
@@ -217,113 +222,21 @@ fn encrypt_br_dp_ks_any_ms_inner_helper(
|
||||
|
||||
let before_ms = after_drift.as_ref().unwrap_or(&after_ks);
|
||||
|
||||
match &cks.atomic_pattern {
|
||||
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
|
||||
let params = standard_atomic_pattern_client_key.parameters;
|
||||
let output_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let large_lwe_secret_key = standard_atomic_pattern_client_key.large_lwe_secret_key();
|
||||
let small_lwe_secret_key = standard_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_br.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_dp.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
)
|
||||
}
|
||||
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
|
||||
let msg_u32: u32 = msg.try_into().unwrap();
|
||||
let params = ks32_atomic_pattern_client_key.parameters;
|
||||
let small_key_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let big_key_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let large_lwe_secret_key = ks32_atomic_pattern_client_key.large_lwe_secret_key();
|
||||
let small_lwe_secret_key = ks32_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&small_key_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_br.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&big_key_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_dp.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&big_key_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&small_key_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&small_key_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&small_key_encoding,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
let large_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
|
||||
let small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
|
||||
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&input, &small_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_br, &large_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_dp, &large_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_ks, &small_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(before_ms, &small_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
|
||||
&after_ms,
|
||||
&small_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn encrypt_br_dp_ks_any_ms_noise_helper(
|
||||
@@ -391,7 +304,9 @@ fn encrypt_br_dp_ks_any_ms_pfail_helper(
|
||||
}
|
||||
|
||||
fn noise_check_encrypt_br_dp_ks_ms_noise(meta_params: MetaParameters) {
|
||||
let params = meta_params.compute_parameters;
|
||||
let params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let cks = ClientKey::new(params);
|
||||
let sks = ServerKey::new(&cks);
|
||||
|
||||
@@ -400,7 +315,7 @@ fn noise_check_encrypt_br_dp_ks_ms_noise(meta_params: MetaParameters) {
|
||||
let noise_simulation_modulus_switch_config =
|
||||
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
|
||||
let noise_simulation_bsk =
|
||||
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
|
||||
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
|
||||
|
||||
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
|
||||
let br_input_modulus_log = sks.br_input_modulus_log();
|
||||
@@ -502,11 +417,14 @@ create_parameterized_test!(noise_check_encrypt_br_dp_ks_ms_noise {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
fn noise_check_encrypt_br_dp_ks_ms_pfail(meta_params: MetaParameters) {
|
||||
let (pfail_test_meta, params) = {
|
||||
let mut ap_params = meta_params.compute_parameters;
|
||||
let mut ap_params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
|
||||
let original_message_modulus = ap_params.message_modulus();
|
||||
let original_carry_modulus = ap_params.carry_modulus();
|
||||
@@ -568,4 +486,5 @@ create_parameterized_test!(noise_check_encrypt_br_dp_ks_ms_pfail {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
@@ -8,13 +8,13 @@ use super::utils::{
|
||||
use super::{should_run_short_pfail_tests_debug, should_use_single_key_debug};
|
||||
use crate::shortint::atomic_pattern::AtomicPattern;
|
||||
use crate::shortint::ciphertext::{Ciphertext, Degree, NoiseLevel};
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::list_compression::{CompressionKey, CompressionPrivateKeys};
|
||||
use crate::shortint::parameters::test_params::{
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
};
|
||||
use crate::shortint::parameters::{
|
||||
AtomicPatternParameters, CarryModulus, CiphertextModulusLog, CompressionParameters,
|
||||
@@ -53,8 +53,7 @@ pub fn br_dp_packing_ks_ms<
|
||||
)
|
||||
where
|
||||
Accumulator: AllocateLweBootstrapResult<Output = PBSResult, SideResources = Resources> + Sync,
|
||||
PBSKey:
|
||||
LweClassicFftBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources> + Sync,
|
||||
PBSKey: LweGenericBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources> + Sync,
|
||||
PBSResult: ScalarMul<DPScalar, Output = DPResult, SideResources = Resources> + Send,
|
||||
PackingKsk: AllocateLwePackingKeyswitchResult<Output = PackingKsResult, SideResources = Resources>
|
||||
+ for<'a> LwePackingKeyswitch<[&'a DPResult], PackingKsResult, SideResources = Resources>,
|
||||
@@ -70,7 +69,7 @@ where
|
||||
.zip(side_resources.par_iter_mut())
|
||||
.map(|(input, side_resources)| {
|
||||
let mut pbs_result = accumulator.allocate_lwe_bootstrap_result(side_resources);
|
||||
bsk.lwe_classic_fft_pbs(&input, &mut pbs_result, accumulator, side_resources);
|
||||
bsk.lwe_generic_bootstrap(&input, &mut pbs_result, accumulator, side_resources);
|
||||
let after_dp = pbs_result.scalar_mul(scalar, side_resources);
|
||||
|
||||
(input, pbs_result, after_dp)
|
||||
@@ -98,7 +97,9 @@ where
|
||||
|
||||
fn sanity_check_encrypt_br_dp_packing_ks_ms(meta_params: MetaParameters) {
|
||||
let (params, comp_params) = (
|
||||
meta_params.compute_parameters,
|
||||
meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution(),
|
||||
meta_params.compression_parameters.unwrap(),
|
||||
);
|
||||
let cks = ClientKey::new(params);
|
||||
@@ -161,6 +162,7 @@ fn sanity_check_encrypt_br_dp_packing_ks_ms(meta_params: MetaParameters) {
|
||||
create_parameterized_test!(sanity_check_encrypt_br_dp_packing_ks_ms {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
@@ -237,57 +239,41 @@ fn encrypt_br_dp_packing_ks_ms_inner_helper(
|
||||
&mut side_resources,
|
||||
);
|
||||
|
||||
let compute_large_lwe_secret_key = cks.encryption_key();
|
||||
let compression_glwe_secret_key = &compression_private_key.post_packing_ks_key;
|
||||
|
||||
let compute_encoding = sks.encoding(PaddingBit::Yes);
|
||||
let compression_encoding = ShortintEncoding {
|
||||
carry_modulus: CarryModulus(1),
|
||||
..compute_encoding
|
||||
};
|
||||
|
||||
let compute_large_lwe_secret_key = cks.encryption_key();
|
||||
let compute_large_lwe_secret_key_with_compression_encoding_dyn = DynLweSecretKeyView::U64 {
|
||||
key: compute_large_lwe_secret_key,
|
||||
encoding: compression_encoding,
|
||||
};
|
||||
let compression_glwe_secret_key = &compression_private_key.post_packing_ks_key;
|
||||
|
||||
(
|
||||
before_packing
|
||||
.into_iter()
|
||||
.map(|(input, pbs_result, dp_result)| {
|
||||
(
|
||||
match &cks.atomic_pattern {
|
||||
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
input.as_ref_64(),
|
||||
&standard_atomic_pattern_client_key.lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
)
|
||||
}
|
||||
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
|
||||
let ks32_params = ks32_atomic_pattern_client_key.parameters;
|
||||
let compute_encoding_32 = ShortintEncoding {
|
||||
ciphertext_modulus: ks32_params.post_keyswitch_ciphertext_modulus,
|
||||
message_modulus: ks32_params.message_modulus,
|
||||
carry_modulus: ks32_params.carry_modulus,
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let compute_large_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
|
||||
let compute_small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
|
||||
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
input.as_ref_32(),
|
||||
&ks32_atomic_pattern_client_key.lwe_secret_key,
|
||||
msg.try_into().unwrap(),
|
||||
&compute_encoding_32,
|
||||
)
|
||||
}
|
||||
},
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
pbs_result.as_ref_64(),
|
||||
&compute_large_lwe_secret_key,
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&input,
|
||||
&compute_small_lwe_secret_key_dyn,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
dp_result.as_ref_64(),
|
||||
&compute_large_lwe_secret_key,
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&pbs_result,
|
||||
&compute_large_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&dp_result,
|
||||
&compute_large_lwe_secret_key_with_compression_encoding_dyn,
|
||||
msg,
|
||||
&compression_encoding,
|
||||
),
|
||||
)
|
||||
})
|
||||
@@ -392,7 +378,9 @@ fn encrypt_br_dp_packing_ks_ms_pfail_helper(
|
||||
|
||||
fn noise_check_encrypt_br_dp_packing_ks_ms_noise(meta_params: MetaParameters) {
|
||||
let (params, comp_params) = (
|
||||
meta_params.compute_parameters,
|
||||
meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution(),
|
||||
meta_params.compression_parameters.unwrap(),
|
||||
);
|
||||
let cks = ClientKey::new(params);
|
||||
@@ -401,7 +389,7 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_noise(meta_params: MetaParameters) {
|
||||
let compression_key = cks.new_compression_key(&compression_private_key);
|
||||
|
||||
let noise_simulation_bsk =
|
||||
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
|
||||
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
|
||||
let noise_simulation_packing_key =
|
||||
NoiseSimulationLwePackingKeyswitchKey::new_from_comp_parameters(params, comp_params);
|
||||
|
||||
@@ -520,12 +508,15 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_noise(meta_params: MetaParameters) {
|
||||
create_parameterized_test!(noise_check_encrypt_br_dp_packing_ks_ms_noise {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
fn noise_check_encrypt_br_dp_packing_ks_ms_pfail(meta_params: MetaParameters) {
|
||||
let (pfail_test_meta, params, comp_params) = {
|
||||
let (mut params, comp_params) = (
|
||||
meta_params.compute_parameters,
|
||||
meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution(),
|
||||
meta_params.compression_parameters.unwrap(),
|
||||
);
|
||||
|
||||
@@ -537,7 +528,7 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_pfail(meta_params: MetaParameters) {
|
||||
assert_eq!(original_carry_modulus.0, 4);
|
||||
|
||||
let noise_simulation_bsk =
|
||||
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
|
||||
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
|
||||
let noise_simulation_packing_key =
|
||||
NoiseSimulationLwePackingKeyswitchKey::new_from_comp_parameters(params, comp_params);
|
||||
|
||||
@@ -685,4 +676,5 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_pfail(meta_params: MetaParameters) {
|
||||
create_parameterized_test!(noise_check_encrypt_br_dp_packing_ks_ms_pfail {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::dp_ks_ms::any_ms;
|
||||
use super::utils::noise_simulation::{
|
||||
DynLwe, NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
|
||||
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
|
||||
DynLwe, DynLweSecretKeyView, NoiseSimulationGenericBootstrapKey, NoiseSimulationGlwe,
|
||||
NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
|
||||
};
|
||||
use super::utils::traits::*;
|
||||
use super::utils::{
|
||||
@@ -22,7 +22,6 @@ use crate::shortint::atomic_pattern::AtomicPattern;
|
||||
use crate::shortint::ciphertext::{
|
||||
CompressedCiphertextList, CompressedCiphertextListMeta, ReRandomizationSeed,
|
||||
};
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::encoding::ShortintEncoding;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
@@ -31,6 +30,7 @@ use crate::shortint::list_compression::{CompressionPrivateKeys, DecompressionKey
|
||||
use crate::shortint::parameters::test_params::{
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
};
|
||||
use crate::shortint::parameters::{
|
||||
AtomicPatternParameters, CarryModulus, CompactCiphertextListExpansionKind,
|
||||
@@ -85,8 +85,7 @@ pub fn br_rerand_dp_ks_any_ms<
|
||||
)
|
||||
where
|
||||
Accumulator: AllocateLweBootstrapResult<Output = PBSResult, SideResources = Resources>,
|
||||
DecompPBSKey:
|
||||
LweClassicFftBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources>,
|
||||
DecompPBSKey: LweGenericBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources>,
|
||||
KsKeyRerand: AllocateLweKeyswitchResult<Output = KsedZeroReRand, SideResources = Resources>
|
||||
+ LweKeyswitch<InputZeroRerand, KsedZeroReRand, SideResources = Resources>,
|
||||
PBSResult: for<'a> LweUncorrelatedAdd<
|
||||
@@ -102,7 +101,9 @@ where
|
||||
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
|
||||
Output = MsResult,
|
||||
SideResources = Resources,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
|
||||
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
|
||||
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
|
||||
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
|
||||
AfterDriftOutput = DriftTechniqueResult,
|
||||
AfterMsOutput = MsResult,
|
||||
@@ -116,7 +117,7 @@ where
|
||||
{
|
||||
// BR to decomp
|
||||
let mut br_result = decomp_accumulator.allocate_lwe_bootstrap_result(side_resources);
|
||||
decomp_bsk.lwe_classic_fft_pbs(&input, &mut br_result, decomp_accumulator, side_resources);
|
||||
decomp_bsk.lwe_generic_bootstrap(&input, &mut br_result, decomp_accumulator, side_resources);
|
||||
|
||||
// Ks the CPK encryption of 0 to be added to BR result
|
||||
let mut ksed_zero_rerand = ksk_rerand.allocate_lwe_keyswitch_result(side_resources);
|
||||
@@ -278,185 +279,76 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper(
|
||||
|
||||
let before_ms = after_drift.as_ref().unwrap_or(&after_ks);
|
||||
|
||||
match &cks.atomic_pattern {
|
||||
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
|
||||
let params = standard_atomic_pattern_client_key.parameters;
|
||||
let comp_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
// Adapt to the compression which has no carry bits
|
||||
carry_modulus: CarryModulus(1),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let compute_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let params = cks.parameters();
|
||||
let compute_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let comp_encoding = ShortintEncoding {
|
||||
// Adapt to the compression which has no carry bits
|
||||
carry_modulus: CarryModulus(1),
|
||||
..compute_encoding
|
||||
};
|
||||
|
||||
let cpk_lwe_secret_key = cpk_private_key.key();
|
||||
let comp_lwe_secret_key = comp_private_key.post_packing_ks_key.as_lwe_secret_key();
|
||||
let cpk_lwe_secret_key_dyn = cpk_private_key.lwe_secret_key_as_dyn();
|
||||
let comp_lwe_secret_key = comp_private_key.post_packing_ks_key.as_lwe_secret_key();
|
||||
let comp_lwe_secret_key_dyn = DynLweSecretKeyView::U64 {
|
||||
key: comp_lwe_secret_key,
|
||||
encoding: comp_encoding,
|
||||
};
|
||||
|
||||
let large_compute_lwe_secret_key =
|
||||
standard_atomic_pattern_client_key.large_lwe_secret_key();
|
||||
let small_compute_lwe_secret_key =
|
||||
standard_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
(
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_64(),
|
||||
&comp_lwe_secret_key,
|
||||
msg,
|
||||
&comp_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_br.as_lwe_64(),
|
||||
&large_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
),
|
||||
),
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input_zero_rerand.as_lwe_64(),
|
||||
&cpk_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ksed_zero_rerand.as_lwe_64(),
|
||||
&large_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
),
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_rerand.as_lwe_64(),
|
||||
&large_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_dp.as_lwe_64(),
|
||||
&large_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks.as_lwe_64(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_64(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_64(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding,
|
||||
),
|
||||
)
|
||||
}
|
||||
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
|
||||
let params = ks32_atomic_pattern_client_key.parameters;
|
||||
let comp_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
// Adapt to the compression which has no carry bits
|
||||
carry_modulus: CarryModulus(1),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let compute_encoding_u32 = ShortintEncoding {
|
||||
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let compute_encoding_u64 = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let large_compute_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
|
||||
let small_compute_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
|
||||
|
||||
let cpk_lwe_secret_key = cpk_private_key.key();
|
||||
let comp_lwe_secret_key = comp_private_key.post_packing_ks_key.as_lwe_secret_key();
|
||||
|
||||
let large_compute_lwe_secret_key =
|
||||
ks32_atomic_pattern_client_key.large_lwe_secret_key();
|
||||
let small_compute_lwe_secret_key =
|
||||
ks32_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
|
||||
let msg_u32: u32 = msg.try_into().unwrap();
|
||||
|
||||
(
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_64(),
|
||||
&comp_lwe_secret_key,
|
||||
msg,
|
||||
&comp_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_br.as_lwe_64(),
|
||||
&large_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding_u64,
|
||||
),
|
||||
),
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input_zero_rerand.as_lwe_64(),
|
||||
&cpk_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding_u64,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ksed_zero_rerand.as_lwe_64(),
|
||||
&large_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding_u64,
|
||||
),
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_rerand.as_lwe_64(),
|
||||
&large_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding_u64,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_dp.as_lwe_64(),
|
||||
&large_compute_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding_u64,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks.as_lwe_32(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg_u32,
|
||||
&compute_encoding_u32,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_32(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg_u32,
|
||||
&compute_encoding_u32,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_32(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg_u32,
|
||||
&compute_encoding_u32,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
(
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&input, &comp_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&after_br,
|
||||
&large_compute_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
),
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&input_zero_rerand,
|
||||
&cpk_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&after_ksed_zero_rerand,
|
||||
&large_compute_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&after_rerand,
|
||||
&large_compute_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&after_dp,
|
||||
&large_compute_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&after_ks,
|
||||
&small_compute_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
before_ms,
|
||||
&small_compute_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
|
||||
&after_ms,
|
||||
&small_compute_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -589,7 +481,9 @@ fn encrypt_br_rerand_dp_ks_any_ms_pfail_helper(
|
||||
|
||||
fn noise_check_encrypt_br_rerand_dp_ks_ms_noise(meta_params: MetaParameters) {
|
||||
let (params, cpk_params, rerand_ksk_params, compression_params) = {
|
||||
let compute_params = meta_params.compute_parameters;
|
||||
let compute_params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
|
||||
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
|
||||
// expand
|
||||
@@ -627,7 +521,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise(meta_params: MetaParameters) {
|
||||
let noise_simulation_modulus_switch_config =
|
||||
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
|
||||
let noise_simulation_decomp_bsk =
|
||||
NoiseSimulationLweFourierBsk::new_from_comp_parameters(params, compression_params);
|
||||
NoiseSimulationGenericBootstrapKey::new_from_comp_parameters(params, compression_params);
|
||||
|
||||
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
|
||||
let compute_br_input_modulus_log = sks.br_input_modulus_log();
|
||||
@@ -791,11 +685,14 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise(meta_params: MetaParameters) {
|
||||
create_parameterized_test!(noise_check_encrypt_br_rerand_dp_ks_ms_noise {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
fn noise_check_encrypt_br_rerand_dp_ks_ms_pfail(meta_params: MetaParameters) {
|
||||
let (params, cpk_params, rerand_ksk_params, compression_params) = {
|
||||
let compute_params = meta_params.compute_parameters;
|
||||
let compute_params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
|
||||
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
|
||||
// expand
|
||||
@@ -898,11 +795,14 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_pfail(meta_params: MetaParameters) {
|
||||
create_parameterized_test!(noise_check_encrypt_br_rerand_dp_ks_ms_pfail {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs(meta_params: MetaParameters) {
|
||||
let (params, cpk_params, rerand_ksk_params, compression_params) = {
|
||||
let compute_params = meta_params.compute_parameters;
|
||||
let compute_params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
|
||||
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
|
||||
// expand
|
||||
@@ -1050,7 +950,7 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs(meta_params: MetaParameters) {
|
||||
|
||||
// Complete the AP by computing the PBS to match shortint
|
||||
let mut pbs_result = id_lut.allocate_lwe_bootstrap_result(&mut ());
|
||||
sks.lwe_classic_fft_pbs(&after_ms, &mut pbs_result, &id_lut, &mut ());
|
||||
sks.apply_generic_blind_rotation(&after_ms, &mut pbs_result, &id_lut);
|
||||
|
||||
assert_eq!(pbs_result.as_lwe_64(), shortint_res.ct.as_view());
|
||||
}
|
||||
@@ -1059,4 +959,5 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs(meta_params: MetaParameters) {
|
||||
create_parameterized_test!(sanity_check_encrypt_br_rerand_dp_ks_ms_pbs {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
@@ -9,14 +9,13 @@ use super::utils::{
|
||||
};
|
||||
use super::{should_run_short_pfail_tests_debug, should_use_single_key_debug};
|
||||
use crate::core_crypto::commons::parameters::CiphertextModulusLog;
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::encoding::ShortintEncoding;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::key_switching_key::{KeySwitchingKeyBuildHelper, KeySwitchingKeyView};
|
||||
use crate::shortint::parameters::test_params::{
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
};
|
||||
use crate::shortint::parameters::{
|
||||
AtomicPatternParameters, CarryModulus, CompactCiphertextListExpansionKind,
|
||||
@@ -26,7 +25,6 @@ use crate::shortint::parameters::{
|
||||
use crate::shortint::public_key::compact::{CompactPrivateKey, CompactPublicKey};
|
||||
use crate::shortint::server_key::tests::parameterized_test::create_parameterized_test;
|
||||
use crate::shortint::server_key::ServerKey;
|
||||
use crate::shortint::PaddingBit;
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -54,7 +52,9 @@ where
|
||||
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
|
||||
Output = MsResult,
|
||||
SideResources = Resources,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
|
||||
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
|
||||
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
|
||||
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
|
||||
AfterDriftOutput = DriftTechniqueResult,
|
||||
AfterMsOutput = MsResult,
|
||||
@@ -163,97 +163,19 @@ fn cpk_ks_any_ms_inner_helper(
|
||||
|
||||
let before_ms = after_drift.as_ref().unwrap_or(&after_ks_ds);
|
||||
|
||||
match &cks.atomic_pattern {
|
||||
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
|
||||
let params = standard_atomic_pattern_client_key.parameters;
|
||||
let encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let cpk_lwe_secret_key_dyn = cpk_private_key.lwe_secret_key_as_dyn();
|
||||
let small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
|
||||
|
||||
let cpk_lwe_secret_key = cpk_private_key.key();
|
||||
|
||||
let small_compute_lwe_secret_key =
|
||||
standard_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_64(),
|
||||
&cpk_lwe_secret_key,
|
||||
msg,
|
||||
&encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks_ds.as_lwe_64(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg,
|
||||
&encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_64(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg,
|
||||
&encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_64(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg,
|
||||
&encoding,
|
||||
),
|
||||
)
|
||||
}
|
||||
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
|
||||
let params = ks32_atomic_pattern_client_key.parameters;
|
||||
let compute_encoding_u32 = ShortintEncoding {
|
||||
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let compute_encoding_u64 = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
|
||||
let cpk_lwe_secret_key = cpk_private_key.key();
|
||||
|
||||
let small_compute_lwe_secret_key =
|
||||
ks32_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
|
||||
let msg_u32: u32 = msg.try_into().unwrap();
|
||||
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_64(),
|
||||
&cpk_lwe_secret_key,
|
||||
msg,
|
||||
&compute_encoding_u64,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks_ds.as_lwe_32(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg_u32,
|
||||
&compute_encoding_u32,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_32(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg_u32,
|
||||
&compute_encoding_u32,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_32(),
|
||||
&small_compute_lwe_secret_key,
|
||||
msg_u32,
|
||||
&compute_encoding_u32,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&input, &cpk_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_ks_ds, &small_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(before_ms, &small_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
|
||||
&after_ms,
|
||||
&small_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -327,7 +249,9 @@ fn cpk_ks_any_ms_pfail_helper(
|
||||
|
||||
fn noise_check_encrypt_cpk_ks_ms_noise(meta_params: MetaParameters) {
|
||||
let (params, cpk_params, ksk_ds_params) = {
|
||||
let compute_params = meta_params.compute_parameters;
|
||||
let compute_params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
|
||||
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
|
||||
// expand
|
||||
@@ -454,11 +378,14 @@ fn noise_check_encrypt_cpk_ks_ms_noise(meta_params: MetaParameters) {
|
||||
create_parameterized_test!(noise_check_encrypt_cpk_ks_ms_noise {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
fn noise_check_encrypt_cpk_ks_ms_pfail(meta_params: MetaParameters) {
|
||||
let (params, cpk_params, ksk_ds_params) = {
|
||||
let compute_params = meta_params.compute_parameters;
|
||||
let compute_params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
|
||||
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
|
||||
// expand
|
||||
@@ -548,11 +475,14 @@ fn noise_check_encrypt_cpk_ks_ms_pfail(meta_params: MetaParameters) {
|
||||
create_parameterized_test!(noise_check_encrypt_cpk_ks_ms_pfail {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
fn sanity_check_encrypt_cpk_ks_ms_pbs(meta_params: MetaParameters) {
|
||||
let (params, cpk_params, ksk_ds_params, orig_cast_mode) = {
|
||||
let compute_params = meta_params.compute_parameters;
|
||||
let compute_params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
|
||||
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
|
||||
// expand
|
||||
@@ -635,7 +565,7 @@ fn sanity_check_encrypt_cpk_ks_ms_pbs(meta_params: MetaParameters) {
|
||||
|
||||
// Complete the AP by computing the PBS to match shortint
|
||||
let mut pbs_result = id_lut.allocate_lwe_bootstrap_result(&mut ());
|
||||
sks.lwe_classic_fft_pbs(&after_ms, &mut pbs_result, &id_lut, &mut ());
|
||||
sks.apply_generic_blind_rotation(&after_ms, &mut pbs_result, &id_lut);
|
||||
|
||||
assert_eq!(pbs_result.as_lwe_64(), shortint_res.ct.as_view());
|
||||
}
|
||||
@@ -644,4 +574,5 @@ fn sanity_check_encrypt_cpk_ks_ms_pbs(meta_params: MetaParameters) {
|
||||
create_parameterized_test!(sanity_check_encrypt_cpk_ks_ms_pbs {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
@@ -6,14 +6,13 @@ use super::utils::{
|
||||
};
|
||||
use super::{should_run_short_pfail_tests_debug, should_use_single_key_debug};
|
||||
use crate::core_crypto::commons::parameters::CiphertextModulusLog;
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::encoding::{PaddingBit, ShortintEncoding};
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::parameters::test_params::{
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
};
|
||||
use crate::shortint::parameters::{AtomicPatternParameters, CarryModulus, MetaParameters};
|
||||
use crate::shortint::server_key::tests::parameterized_test::create_parameterized_test;
|
||||
@@ -32,7 +31,9 @@ where
|
||||
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
|
||||
Output = MsResult,
|
||||
SideResources = Resources,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
|
||||
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
|
||||
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
|
||||
// We need to be able to allocate the result and apply drift technique + mod switch it
|
||||
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
|
||||
AfterDriftOutput = DriftTechniqueResult,
|
||||
@@ -76,6 +77,17 @@ where
|
||||
side_resources,
|
||||
);
|
||||
|
||||
(None, ms_result)
|
||||
}
|
||||
NoiseSimulationModulusSwitchConfig::MultiBit(grouping_factor) => {
|
||||
let mut ms_result = input.allocate_multi_bit_mod_switch_result(side_resources);
|
||||
input.multi_bit_mod_switch(
|
||||
grouping_factor,
|
||||
br_input_modulus_log,
|
||||
&mut ms_result,
|
||||
side_resources,
|
||||
);
|
||||
|
||||
(None, ms_result)
|
||||
}
|
||||
}
|
||||
@@ -116,7 +128,9 @@ where
|
||||
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
|
||||
Output = MsResult,
|
||||
SideResources = Resources,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
|
||||
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
|
||||
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
|
||||
// We need to be able to allocate the result and apply drift technique + mod switch it
|
||||
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
|
||||
AfterDriftOutput = DriftTechniqueResult,
|
||||
@@ -152,7 +166,9 @@ where
|
||||
/// Test function to verify that the noise checking tools match the actual atomic patterns
|
||||
/// implemented in shortint
|
||||
fn sanity_check_encrypt_dp_ks_pbs(meta_params: MetaParameters) {
|
||||
let params = meta_params.compute_parameters;
|
||||
let params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let cks = ClientKey::new(params);
|
||||
let sks = ServerKey::new(&cks);
|
||||
|
||||
@@ -178,7 +194,7 @@ fn sanity_check_encrypt_dp_ks_pbs(meta_params: MetaParameters) {
|
||||
|
||||
// Complete the AP by computing the PBS to match shortint
|
||||
let mut pbs_result = id_lut.allocate_lwe_bootstrap_result(&mut ());
|
||||
sks.lwe_classic_fft_pbs(&after_ms, &mut pbs_result, &id_lut, &mut ());
|
||||
sks.apply_generic_blind_rotation(&after_ms, &mut pbs_result, &id_lut);
|
||||
|
||||
let mut shortint_res =
|
||||
sks.unchecked_scalar_mul(&input_zero, max_scalar_mul.try_into().unwrap());
|
||||
@@ -192,6 +208,7 @@ create_parameterized_test!(sanity_check_encrypt_dp_ks_pbs {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
fn encrypt_dp_ks_any_ms_inner_helper(
|
||||
@@ -235,100 +252,20 @@ fn encrypt_dp_ks_any_ms_inner_helper(
|
||||
|
||||
let before_ms = after_drift.as_ref().unwrap_or(&after_ks);
|
||||
|
||||
match &cks.atomic_pattern {
|
||||
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
|
||||
let output_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let large_lwe_secret_key = standard_atomic_pattern_client_key.large_lwe_secret_key();
|
||||
let small_lwe_secret_key = standard_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_dp.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&output_encoding,
|
||||
),
|
||||
)
|
||||
}
|
||||
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
|
||||
let msg_u32: u32 = msg.try_into().unwrap();
|
||||
let params = ks32_atomic_pattern_client_key.parameters;
|
||||
let small_key_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let big_key_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let large_lwe_secret_key = ks32_atomic_pattern_client_key.large_lwe_secret_key();
|
||||
let small_lwe_secret_key = ks32_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&big_key_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_dp.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&big_key_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&small_key_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&small_key_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&small_key_encoding,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
let large_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
|
||||
let small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
|
||||
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&input, &large_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_dp, &large_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_ks, &small_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(before_ms, &small_lwe_secret_key_dyn, msg),
|
||||
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
|
||||
&after_ms,
|
||||
&small_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn encrypt_dp_ks_any_ms_noise_helper(
|
||||
@@ -390,7 +327,9 @@ fn encrypt_dp_ks_any_ms_pfail_helper(
|
||||
}
|
||||
|
||||
fn noise_check_encrypt_dp_ks_ms_noise(meta_params: MetaParameters) {
|
||||
let params = meta_params.compute_parameters;
|
||||
let params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
let cks = ClientKey::new(params);
|
||||
let sks = ServerKey::new(&cks);
|
||||
|
||||
@@ -482,11 +421,14 @@ create_parameterized_test!(noise_check_encrypt_dp_ks_ms_noise {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
fn noise_check_encrypt_dp_ks_ms_pfail(meta_params: MetaParameters) {
|
||||
let (pfail_test_meta, params) = {
|
||||
let mut ap_params = meta_params.compute_parameters;
|
||||
let mut ap_params = meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution();
|
||||
|
||||
let original_message_modulus = ap_params.message_modulus();
|
||||
let original_carry_modulus = ap_params.carry_modulus();
|
||||
@@ -548,4 +490,5 @@ create_parameterized_test!(noise_check_encrypt_dp_ks_ms_pfail {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
@@ -6,7 +6,6 @@ use super::utils::{mean_and_variance_check, DecryptionAndNoiseResult, NoiseSampl
|
||||
use crate::core_crypto::algorithms::lwe_programmable_bootstrapping::generate_programmable_bootstrap_glwe_lut;
|
||||
use crate::core_crypto::commons::dispersion::Variance;
|
||||
use crate::core_crypto::commons::parameters::CiphertextModulusLog;
|
||||
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::encoding::{PaddingBit, ShortintEncoding};
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
@@ -18,6 +17,7 @@ use crate::shortint::parameters::noise_squashing::NoiseSquashingParameters;
|
||||
use crate::shortint::parameters::test_params::{
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
};
|
||||
use crate::shortint::parameters::{
|
||||
AtomicPatternParameters, MetaParameters, NoiseSquashingCompressionParameters,
|
||||
@@ -68,7 +68,9 @@ where
|
||||
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
|
||||
Output = MsResult,
|
||||
SideResources = Resources,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
|
||||
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
|
||||
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
|
||||
// We need to be able to allocate the result and apply drift technique + mod switch it
|
||||
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
|
||||
AfterDriftOutput = DriftTechniqueResult,
|
||||
@@ -84,7 +86,7 @@ where
|
||||
// one to allocate the blind rotation result
|
||||
Accumulator: AllocateLweBootstrapResult<Output = PbsResult, SideResources = Resources>,
|
||||
// We need to be able to apply the PBS
|
||||
Bsk: LweClassicFft128Bootstrap<MsResult, PbsResult, Accumulator, SideResources = Resources>,
|
||||
Bsk: LweGenericBlindRotate128<MsResult, PbsResult, Accumulator, SideResources = Resources>,
|
||||
{
|
||||
let after_dp = input.scalar_mul(scalar, side_resources);
|
||||
let mut ks_result = ksk.allocate_lwe_keyswitch_result(side_resources);
|
||||
@@ -98,7 +100,7 @@ where
|
||||
);
|
||||
|
||||
let mut pbs_result = accumulator.allocate_lwe_bootstrap_result(side_resources);
|
||||
bsk_128.lwe_classic_fft_128_pbs(&ms_result, &mut pbs_result, accumulator, side_resources);
|
||||
bsk_128.lwe_generic_blind_rotate_128(&ms_result, &mut pbs_result, accumulator, side_resources);
|
||||
(
|
||||
input,
|
||||
after_dp,
|
||||
@@ -159,7 +161,9 @@ where
|
||||
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
|
||||
Output = MsResult,
|
||||
SideResources = Resources,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
|
||||
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
|
||||
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
|
||||
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
|
||||
// We need to be able to allocate the result and apply drift technique + mod switch it
|
||||
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
|
||||
AfterDriftOutput = DriftTechniqueResult,
|
||||
@@ -175,7 +179,7 @@ where
|
||||
// one to allocate the blind rotation result
|
||||
Accumulator: AllocateLweBootstrapResult<Output = PbsResult, SideResources = Resources> + Sync,
|
||||
// We need to be able to apply the PBS
|
||||
Bsk: LweClassicFft128Bootstrap<MsResult, PbsResult, Accumulator, SideResources = Resources>
|
||||
Bsk: LweGenericBlindRotate128<MsResult, PbsResult, Accumulator, SideResources = Resources>
|
||||
+ Sync,
|
||||
PackingKey: AllocateLwePackingKeyswitchResult<Output = PackingResult, SideResources = Resources>
|
||||
+ for<'a> LwePackingKeyswitch<[&'a PbsResult], PackingResult, SideResources = Resources>,
|
||||
@@ -240,8 +244,12 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks(meta_params: MetaParame
|
||||
let (params, noise_squashing_params, noise_squashing_compression_params) = {
|
||||
let meta_noise_squashing_params = meta_params.noise_squashing_parameters.unwrap();
|
||||
(
|
||||
meta_params.compute_parameters,
|
||||
meta_noise_squashing_params.parameters,
|
||||
meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution(),
|
||||
meta_noise_squashing_params
|
||||
.parameters
|
||||
.with_deterministic_execution(),
|
||||
meta_noise_squashing_params.compression_parameters.unwrap(),
|
||||
)
|
||||
};
|
||||
@@ -314,14 +322,14 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks(meta_params: MetaParame
|
||||
let compressed = noise_squashing_compression_key
|
||||
.compress_noise_squashed_ciphertexts_into_list(&noise_squashed);
|
||||
|
||||
let underlying_glwes = compressed.glwe_ciphertext_list;
|
||||
let underlying_glwes = &compressed.glwe_ciphertext_list;
|
||||
|
||||
assert_eq!(underlying_glwes.len(), 1);
|
||||
|
||||
let extracted = underlying_glwes[0].extract();
|
||||
|
||||
// Bodies that were not filled are discarded
|
||||
after_packing.get_mut_body().as_mut()[lwe_per_glwe.0..].fill(0);
|
||||
after_packing.get_mut_body().as_mut()[compressed.len()..].fill(0);
|
||||
|
||||
assert_eq!(after_packing.as_view(), extracted.as_view());
|
||||
}
|
||||
@@ -329,6 +337,7 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks(meta_params: MetaParame
|
||||
create_parameterized_test!(sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -451,117 +460,45 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper(
|
||||
.map(
|
||||
|(input, after_dp, after_ks, after_drift, after_ms, after_pbs128)| {
|
||||
let before_ms = after_drift.as_ref().unwrap_or(&after_ks);
|
||||
match &cks.atomic_pattern {
|
||||
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
|
||||
let params = standard_atomic_pattern_client_key.parameters;
|
||||
let u64_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let large_lwe_secret_key =
|
||||
standard_atomic_pattern_client_key.large_lwe_secret_key();
|
||||
let small_lwe_secret_key =
|
||||
standard_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&u64_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_dp.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&u64_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&u64_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&u64_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_64(),
|
||||
&small_lwe_secret_key,
|
||||
msg,
|
||||
&u64_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_pbs128,
|
||||
&noise_squashing_private_key.post_noise_squashing_lwe_secret_key(),
|
||||
msg.into(),
|
||||
&u128_encoding,
|
||||
),
|
||||
)
|
||||
}
|
||||
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
|
||||
let msg_u32: u32 = msg.try_into().unwrap();
|
||||
let params = ks32_atomic_pattern_client_key.parameters;
|
||||
let u32_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let u64_encoding = ShortintEncoding {
|
||||
ciphertext_modulus: params.ciphertext_modulus(),
|
||||
message_modulus: params.message_modulus(),
|
||||
carry_modulus: params.carry_modulus(),
|
||||
padding_bit: PaddingBit::Yes,
|
||||
};
|
||||
let large_lwe_secret_key =
|
||||
ks32_atomic_pattern_client_key.large_lwe_secret_key();
|
||||
let small_lwe_secret_key =
|
||||
ks32_atomic_pattern_client_key.small_lwe_secret_key();
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&input.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&u64_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_dp.as_lwe_64(),
|
||||
&large_lwe_secret_key,
|
||||
msg,
|
||||
&u64_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ks.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&u32_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&before_ms.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&u32_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_ms.as_lwe_32(),
|
||||
&small_lwe_secret_key,
|
||||
msg_u32,
|
||||
&u32_encoding,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_pbs128,
|
||||
&noise_squashing_private_key.post_noise_squashing_lwe_secret_key(),
|
||||
msg.into(),
|
||||
&u128_encoding,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let large_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
|
||||
let small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
|
||||
|
||||
(
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&input,
|
||||
&large_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&after_dp,
|
||||
&large_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
&after_ks,
|
||||
&small_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_lwe(
|
||||
before_ms,
|
||||
&small_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
|
||||
&after_ms,
|
||||
&small_lwe_secret_key_dyn,
|
||||
msg,
|
||||
),
|
||||
// This one is a "raw" LWE given today we don't need to manage several unsigned
|
||||
// integer types after a PBS128
|
||||
DecryptionAndNoiseResult::new_from_lwe(
|
||||
&after_pbs128,
|
||||
&noise_squashing_private_key.post_noise_squashing_lwe_secret_key(),
|
||||
msg.into(),
|
||||
&u128_encoding,
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
@@ -662,8 +599,12 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise(meta_params: MetaP
|
||||
let (params, noise_squashing_params, noise_squashing_compression_params) = {
|
||||
let meta_noise_squashing_params = meta_params.noise_squashing_parameters.unwrap();
|
||||
(
|
||||
meta_params.compute_parameters,
|
||||
meta_noise_squashing_params.parameters,
|
||||
meta_params
|
||||
.compute_parameters
|
||||
.with_deterministic_execution(),
|
||||
meta_noise_squashing_params
|
||||
.parameters
|
||||
.with_deterministic_execution(),
|
||||
meta_noise_squashing_params.compression_parameters.unwrap(),
|
||||
)
|
||||
};
|
||||
@@ -684,7 +625,7 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise(meta_params: MetaP
|
||||
let noise_simulation_modulus_switch_config =
|
||||
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
|
||||
let noise_simulation_bsk128 =
|
||||
NoiseSimulationLweFourier128Bsk::new_from_parameters(params, noise_squashing_params);
|
||||
NoiseSimulationGenericBootstrapKey128::new_from_parameters(params, noise_squashing_params);
|
||||
let noise_simulation_packing_key =
|
||||
NoiseSimulationLwePackingKeyswitchKey::new_from_noise_squashing_parameters(
|
||||
noise_squashing_params,
|
||||
@@ -806,4 +747,5 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise(meta_params: MetaP
|
||||
create_parameterized_test!(noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise {
|
||||
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
|
||||
});
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
pub mod noise_simulation;
|
||||
pub use noise_simulation::traits;
|
||||
pub mod traits;
|
||||
|
||||
use crate::core_crypto::algorithms::glwe_encryption::decrypt_glwe_ciphertext;
|
||||
use crate::core_crypto::algorithms::lwe_encryption::{
|
||||
allocate_and_encrypt_new_lwe_ciphertext, decrypt_lwe_ciphertext,
|
||||
};
|
||||
use crate::core_crypto::algorithms::lwe_multi_bit_programmable_bootstrapping::{
|
||||
MultiBitModulusSwitchedLweCiphertext, StandardMultiBitModulusSwitchedCt,
|
||||
};
|
||||
use crate::core_crypto::algorithms::misc::torus_modular_diff;
|
||||
use crate::core_crypto::algorithms::test::round_decode;
|
||||
use crate::core_crypto::commons::dispersion::{DispersionParameter, Variance};
|
||||
@@ -14,7 +17,7 @@ use crate::core_crypto::commons::noise_formulas::secure_noise::{
|
||||
minimal_lwe_variance_for_132_bits_security_gaussian,
|
||||
minimal_lwe_variance_for_132_bits_security_tuniform,
|
||||
};
|
||||
use crate::core_crypto::commons::numeric::{CastFrom, UnsignedInteger};
|
||||
use crate::core_crypto::commons::numeric::{CastFrom, CastInto, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
CiphertextModulus, DynamicDistribution, LweCiphertextCount, LweDimension, PlaintextCount,
|
||||
};
|
||||
@@ -29,11 +32,14 @@ use crate::core_crypto::entities::glwe_ciphertext::GlweCiphertext;
|
||||
use crate::core_crypto::entities::glwe_secret_key::GlweSecretKey;
|
||||
use crate::core_crypto::entities::lwe_ciphertext::{LweCiphertext, LweCiphertextOwned};
|
||||
use crate::core_crypto::entities::lwe_secret_key::LweSecretKey;
|
||||
use crate::core_crypto::entities::{Cleartext, PlaintextList};
|
||||
use crate::core_crypto::entities::{Cleartext, Plaintext, PlaintextList};
|
||||
use crate::shortint::encoding::ShortintEncoding;
|
||||
use crate::shortint::parameters::{
|
||||
AtomicPatternParameters, CarryModulus, MessageModulus, PBSParameters,
|
||||
};
|
||||
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
|
||||
DynLwe, DynLweSecretKeyView, DynModSwitchedLwe, DynStandardMultiBitModulusSwitchedCt,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct PrecisionWithPadding {
|
||||
@@ -399,6 +405,40 @@ pub enum DecryptionAndNoiseResult {
|
||||
}
|
||||
|
||||
impl DecryptionAndNoiseResult {
|
||||
pub fn new_from_plaintext<Scalar: UnsignedInteger + CastFrom<u64>>(
|
||||
decrypted_plaintext: Plaintext<Scalar>,
|
||||
expected_msg: Scalar,
|
||||
encoding: &ShortintEncoding<Scalar>,
|
||||
) -> Self {
|
||||
let delta = encoding.delta();
|
||||
let cleartext_modulus_with_padding = encoding.full_cleartext_space();
|
||||
|
||||
// We apply the modulus on the cleartext + the padding bit
|
||||
let decoded_msg =
|
||||
round_decode(decrypted_plaintext.0, delta) % cleartext_modulus_with_padding;
|
||||
|
||||
let expected_plaintext = expected_msg * delta;
|
||||
|
||||
// decrypted_plaintext = expected_plaintext + error
|
||||
// The order below computes:
|
||||
// decrypted_plaintext - expected_plaintext in a modular way, which is what we want
|
||||
// It only changes the average value sign, so that it is more intuitive when comparing to
|
||||
// theory
|
||||
let noise = torus_modular_diff(
|
||||
decrypted_plaintext.0,
|
||||
expected_plaintext,
|
||||
encoding.ciphertext_modulus,
|
||||
);
|
||||
|
||||
if decoded_msg == expected_msg {
|
||||
Self::DecryptionSucceeded {
|
||||
noise: NoiseSample { value: noise },
|
||||
}
|
||||
} else {
|
||||
Self::DecryptionFailed
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_from_lwe<Scalar: UnsignedInteger + CastFrom<u64>, CtCont, KeyCont>(
|
||||
ct: &LweCiphertext<CtCont>,
|
||||
secret_key: &LweSecretKey<KeyCont>,
|
||||
@@ -409,33 +449,85 @@ impl DecryptionAndNoiseResult {
|
||||
CtCont: Container<Element = Scalar>,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
{
|
||||
let decrypted_plaintext = decrypt_lwe_ciphertext(secret_key, ct).0;
|
||||
let decrypted_plaintext = decrypt_lwe_ciphertext(secret_key, ct);
|
||||
|
||||
let delta = encoding.delta();
|
||||
let cleartext_modulus_with_padding = encoding.full_cleartext_space();
|
||||
Self::new_from_plaintext(decrypted_plaintext, expected_msg, encoding)
|
||||
}
|
||||
|
||||
// We apply the modulus on the cleartext + the padding bit
|
||||
let decoded_msg = round_decode(decrypted_plaintext, delta) % cleartext_modulus_with_padding;
|
||||
|
||||
let expected_plaintext = expected_msg * delta;
|
||||
|
||||
// decrypted_plaintext = expected_plaintext + error
|
||||
// The order below computes:
|
||||
// decrypted_plaintext - expected_plaintext in a modular way, which is what we want
|
||||
// It only changes the average value sign, so that it is more intuitive when comparing to
|
||||
// theory
|
||||
let noise = torus_modular_diff(
|
||||
decrypted_plaintext,
|
||||
expected_plaintext,
|
||||
ct.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
if decoded_msg == expected_msg {
|
||||
Self::DecryptionSucceeded {
|
||||
noise: NoiseSample { value: noise },
|
||||
pub fn new_from_dyn_lwe(
|
||||
ct: &DynLwe,
|
||||
secret_key: &DynLweSecretKeyView<'_>,
|
||||
expected_msg: u64,
|
||||
) -> Self {
|
||||
match (ct, secret_key) {
|
||||
(DynLwe::U32(lwe_ciphertext), DynLweSecretKeyView::U32 { key, encoding }) => {
|
||||
Self::new_from_lwe(
|
||||
lwe_ciphertext,
|
||||
key,
|
||||
expected_msg.try_into().unwrap(),
|
||||
encoding,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
Self::DecryptionFailed
|
||||
(DynLwe::U64(lwe_ciphertext), DynLweSecretKeyView::U64 { key, encoding }) => {
|
||||
Self::new_from_lwe(lwe_ciphertext, key, expected_msg, encoding)
|
||||
}
|
||||
_ => panic!("Incompatible types in DecryptionAndNoiseResult::new_from_dyn_lwe"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_from_dyn_multi_bit_mod_switched_lwe(
|
||||
ct: &DynStandardMultiBitModulusSwitchedCt,
|
||||
secret_key: &DynLweSecretKeyView<'_>,
|
||||
expected_msg: u64,
|
||||
) -> Self {
|
||||
match (ct, secret_key) {
|
||||
(
|
||||
DynStandardMultiBitModulusSwitchedCt::U32(standard_multi_bit_modulus_switched_ct),
|
||||
DynLweSecretKeyView::U32 { key, encoding },
|
||||
) => {
|
||||
let decrypted_plaintext = decrypt_multi_bit_mod_switched_lwe_ciphertext(
|
||||
key,
|
||||
standard_multi_bit_modulus_switched_ct,
|
||||
);
|
||||
Self::new_from_plaintext(
|
||||
decrypted_plaintext,
|
||||
expected_msg.try_into().unwrap(),
|
||||
encoding,
|
||||
)
|
||||
}
|
||||
(
|
||||
DynStandardMultiBitModulusSwitchedCt::U64(standard_multi_bit_modulus_switched_ct),
|
||||
DynLweSecretKeyView::U64 { key, encoding },
|
||||
) => {
|
||||
let decrypted_plaintext = decrypt_multi_bit_mod_switched_lwe_ciphertext(
|
||||
key,
|
||||
standard_multi_bit_modulus_switched_ct,
|
||||
);
|
||||
Self::new_from_plaintext(decrypted_plaintext, expected_msg, encoding)
|
||||
}
|
||||
_ => panic!(
|
||||
"Incompatible types in \
|
||||
DecryptionAndNoiseResult::new_from_dyn_multi_bit_mod_switched_lwe"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_from_dyn_modswitched_lwe(
|
||||
ct: &DynModSwitchedLwe,
|
||||
secret_key: &DynLweSecretKeyView<'_>,
|
||||
expected_msg: u64,
|
||||
) -> Self {
|
||||
match ct {
|
||||
DynModSwitchedLwe::ModSwitchedLwe(dyn_lwe) => {
|
||||
Self::new_from_dyn_lwe(dyn_lwe, secret_key, expected_msg)
|
||||
}
|
||||
DynModSwitchedLwe::MultiBitModSwitchedLwe(
|
||||
dyn_standard_multi_bit_modulus_switched_ct,
|
||||
) => Self::new_from_dyn_multi_bit_mod_switched_lwe(
|
||||
dyn_standard_multi_bit_modulus_switched_ct,
|
||||
secret_key,
|
||||
expected_msg,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -599,6 +691,64 @@ pub fn expected_pfail_for_precision(
|
||||
statrs::function::erf::erfc(measured_std_score / core::f64::consts::SQRT_2)
|
||||
}
|
||||
|
||||
pub fn decrypt_multi_bit_mod_switched_lwe_ciphertext<Scalar, CtCont, KeyCont>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
mod_switched_lwe: &StandardMultiBitModulusSwitchedCt<Scalar, CtCont>,
|
||||
) -> Plaintext<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger + CastFrom<usize> + CastInto<usize>,
|
||||
CtCont: Container<Element = Scalar> + Sync,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
{
|
||||
let mut result: Scalar = mod_switched_lwe
|
||||
.switched_modulus_input_lwe_body()
|
||||
.cast_into();
|
||||
|
||||
let log_modulus = mod_switched_lwe.log_modulus;
|
||||
let grouping_factor = mod_switched_lwe.grouping_factor();
|
||||
|
||||
let shift_to_native = Scalar::BITS - log_modulus.0;
|
||||
|
||||
result <<= shift_to_native;
|
||||
|
||||
for (loop_idx, lwe_key_bits) in lwe_secret_key
|
||||
.as_ref()
|
||||
.chunks_exact(grouping_factor.0)
|
||||
.enumerate()
|
||||
{
|
||||
let selector = {
|
||||
let mut selector = 0usize;
|
||||
for bit in lwe_key_bits.iter() {
|
||||
let bit: usize = (*bit).cast_into();
|
||||
selector <<= 1;
|
||||
selector |= bit;
|
||||
}
|
||||
if selector == 0 {
|
||||
// We dont generate a mod switched value for selector == 0 it corresponds to key
|
||||
// bits == 0
|
||||
None
|
||||
} else {
|
||||
// We subtract 1 to be coherent with the fact the first mod switched value is not
|
||||
// generated
|
||||
Some(selector - 1)
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(selector) = selector {
|
||||
let mod_switched: Scalar = mod_switched_lwe
|
||||
.switched_modulus_input_mask_per_group(loop_idx)
|
||||
.nth(selector)
|
||||
.unwrap()
|
||||
.cast_into();
|
||||
// Put in the high bits same as the body to be able to measure the noise in the
|
||||
// encompassing modulus
|
||||
let mod_switched = mod_switched << shift_to_native;
|
||||
result = result.wrapping_sub(mod_switched);
|
||||
}
|
||||
}
|
||||
Plaintext(result)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expected_pfail_for_ci_run_filter() {
|
||||
// Practical check on a compression-like scenario, of interest because pfail is known to be very
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,28 @@
|
||||
pub use super::noise_simulation::traits::*;
|
||||
|
||||
/// Abstracts several bootstrapping implementation in the same way shortint's ServerKey does
|
||||
pub trait LweGenericBootstrap<Input, Output, Accumulator> {
|
||||
type SideResources;
|
||||
|
||||
fn lwe_generic_bootstrap(
|
||||
&self,
|
||||
input: &Input,
|
||||
output: &mut Output,
|
||||
accumulator: &Accumulator,
|
||||
side_resources: &mut Self::SideResources,
|
||||
);
|
||||
}
|
||||
|
||||
/// Abstracts several blind rotate implementation in the same way shortint's ServerKey does, this
|
||||
/// one is specific to the PBS 128 blind rotation
|
||||
pub trait LweGenericBlindRotate128<Input, Output, Accumulator> {
|
||||
type SideResources;
|
||||
|
||||
fn lwe_generic_blind_rotate_128(
|
||||
&self,
|
||||
input: &Input,
|
||||
output: &mut Output,
|
||||
accumulator: &Accumulator,
|
||||
side_resources: &mut Self::SideResources,
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user