Compare commits

..

34 Commits

Author SHA1 Message Date
Andrei Stoian
b4ea48165b fix(gpu): disable cache 2026-01-23 15:50:37 +01:00
Andrei Stoian
0a6b62627d fix(gpu): remove broadcast 2026-01-23 13:00:44 +01:00
Andrei Stoian
6deeb66bf8 fix(gpu): test remove sync on lut create 2026-01-23 11:27:47 +01:00
Andrei Stoian
17022dae69 feat(gpu): lut cache univariate 2026-01-22 16:56:14 +01:00
Andrei Stoian
09802dd5ee feat(gpu): lut cache 2026-01-22 10:04:39 +01:00
Andrei Stoian
e3fe433a35 fix(gpu): univariate fix 2026-01-21 17:24:16 +01:00
Andrei Stoian
2bea35a3b5 fix(gpu): finish bivariate 2026-01-21 16:13:21 +01:00
Andrei Stoian
e2bf226276 fix(gpu): start bivariate, fix all univariate 2026-01-21 15:24:51 +01:00
Andrei Stoian
c66f1c6d8b fix(gpu): all univariate luts 2026-01-21 12:06:21 +01:00
Andrei Stoian
9bfe190ad3 fix(gpu): sc prop fix 2026-01-21 11:48:21 +01:00
Andrei Stoian
e40070db0e fix(gpu): sc prop encapsulate lut 2026-01-21 10:35:21 +01:00
Andrei Stoian
e8d5ceac68 fix(gpu): more lut encaps 2026-01-20 15:43:24 +01:00
Andrei Stoian
f1526b29d8 fix(gpu): more lut 2026-01-19 17:55:45 +01:00
Andrei Stoian
602e0c5a19 fix(gpu): more lut encaps 2026-01-19 15:13:07 +01:00
Andrei Stoian
163c1eeffb chore(gpu): refactor lut generation 2026-01-16 16:02:17 +01:00
Nicolas Sarlin
3bcb9c8360 chore(test-vectors): update README 2026-01-15 17:43:41 +01:00
Arthur Meyre
20a64abaf1 chore: update README.md indicating the wrong MSRV 2026-01-15 13:23:34 +01:00
Arthur Meyre
4b31987a45 chore(ci): warn if a milestone is not set on a Pull Request 2026-01-15 13:23:34 +01:00
Enzo Di Maria
b27fbc5d78 feat(gpu): trivium 2026-01-15 11:26:12 +01:00
Arthur Meyre
3d797e4823 chore: update GPU code to still work with new test harnesses
- multi bit implementations are placeholders to be updated
2026-01-15 10:02:46 +01:00
Arthur Meyre
51ef40ace3 test: add multi bit support to dp_ks_pbs128_packingks 2026-01-15 10:02:46 +01:00
Arthur Meyre
c560462a4a chore: update PBS noise formulas 2026-01-15 10:02:46 +01:00
Arthur Meyre
67ed05a008 chore: add pbs128 multi bit formulas and noise simulation primitives 2026-01-15 10:02:46 +01:00
Arthur Meyre
236eea5bd7 test: add multi bit support to br_rerand_dp_ks_ms 2026-01-15 10:02:46 +01:00
Arthur Meyre
a1d3262726 test: add multi bit support to br_dp_packingks_ms 2026-01-15 10:02:46 +01:00
Arthur Meyre
afbeebc1b4 test: add multi bit support to cpk_ks_ms, add test params 2026-01-15 10:02:46 +01:00
Arthur Meyre
09cd5c1727 test: add multi bit case to dp_ks_ms 2026-01-15 10:02:46 +01:00
Arthur Meyre
521f1516bb test: add multi-bit parameters to br_dp_ks_ms noise checks
- support added for generic bootstrap to keep existing code
2026-01-15 10:02:46 +01:00
Arthur Meyre
3c171136ad chore: add multi bit noise primitives in core
- add a fully fledged MultiBit PBS trait required for BR -> ... APs
2026-01-15 10:02:46 +01:00
Arthur Meyre
6f360968df test: add multi bit modswitch in any_ms
- update implems to manage the right dynamic types to keep atomic patterns
coherent
2026-01-15 10:02:46 +01:00
Arthur Meyre
37a0c58cb9 test: update noise check tests to manage several mod switch types
- current primitives have a placeholder for the multi bit case
- generic PBS to handle classic and multi bit case to come in next PR
2026-01-15 10:02:46 +01:00
Arthur Meyre
99590e3b0f chore: prepare primitives for multi bit PBS
- implement traits on core primitives
2026-01-15 10:02:46 +01:00
Nicolas Sarlin
6300a025d9 chore(docs): fix api levels description 2026-01-13 09:43:49 +01:00
David Testé
7222bff5d6 chore(ci): fix artifact naming for hpu benchmarks
Prior to this commit, all generated artifacts would be identified
as integer benchmarks.
2026-01-12 15:42:24 +01:00
71 changed files with 4796 additions and 1782 deletions

View File

@@ -187,7 +187,7 @@ jobs:
- name: Upload parsed results artifact
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: ${{ github.sha }}_${{ matrix.bench_type }}_integer_benchmarks
name: ${{ github.sha }}_${{ matrix.bench_type }}_${{ matrix.command }}_benchmarks
path: ${{ env.RESULTS_FILENAME }}
- name: Checkout Slab repo

View File

@@ -0,0 +1,67 @@
name: pr_milestone_check
on:
pull_request:
types: [opened, edited, synchronize, reopened, milestoned, demilestoned]
permissions: {}
# zizmor: ignore[concurrency-limits] only Zama organization members can trigger this workflow
# external contributors workflows are manually approved
jobs:
check-empty-milestone:
name: pr_milestone_check/check-empty-milestone
runs-on: ubuntu-latest
if: github.event.pull_request.milestone == null
permissions:
pull-requests: write # Need write access on pull requests to post comment
steps:
- name: Post Reminder Comment
uses: octokit/request-action@dad4362715b7fb2ddedf9772c8670824af564f0d # v2.4.0
with:
route: POST /repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments
body: |
'### ❌ Milestone Missing
Please assign a milestone to this pull request. If your PR targets the next version of
TFHE-rs please use the current quarter milestone, e.g. "Q1 26".
If your PR targets a patch version for previous releases: consider creating a dedicated
milestone e.g. v1.5.1 if it does not exist yet.'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Check Final Status
run: |
echo "::error::Milestone is missing. This check is failing."
exit 1
check-milestone-open:
name: pr_milestone_check/check-milestone-open
runs-on: ubuntu-latest
if: github.event.pull_request.milestone != null && github.event.pull_request.milestone.state == 'closed'
permissions:
pull-requests: write # Need write access on pull requests to post comment
steps:
- name: Post Reminder Comment
uses: octokit/request-action@dad4362715b7fb2ddedf9772c8670824af564f0d # v2.4.0
with:
route: POST /repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments
body: |
'### ❌ Milestone is closed
Please assign an open milestone to this pull request. If your PR targets the next version of
TFHE-rs please use the current quarter milestone, e.g. "Q1 26".
If your PR targets a patch version for previous releases: consider creating a dedicated
milestone e.g. v1.5.1 if it does not exist yet.'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Check Final Status
run: |
echo "::error::Milestone is closed. This check is failing."
exit 1

View File

@@ -1454,6 +1454,13 @@ bench_integer_aes256_gpu: install_rs_check_toolchain
--bench integer-aes256 \
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
.PHONY: bench_integer_trivium_gpu # Run benchmarks for trivium on GPU backend
bench_integer_trivium_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench integer-trivium \
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
.PHONY: bench_integer_multi_bit # Run benchmarks for unsigned integer using multi-bit parameters
bench_integer_multi_bit: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_PARAM_TYPE=MULTI_BIT __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \

View File

@@ -79,7 +79,7 @@ tfhe = { version = "*", features = ["boolean", "shortint", "integer"] }
```
> [!Note]
> Note: You need Rust version 1.84 or newer to compile TFHE-rs. You can check your version with `rustc --version`.
> Note: You need Rust version 1.91.1 or newer to compile TFHE-rs. You can check your version with `rustc --version`.
> [!Note]
> Note: AArch64-based machines are not supported for Windows as it's currently missing an entropy source to be able to seed the [CSPRNGs](https://en.wikipedia.org/wiki/Cryptographically_secure_pseudorandom_number_generator) used in TFHE-rs.

View File

@@ -1,43 +1,43 @@
# Test vectors for TFHE
These test vectors are generated using [TFHE-rs](https://github.com/zama-ai/tfhe-rs), with the git tag `tfhe-test-vectors-0.2.0`.
They are TFHE-rs objects serialized in the [cbor format](https://cbor.io/). You can deserialize them using any cbor library for the language of your choice. For example, using the [cbor2](https://pypi.org/project/cbor2/) program, run: `cbor2 --pretty toy_params/lwe_a.cbor`.
They are TFHE-rs objects serialized in the [cbor format](https://cbor.io/). These can be deserialized using any cbor library for any programming languages. For example, using the [cbor2](https://pypi.org/project/cbor2/) program, the command to run is: `cbor2 --pretty toy_params/lwe_a.cbor`.
You will find 2 folders with test vectors for different parameter sets:
- `valid_params_128`: valid classical PBS parameters using a gaussian noise distribution, providing 128bits of security in the IND-CPA model and a bootstrapping probability of failure of 2^{-64}.
- `toy_params`: insecure parameters that yield smaller values
There are 2 folders with test vectors for different parameter sets:
- `valid_params_128`: valid classical PBS parameters using a Gaussian noise distribution, providing 128-bits of security in the IND-CPA model (i.e., the probability of failure is smaller than 2^{-64}).
- `toy_params`: insecure parameters that yield smaller values to simplify the bit comparison of the results.
The values are generated for the keyswitch -> bootstrap (KS-PBS) atomic pattern. The cleartext inputs are 2 values, A and B defined below.
The values are generated to compute a keyswitch (KS) followed by a bootstrap (PBS). The cleartext inputs are 2 values, A and B defined below.
All the random values are generated from a fixed seed, that can be found in the `RAND_SEED` constant below. The PRNG used is the one based on the AES block cipher in counter mode, from tfhe `tfhe-csprng` crate.
The programmable bootstrap is applied twice, with 2 different lut, the identity lut and a specific one (currently a x2 operation)
The bootstrap is applied twice, with 2 different lut, the identity lut and a specific one computing the double of the input value (i.e., f(x) = 2*x).
## Vectors
The following values are generated:
### Keys
| name | description | TFHE-rs type |
|------------------------|---------------------------------------------------------------------------------------|-----------------------------|
| `large_lwe_secret_key` | Encryption secret key, before the KS and after the PBS | `LweSecretKey<Vec<u64>>` |
| `small_lwe_secret_key` | Secret key encrypting ciphertexts between the KS and the PBS | `LweSecretKey<Vec<u64>>` |
| `ksk` | The keyswitching key to convert a ct from the large key to the small one | `LweKeyswitchKey<Vec<u64>>` |
| name | description | TFHE-rs type |
|------------------------|-----------------------------------------------------------------------------------------|-----------------------------|
| `large_lwe_secret_key` | Encryption secret key, before the KS and after the PBS | `LweSecretKey<Vec<u64>>` |
| `small_lwe_secret_key` | Secret key encrypting ciphertexts between the KS and the PBS | `LweSecretKey<Vec<u64>>` |
| `ksk` | The keyswitching key to convert a ct from the large key to the small one | `LweKeyswitchKey<Vec<u64>>` |
| `bsk` | the bootstrapping key to perform a programmable bootstrap on the keyswitched ciphertext | `LweBootstrapKey<Vec<u64>>` |
### Ciphertexts
| name | description | TFHE-rs type | Cleartext |
|----------------------|--------------------------------------------------------------------------------------------------------------|----------------------------|--------------|
| `lwe_a` | Lwe encryption of A | `LweCiphertext<Vec<u64>>` | `A` |
| `lwe_b` | Lwe encryption of B | `LweCiphertext<Vec<u64>>` | `B` |
| `lwe_sum` | Lwe encryption of A plus lwe encryption of B | `LweCiphertext<Vec<u64>>` | `A+B` |
| `lwe_prod` | Lwe encryption of A times cleartext B | `LweCiphertext<Vec<u64>>` | `A*B` |
| `lwe_ms` | The lwe ciphertext after the modswitch part of the PBS ([note](#non-native-encoding)) | `LweCiphertext<Vec<u64>>` | `A` |
| `lwe_ks` | The lwe ciphertext after the keyswitch | `LweCiphertext<Vec<u64>>` | `A` |
| `glwe_after_id_br` | The glwe returned by the application of the identity blind rotation on the mod switched ciphertexts. | `GlweCiphertext<Vec<u64>>` | rot id LUT |
| `lwe_after_id_pbs` | The lwe returned by the application of the sample extract operation on the output of the id blind rotation | `LweCiphertext<Vec<u64>>` | `A` |
| `glwe_after_spec_br` | The glwe returned by the application of the spec blind rotation on the mod switched ciphertexts. | `GlweCiphertext<Vec<u64>>` | rot spec LUT |
| `lwe_after_spec_pbs` | The lwe returned by the application of the sample extract operation on the output of the spec blind rotation | `LweCiphertext<Vec<u64>>` | `spec(A)` |
| name | description | TFHE-rs type | Cleartext |
|----------------------|-----------------------------------------------------------------------------------------------------|----------------------------|----------------------|
| `lwe_a` | LWE Ciphertext encrypting A | `LweCiphertext<Vec<u64>>` | `A` |
| `lwe_b` | LWE Ciphertext encrypting B | `LweCiphertext<Vec<u64>>` | `B` |
| `lwe_sum` | LWE Ciphertext encrypting A plus lwe encryption of B | `LweCiphertext<Vec<u64>>` | `A+B` |
| `lwe_prod` | LWE Ciphertext encrypting A times cleartext B | `LweCiphertext<Vec<u64>>` | `A*B` |
| `lwe_ms` | LWE Ciphertext encrypting A after a Modulus Switch from q to 2*N ([note](#non-native-encoding)) | `LweCiphertext<Vec<u64>>` | `A` |
| `lwe_ks` | LWE Ciphertext encrypting A after a keyswitch from `large_lwe_secret_key` to `small_lwe_secret_key` | `LweCiphertext<Vec<u64>>` | `A` |
| `glwe_after_id_br` | GLWE Ciphertext encrypting A after the application of the identity blind rotation on `lwe_ms` | `GlweCiphertext<Vec<u64>>` | rotation of id LUT |
| `lwe_after_id_pbs` | LWE Ciphertext encrypting A after the sample extract operation on `glwe_after_id_br` | `LweCiphertext<Vec<u64>>` | `A` |
| `glwe_after_spec_br` | GLWE Ciphertext encrypting spec(A) after the application of the spec blind rotation on `lwe_ms` | `GlweCiphertext<Vec<u64>>` | rotation of spec LUT |
| `lwe_after_spec_pbs` | LWE Ciphertext encrypting spec(A) after the sample extract operation on `glwe_after_spec_br` | `LweCiphertext<Vec<u64>>` | `spec(A)` |
Ciphertexts with the `_karatsuba` suffix are generated using the Karatsuba polynomial multiplication algorithm in the blind rotation, while default ciphertexts are generated using an FFT multiplication.
This makes it easier to reproduce bit exact results.

View File

@@ -86,6 +86,7 @@ fn main() {
"cuda/include/integer/integer.h",
"cuda/include/integer/rerand.h",
"cuda/include/aes/aes.h",
"cuda/include/trivium/trivium.h",
"cuda/include/zk/zk.h",
"cuda/include/keyswitch/keyswitch.h",
"cuda/include/keyswitch/ks_enums.h",

View File

@@ -29,15 +29,13 @@ template <typename Torus> struct int_aes_lut_buffers {
allocate_gpu_memory, size_tracker);
std::function<Torus(Torus, Torus)> and_lambda =
[](Torus a, Torus b) -> Torus { return a & b; };
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), this->and_lut->get_lut(0, 0),
this->and_lut->get_degree(0), this->and_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, and_lambda, allocate_gpu_memory);
auto active_streams_and_lut = streams.active_gpu_subset(
SBOX_MAX_AND_GATES * num_aes_inputs * sbox_parallelism,
params.pbs_type);
this->and_lut->broadcast_lut(active_streams_and_lut);
this->and_lut->generate_and_broadcast_bivariate_lut(
active_streams_and_lut, {0}, {and_lambda}, allocate_gpu_memory);
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
this->flush_lut = new int_radix_lut<Torus>(
@@ -46,14 +44,11 @@ template <typename Torus> struct int_aes_lut_buffers {
std::function<Torus(Torus)> flush_lambda = [](Torus x) -> Torus {
return x & 1;
};
generate_device_accumulator(
streams.stream(0), streams.gpu_index(0), this->flush_lut->get_lut(0, 0),
this->flush_lut->get_degree(0), this->flush_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, flush_lambda, allocate_gpu_memory);
auto active_streams_flush_lut = streams.active_gpu_subset(
AES_STATE_BITS * num_aes_inputs, params.pbs_type);
this->flush_lut->broadcast_lut(active_streams_flush_lut);
this->flush_lut->generate_and_broadcast_lut(
active_streams_flush_lut, {0}, {flush_lambda}, allocate_gpu_memory);
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
this->carry_lut = new int_radix_lut<Torus>(
@@ -61,14 +56,11 @@ template <typename Torus> struct int_aes_lut_buffers {
std::function<Torus(Torus)> carry_lambda = [](Torus x) -> Torus {
return (x >> 1) & 1;
};
generate_device_accumulator(
streams.stream(0), streams.gpu_index(0), this->carry_lut->get_lut(0, 0),
this->carry_lut->get_degree(0), this->carry_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, carry_lambda, allocate_gpu_memory);
auto active_streams_carry_lut =
streams.active_gpu_subset(num_aes_inputs, params.pbs_type);
this->carry_lut->broadcast_lut(active_streams_carry_lut);
this->carry_lut->generate_and_broadcast_lut(
active_streams_carry_lut, {0}, {carry_lambda}, allocate_gpu_memory);
this->carry_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
}

View File

@@ -65,14 +65,8 @@ template <typename Torus> struct boolean_bitop_buffer {
return x % params.message_modulus;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
message_extract_lut->get_lut(0, 0),
message_extract_lut->get_degree(0),
message_extract_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
lut_f_message_extract, gpu_memory_allocated);
message_extract_lut->broadcast_lut(active_streams);
message_extract_lut->generate_and_broadcast_lut(
active_streams, {0}, {lut_f_message_extract}, gpu_memory_allocated);
}
tmp_lwe_left = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
@@ -142,12 +136,8 @@ template <typename Torus> struct int_bitop_buffer {
}
};
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, 0),
lut->get_degree(0), lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus,
params.carry_modulus, lut_bivariate_f, gpu_memory_allocated);
lut->broadcast_lut(active_streams);
lut->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {lut_bivariate_f}, gpu_memory_allocated);
}
break;
default:
@@ -156,6 +146,8 @@ template <typename Torus> struct int_bitop_buffer {
num_radix_blocks, allocate_gpu_memory,
size_tracker);
std::vector<std::function<Torus(Torus)>> lut_funcs;
std::vector<uint32_t> lut_indices;
for (int i = 0; i < params.message_modulus; i++) {
auto rhs = i;
@@ -171,14 +163,13 @@ template <typename Torus> struct int_bitop_buffer {
return x ^ rhs;
}
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, i),
lut->get_degree(i), lut->get_max_degree(i), params.glwe_dimension,
params.polynomial_size, params.message_modulus,
params.carry_modulus, lut_univariate_scalar_f,
gpu_memory_allocated);
lut->broadcast_lut(active_streams);
lut_funcs.push_back(lut_univariate_scalar_f);
lut_indices.push_back(i);
}
lut->generate_and_broadcast_lut(active_streams, lut_indices, lut_funcs,
gpu_memory_allocated);
}
}
@@ -211,16 +202,11 @@ template <typename Torus> struct boolean_bitnot_buffer {
return x % message_modulus;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
message_extract_lut->get_lut(0, 0),
message_extract_lut->get_degree(0),
message_extract_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
lut_f_message_extract, gpu_memory_allocated);
auto active_streams =
streams.active_gpu_subset(lwe_ciphertext_count, params.pbs_type);
message_extract_lut->broadcast_lut(active_streams);
message_extract_lut->generate_and_broadcast_lut(
active_streams, {0}, {lut_f_message_extract}, gpu_memory_allocated);
}
}

View File

@@ -28,21 +28,17 @@ template <typename Torus> struct int_extend_radix_with_sign_msb_buffer {
uint32_t bits_per_block = std::log2(params.message_modulus);
uint32_t msg_modulus = params.message_modulus;
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, 0),
lut->get_degree(0), lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
[msg_modulus, bits_per_block](Torus x) {
auto active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
lut->generate_and_broadcast_lut(
active_streams, {0}, {[msg_modulus, bits_per_block](Torus x) {
const auto xm = x % msg_modulus;
const auto sign_bit = (xm >> (bits_per_block - 1)) & 1;
return (Torus)((msg_modulus - 1) * sign_bit);
},
}},
allocate_gpu_memory);
auto active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
lut->broadcast_lut(active_streams);
this->last_block = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(

View File

@@ -85,24 +85,6 @@ template <typename Torus> struct int_cmux_buffer {
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
allocate_gpu_memory, size_tracker);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), predicate_lut->get_lut(0, 0),
predicate_lut->get_degree(0), predicate_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, inverted_lut_f, gpu_memory_allocated);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), predicate_lut->get_lut(0, 1),
predicate_lut->get_degree(1), predicate_lut->get_max_degree(1),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, lut_f, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
message_extract_lut->get_lut(0, 0), message_extract_lut->get_degree(0),
message_extract_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
message_extract_lut_f, gpu_memory_allocated);
Torus *h_lut_indexes = predicate_lut->h_lut_indexes;
for (int index = 0; index < 2 * num_radix_blocks; index++) {
if (index < num_radix_blocks) {
@@ -115,12 +97,18 @@ template <typename Torus> struct int_cmux_buffer {
predicate_lut->get_lut_indexes(0, 0), h_lut_indexes,
2 * num_radix_blocks * sizeof(Torus), streams.stream(0),
streams.gpu_index(0), allocate_gpu_memory);
auto active_streams_pred =
streams.active_gpu_subset(2 * num_radix_blocks, params.pbs_type);
predicate_lut->broadcast_lut(active_streams_pred);
predicate_lut->generate_and_broadcast_bivariate_lut(
active_streams_pred, {0, 1}, {inverted_lut_f, lut_f},
gpu_memory_allocated);
auto active_streams_msg =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
message_extract_lut->broadcast_lut(active_streams_msg);
message_extract_lut->generate_and_broadcast_lut(
active_streams_msg, {0}, {message_extract_lut_f}, gpu_memory_allocated);
}
void release(CudaStreams streams) {

View File

@@ -39,22 +39,21 @@ template <typename Torus> struct int_are_all_block_true_buffer {
max_chunks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
is_max_value = new int_radix_lut<Torus>(streams, params, 2, max_chunks,
allocate_gpu_memory, size_tracker);
auto is_max_value_f = [max_value](Torus x) -> Torus {
return x == max_value;
};
preallocated_h_lut = (Torus *)malloc(
(params.glwe_dimension + 1) * params.polynomial_size * sizeof(Torus));
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), is_max_value->get_lut(0, 0),
is_max_value->get_degree(0), is_max_value->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, is_max_value_f, gpu_memory_allocated);
is_max_value = new int_radix_lut<Torus>(streams, params, 2, max_chunks,
allocate_gpu_memory, size_tracker);
auto active_streams =
streams.active_gpu_subset(max_chunks, params.pbs_type);
is_max_value->broadcast_lut(active_streams);
auto is_max_value_f = [max_value](Torus x) -> Torus {
return x == max_value;
};
is_max_value->generate_and_broadcast_lut(
active_streams, {0}, {is_max_value_f}, gpu_memory_allocated);
}
void release(CudaStreams streams) {
@@ -103,15 +102,10 @@ template <typename Torus> struct int_comparison_eq_buffer {
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), is_non_zero_lut->get_lut(0, 0),
is_non_zero_lut->get_degree(0), is_non_zero_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, is_non_zero_lut_f, gpu_memory_allocated);
auto active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
is_non_zero_lut->broadcast_lut(active_streams);
is_non_zero_lut->generate_and_broadcast_lut(
active_streams, {0}, {is_non_zero_lut_f}, gpu_memory_allocated);
// Scalar may have up to num_radix_blocks blocks
scalar_comparison_luts = new int_radix_lut<Torus>(
@@ -129,32 +123,28 @@ template <typename Torus> struct int_comparison_eq_buffer {
return (lhs == rhs);
}
};
std::vector<std::function<Torus(Torus)>> lut_funcs;
std::vector<uint32_t> lut_indices;
for (int i = 0; i < total_modulus; i++) {
auto lut_f = [i, operator_f](Torus x) -> Torus {
return operator_f(i, x);
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
scalar_comparison_luts->get_lut(0, i),
scalar_comparison_luts->get_degree(i),
scalar_comparison_luts->get_max_degree(i), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
lut_f, gpu_memory_allocated);
lut_funcs.push_back(lut_f);
lut_indices.push_back(i);
}
scalar_comparison_luts->broadcast_lut(active_streams);
scalar_comparison_luts->generate_and_broadcast_lut(
active_streams, lut_indices, lut_funcs, gpu_memory_allocated);
if (op == COMPARISON_TYPE::EQ || op == COMPARISON_TYPE::NE) {
operator_lut =
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
allocate_gpu_memory, size_tracker);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), operator_lut->get_lut(0, 0),
operator_lut->get_degree(0), operator_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, operator_f, gpu_memory_allocated);
operator_lut->broadcast_lut(active_streams);
operator_lut->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {operator_f}, gpu_memory_allocated);
// operator_lut->broadcast_lut(active_streams);
} else {
operator_lut = nullptr;
}
@@ -221,9 +211,6 @@ template <typename Torus> struct int_tree_sign_reduction_buffer {
streams.stream(0), streams.gpu_index(0), tmp_y, num_radix_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
// LUTs
tree_inner_leaf_lut =
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
allocate_gpu_memory, size_tracker);
tree_last_leaf_lut = new int_radix_lut<Torus>(
streams, params, 1, 1, allocate_gpu_memory, size_tracker);
@@ -234,15 +221,14 @@ template <typename Torus> struct int_tree_sign_reduction_buffer {
tree_last_leaf_scalar_lut = new int_radix_lut<Torus>(
streams, params, 1, 1, allocate_gpu_memory, size_tracker);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
tree_inner_leaf_lut->get_lut(0, 0), tree_inner_leaf_lut->get_degree(0),
tree_inner_leaf_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
block_selector_f, gpu_memory_allocated);
tree_inner_leaf_lut =
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
allocate_gpu_memory, size_tracker);
auto active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
tree_inner_leaf_lut->broadcast_lut(active_streams);
tree_inner_leaf_lut->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {block_selector_f}, allocate_gpu_memory);
}
void release(CudaStreams streams) {
@@ -426,12 +412,8 @@ template <typename Torus> struct int_comparison_buffer {
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), identity_lut->get_lut(0, 0),
identity_lut->get_degree(0), identity_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, identity_lut_f, gpu_memory_allocated);
identity_lut->broadcast_lut(active_streams);
identity_lut->generate_and_broadcast_lut(
active_streams, {0}, {identity_lut_f}, gpu_memory_allocated);
uint32_t total_modulus = params.message_modulus * params.carry_modulus;
auto is_zero_f = [total_modulus](Torus x) -> Torus {
@@ -441,13 +423,8 @@ template <typename Torus> struct int_comparison_buffer {
is_zero_lut = new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), is_zero_lut->get_lut(0, 0),
is_zero_lut->get_degree(0), is_zero_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, is_zero_f, gpu_memory_allocated);
is_zero_lut->broadcast_lut(active_streams);
is_zero_lut->generate_and_broadcast_lut(active_streams, {0}, {is_zero_f},
gpu_memory_allocated);
switch (op) {
case COMPARISON_TYPE::MAX:
@@ -522,13 +499,9 @@ template <typename Torus> struct int_comparison_buffer {
PANIC("Cuda error: sign_lut creation failed due to wrong function.")
};
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), signed_lut->get_lut(0, 0),
signed_lut->get_degree(0), signed_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, signed_lut_f, gpu_memory_allocated);
auto active_streams = streams.active_gpu_subset(1, params.pbs_type);
signed_lut->broadcast_lut(active_streams);
signed_lut->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {signed_lut_f}, gpu_memory_allocated);
}
preallocated_h_lut = (Torus *)malloc(
(params.glwe_dimension + 1) * params.polynomial_size * sizeof(Torus));

View File

@@ -283,12 +283,9 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
zero_out_if_not_1_lut_2};
size_t lut_gpu_indexes[2] = {0, 3};
for (int j = 0; j < 2; j++) {
generate_device_accumulator<Torus>(
streams.stream(lut_gpu_indexes[j]),
streams.gpu_index(lut_gpu_indexes[j]), luts[j]->get_lut(0, 0),
luts[j]->get_degree(0), luts[j]->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, zero_out_if_not_1_lut_f, gpu_memory_allocated);
luts[j]->generate_and_broadcast_lut(streams.get_ith(lut_gpu_indexes[j]),
{0}, {zero_out_if_not_1_lut_f},
gpu_memory_allocated);
}
luts[0] = zero_out_if_not_2_lut_1;
@@ -296,12 +293,9 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
lut_gpu_indexes[0] = 1;
lut_gpu_indexes[1] = 2;
for (int j = 0; j < 2; j++) {
generate_device_accumulator<Torus>(
streams.stream(lut_gpu_indexes[j]),
streams.gpu_index(lut_gpu_indexes[j]), luts[j]->get_lut(0, 0),
luts[j]->get_degree(0), luts[j]->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, zero_out_if_not_2_lut_f, gpu_memory_allocated);
luts[j]->generate_and_broadcast_lut(streams.get_ith(lut_gpu_indexes[j]),
{0}, {zero_out_if_not_2_lut_f},
gpu_memory_allocated);
}
quotient_lut_1 =
@@ -321,21 +315,12 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
};
auto quotient_lut_3_f = [](Torus cond) -> Torus { return cond * 3; };
generate_device_accumulator<Torus>(
streams.stream(2), streams.gpu_index(2), quotient_lut_1->get_lut(0, 0),
quotient_lut_1->get_degree(0), quotient_lut_1->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, quotient_lut_1_f, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(1), streams.gpu_index(1), quotient_lut_2->get_lut(0, 0),
quotient_lut_2->get_degree(0), quotient_lut_2->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, quotient_lut_2_f, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), quotient_lut_3->get_lut(0, 0),
quotient_lut_3->get_degree(0), quotient_lut_3->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, quotient_lut_3_f, gpu_memory_allocated);
quotient_lut_1->generate_and_broadcast_lut(
streams.get_ith(2), {0}, {quotient_lut_1_f}, gpu_memory_allocated);
quotient_lut_2->generate_and_broadcast_lut(
streams.get_ith(1), {0}, {quotient_lut_2_f}, gpu_memory_allocated);
quotient_lut_3->generate_and_broadcast_lut(
streams.get_ith(0), {0}, {quotient_lut_3_f}, gpu_memory_allocated);
message_extract_lut_1 = new int_radix_lut<Torus>(
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
@@ -350,15 +335,12 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
luts[0] = message_extract_lut_1;
luts[1] = message_extract_lut_2;
auto active_streams =
streams.active_gpu_subset(num_blocks, params.pbs_type);
for (int j = 0; j < 2; j++) {
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), luts[j]->get_lut(0, 0),
luts[j]->get_degree(0), luts[j]->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, lut_f_message_extract, gpu_memory_allocated);
auto active_streams =
streams.active_gpu_subset(num_blocks, params.pbs_type);
luts[j]->broadcast_lut(active_streams);
luts[j]->generate_and_broadcast_lut(
active_streams, {0}, {lut_f_message_extract}, gpu_memory_allocated);
}
}
@@ -1007,24 +989,14 @@ template <typename Torus> struct unsigned_int_div_rem_memory {
masking_luts_2[i] = new int_radix_lut<Torus>(
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
masking_luts_1[i]->get_lut(0, 0), masking_luts_1[i]->get_degree(0),
masking_luts_1[i]->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
lut_f_masking, gpu_memory_allocated);
auto active_streams_1 = streams.active_gpu_subset(1, params.pbs_type);
masking_luts_1[i]->broadcast_lut(active_streams_1);
masking_luts_1[i]->generate_and_broadcast_lut(
active_streams_1, {0}, {lut_f_masking}, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
masking_luts_2[i]->get_lut(0, 0), masking_luts_2[i]->get_degree(0),
masking_luts_2[i]->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
lut_f_masking, gpu_memory_allocated);
auto active_streams_2 =
streams.active_gpu_subset(num_blocks, params.pbs_type);
masking_luts_2[i]->broadcast_lut(active_streams_2);
masking_luts_2[i]->generate_and_broadcast_lut(
active_streams_2, {0}, {lut_f_masking}, gpu_memory_allocated);
}
// create and generate message_extract_lut_1 and message_extract_lut_2
@@ -1042,15 +1014,12 @@ template <typename Torus> struct unsigned_int_div_rem_memory {
int_radix_lut<Torus> *luts[2] = {message_extract_lut_1,
message_extract_lut_2};
auto active_streams =
streams.active_gpu_subset(num_blocks, params.pbs_type);
for (int j = 0; j < 2; j++) {
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), luts[j]->get_lut(0, 0),
luts[j]->get_degree(0), luts[j]->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, lut_f_message_extract, gpu_memory_allocated);
luts[j]->broadcast_lut(active_streams);
luts[j]->generate_and_broadcast_lut(
active_streams, {0}, {lut_f_message_extract}, gpu_memory_allocated);
}
// Give name to closures to improve readability
@@ -1141,14 +1110,8 @@ template <typename Torus> struct unsigned_int_div_rem_memory {
merge_overflow_flags_luts[i] = new int_radix_lut<Torus>(
streams, params, 1, 1, allocate_gpu_memory, size_tracker);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
merge_overflow_flags_luts[i]->get_lut(0, 0),
merge_overflow_flags_luts[i]->get_degree(0),
merge_overflow_flags_luts[i]->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, lut_f_bit, gpu_memory_allocated);
merge_overflow_flags_luts[i]->broadcast_lut(active_gpu_count_for_bits);
merge_overflow_flags_luts[i]->generate_and_broadcast_bivariate_lut(
active_gpu_count_for_bits, {0}, {lut_f_bit}, gpu_memory_allocated);
}
}
@@ -1557,16 +1520,12 @@ template <typename Torus> struct int_div_rem_memory {
compare_signed_bits_lut = new int_radix_lut<Torus>(
streams, params, 1, 1, allocate_gpu_memory, size_tracker);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
compare_signed_bits_lut->get_lut(0, 0),
compare_signed_bits_lut->get_degree(0),
compare_signed_bits_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
f_compare_extracted_signed_bits, gpu_memory_allocated);
auto active_gpu_count_cmp =
streams.active_gpu_subset(1, params.pbs_type); // only 1 block needed
compare_signed_bits_lut->broadcast_lut(active_gpu_count_cmp);
compare_signed_bits_lut->generate_and_broadcast_bivariate_lut(
active_gpu_count_cmp, {0}, {f_compare_extracted_signed_bits},
gpu_memory_allocated);
}
}

View File

@@ -53,13 +53,8 @@ template <typename Torus> struct int_prepare_count_of_consecutive_bits_buffer {
return count;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), univ_lut_mem->get_lut(0, 0),
univ_lut_mem->get_degree(0), univ_lut_mem->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, generate_uni_lut_lambda, allocate_gpu_memory);
univ_lut_mem->broadcast_lut(active_streams);
univ_lut_mem->generate_and_broadcast_lut(
active_streams, {0}, {generate_uni_lut_lambda}, allocate_gpu_memory);
auto generate_bi_lut_lambda =
[num_bits](Torus block_num_bit_count,
@@ -70,13 +65,8 @@ template <typename Torus> struct int_prepare_count_of_consecutive_bits_buffer {
return 0;
};
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), biv_lut_mem->get_lut(0, 0),
biv_lut_mem->get_degree(0), biv_lut_mem->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, generate_bi_lut_lambda, allocate_gpu_memory);
biv_lut_mem->broadcast_lut(active_streams);
biv_lut_mem->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {generate_bi_lut_lambda}, allocate_gpu_memory);
this->tmp_ct = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
@@ -232,7 +222,7 @@ template <typename Torus> struct int_ilog2_buffer {
this->sum_output_not_propagated, counter_num_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->lut_message_not =
lut_message_not =
new int_radix_lut<Torus>(streams, params, 1, counter_num_blocks,
allocate_gpu_memory, size_tracker);
std::function<Torus(Torus)> lut_message_lambda =
@@ -240,16 +230,11 @@ template <typename Torus> struct int_ilog2_buffer {
uint64_t message = x % this->params.message_modulus;
return (~message) % this->params.message_modulus;
};
generate_device_accumulator(streams.stream(0), streams.gpu_index(0),
this->lut_message_not->get_lut(0, 0),
this->lut_message_not->get_degree(0),
this->lut_message_not->get_max_degree(0),
params.glwe_dimension, params.polynomial_size,
params.message_modulus, params.carry_modulus,
lut_message_lambda, allocate_gpu_memory);
auto active_streams =
streams.active_gpu_subset(counter_num_blocks, params.pbs_type);
lut_message_not->broadcast_lut(active_streams);
lut_message_not->generate_and_broadcast_lut(
active_streams, {0}, {lut_message_lambda}, allocate_gpu_memory);
this->lut_carry_not =
new int_radix_lut<Torus>(streams, params, 1, counter_num_blocks,
@@ -259,13 +244,8 @@ template <typename Torus> struct int_ilog2_buffer {
uint64_t carry = x / this->params.message_modulus;
return (~carry) % this->params.message_modulus;
};
generate_device_accumulator(
streams.stream(0), streams.gpu_index(0),
this->lut_carry_not->get_lut(0, 0), this->lut_carry_not->get_degree(0),
this->lut_carry_not->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
lut_carry_lambda, allocate_gpu_memory);
lut_carry_not->broadcast_lut(active_streams);
lut_carry_not->generate_and_broadcast_lut(
active_streams, {0}, {lut_carry_lambda}, allocate_gpu_memory);
this->message_blocks_not = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(

View File

@@ -9,6 +9,7 @@
#include "utils/helper_multi_gpu.cuh"
#include <cmath>
#include <functional>
#include <map>
#include <queue>
#include <stdio.h>
@@ -835,6 +836,56 @@ struct int_radix_lut_custom_input_output {
}
}
void generate_and_broadcast_lut(
const CudaStreams &streams, std::vector<uint32_t> lut_indexes,
std::vector<std::function<OutputTorus(OutputTorus)>> f,
bool gpu_memory_allocated) {
// streams should be a subset of active_streams
for (uint32_t i = 0; i < lut_indexes.size(); ++i) {
generate_device_accumulator<OutputTorus>(
streams.stream(0), streams.gpu_index(0), get_lut(0, lut_indexes[i]),
get_degree(lut_indexes[i]), get_max_degree(lut_indexes[i]),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, f[i], gpu_memory_allocated);
}
//broadcast_lut(streams);
}
void generate_and_broadcast_bivariate_lut(
const CudaStreams &streams, std::vector<uint32_t> lut_indexes,
std::vector<std::function<OutputTorus(OutputTorus, OutputTorus)>> f,
bool gpu_memory_allocated) {
// streams should be a subset of active_streams
/* for (int fidx = 0; fidx < f.size(); ++fidx) {
__int128_t f_hash = 0;
uint32_t bits_per_lut_val = 5;
uint32_t input_modulus_sup =
params.message_modulus * params.carry_modulus;
for (uint32_t i = 0; i < input_modulus_sup; ++i) {
OutputTorus f_eval =
f[fidx](i / params.message_modulus, i % params.message_modulus);
GPU_ASSERT(f_eval < (1 << bits_per_lut_val),
"LUT value expected bitwidth overflow");
f_hash |= f_eval;
f_hash <<= bits_per_lut_val;
}
printf("%016llX%016llX\n",
(unsigned long long)((f_hash >> 64) & 0xFFFFFFFFFFFFFFFF),
(unsigned long long)(f_hash & 0xFFFFFFFFFFFFFFFF));
}
*/
for (uint32_t i = 0; i < lut_indexes.size(); ++i) {
generate_device_accumulator_bivariate<InputTorus>(
streams.stream(0), streams.gpu_index(0), get_lut(0, lut_indexes[i]),
get_degree(lut_indexes[i]), get_max_degree(lut_indexes[i]),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, f[i], gpu_memory_allocated);
}
//broadcast_lut(streams);
}
void release(CudaStreams streams) {
PANIC_IF_FALSE(lut_indexes_vec.size() == lut_vec.size(),
"Lut vec and Lut vec indexes must have the same size");
@@ -985,18 +1036,15 @@ template <typename Torus> struct int_bit_extract_luts_buffer {
bits_per_block * num_radix_blocks,
allocate_gpu_memory, size_tracker);
std::vector<std::function<Torus(Torus)>> lut_funs;
std::vector<uint32_t> lut_indices;
for (int i = 0; i < bits_per_block; i++) {
auto operator_f = [i, final_offset](Torus x) -> Torus {
Torus y = (x >> i) & 1;
return y << final_offset;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, i),
lut->get_degree(i), lut->get_max_degree(i), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
operator_f, gpu_memory_allocated);
lut_funs.push_back(operator_f);
lut_indices.push_back(i);
}
/**
@@ -1015,7 +1063,10 @@ template <typename Torus> struct int_bit_extract_luts_buffer {
auto active_streams = streams.active_gpu_subset(
bits_per_block * num_radix_blocks, params.pbs_type);
lut->broadcast_lut(active_streams);
lut->generate_and_broadcast_lut(active_streams, lut_indices, lut_funs,
gpu_memory_allocated);
// lut->broadcast_lut(active_streams);
/**
* the input indexes should take the first bits_per_block PBS to target
@@ -1091,24 +1142,6 @@ template <typename Torus> struct int_fullprop_buffer {
};
//
Torus *lut_buffer_message = lut->get_lut(0, 0);
uint64_t *message_degree = lut->get_degree(0);
uint64_t *message_max_degree = lut->get_max_degree(0);
Torus *lut_buffer_carry = lut->get_lut(0, 1);
uint64_t *carry_degree = lut->get_degree(1);
uint64_t *carry_max_degree = lut->get_max_degree(1);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), lut_buffer_message,
message_degree, message_max_degree, params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
lut_f_message, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), lut_buffer_carry, carry_degree,
carry_max_degree, params.glwe_dimension, params.polynomial_size,
params.message_modulus, params.carry_modulus, lut_f_carry,
gpu_memory_allocated);
uint64_t lwe_indexes_size = 2 * sizeof(Torus);
Torus *h_lwe_indexes = (Torus *)malloc(lwe_indexes_size);
@@ -1118,9 +1151,15 @@ template <typename Torus> struct int_fullprop_buffer {
cuda_memcpy_with_size_tracking_async_to_gpu(
lwe_indexes, h_lwe_indexes, lwe_indexes_size, streams.stream(0),
streams.gpu_index(0), allocate_gpu_memory);
//
// No broadcast is needed because full prop is done on 1 single GPU.
// By passing a single-GPU CudaStreams with streams.get_ith(0) the LUT is
// not broadcast.
//
lut->generate_and_broadcast_lut(streams.get_ith(0), {0, 1},
{lut_f_message, lut_f_carry},
gpu_memory_allocated);
tmp_small_lwe_vector = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
@@ -1238,9 +1277,10 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
if (total_ciphertexts > 0 ||
reduce_degrees_for_single_carry_propagation) {
uint64_t size_tracker = 0;
allocated_luts_message_carry = true;
luts_message_carry = new int_radix_lut<Torus>(
streams, params, 2, pbs_count, true, size_tracker);
allocated_luts_message_carry = true;
uint64_t message_modulus_bits =
(uint64_t)std::log2(params.message_modulus);
uint64_t carry_modulus_bits = (uint64_t)std::log2(params.carry_modulus);
@@ -1256,7 +1296,9 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
streams, upper_bound_num_blocks, size_tracker, true);
}
}
if (allocated_luts_message_carry) {
auto message_acc = luts_message_carry->get_lut(0, 0);
auto carry_acc = luts_message_carry->get_lut(0, 1);
@@ -1268,22 +1310,11 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
return x / message_modulus;
};
// generate accumulators
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), message_acc,
luts_message_carry->get_degree(0),
luts_message_carry->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, message_modulus, params.carry_modulus,
lut_f_message, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), carry_acc,
luts_message_carry->get_degree(1),
luts_message_carry->get_max_degree(1), params.glwe_dimension,
params.polynomial_size, message_modulus, params.carry_modulus,
lut_f_carry, gpu_memory_allocated);
auto active_gpu_count_mc =
streams.active_gpu_subset(pbs_count, params.pbs_type);
luts_message_carry->broadcast_lut(active_gpu_count_mc);
luts_message_carry->generate_and_broadcast_lut(
active_gpu_count_mc, {0, 1}, {lut_f_message, lut_f_carry},
gpu_memory_allocated);
}
}
int_sum_ciphertexts_vec_memory(
@@ -1418,10 +1449,6 @@ template <typename Torus> struct int_seq_group_prop_memory {
uint32_t group_size, uint32_t big_lwe_size_bytes,
bool allocate_gpu_memory, uint64_t &size_tracker) {
gpu_memory_allocated = allocate_gpu_memory;
auto glwe_dimension = params.glwe_dimension;
auto polynomial_size = params.polynomial_size;
auto message_modulus = params.message_modulus;
auto carry_modulus = params.carry_modulus;
grouping_size = group_size;
group_resolved_carries = new CudaRadixCiphertextFFI;
@@ -1431,22 +1458,20 @@ template <typename Torus> struct int_seq_group_prop_memory {
allocate_gpu_memory);
int num_seq_luts = grouping_size - 1;
Torus *h_seq_lut_indexes = (Torus *)malloc(num_seq_luts * sizeof(Torus));
lut_sequential_algorithm =
new int_radix_lut<Torus>(streams, params, num_seq_luts, num_seq_luts,
allocate_gpu_memory, size_tracker);
std::vector<std::function<Torus(Torus)>> lut_funcs;
std::vector<uint32_t> lut_indices;
Torus *h_seq_lut_indexes = (Torus *)malloc(num_seq_luts * sizeof(Torus));
for (int index = 0; index < num_seq_luts; index++) {
auto f_lut_sequential = [index](Torus propa_cum_sum_block) {
return (propa_cum_sum_block >> (index + 1)) & 1;
};
auto seq_lut = lut_sequential_algorithm->get_lut(0, index);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), seq_lut,
lut_sequential_algorithm->get_degree(index),
lut_sequential_algorithm->get_max_degree(index), glwe_dimension,
polynomial_size, message_modulus, carry_modulus, f_lut_sequential,
gpu_memory_allocated);
lut_funcs.push_back(f_lut_sequential);
h_seq_lut_indexes[index] = index;
lut_indices.push_back(index);
}
Torus *seq_lut_indexes = lut_sequential_algorithm->get_lut_indexes(0, 0);
cuda_memcpy_with_size_tracking_async_to_gpu(
@@ -1454,9 +1479,12 @@ template <typename Torus> struct int_seq_group_prop_memory {
streams.stream(0), streams.gpu_index(0), allocate_gpu_memory);
auto active_streams =
streams.active_gpu_subset(num_seq_luts, params.pbs_type);
lut_sequential_algorithm->broadcast_lut(active_streams);
lut_sequential_algorithm->generate_and_broadcast_lut(
active_streams, lut_indices, lut_funcs, gpu_memory_allocated);
// lut_sequential_algorithm->broadcast_lut(active_streams);
free(h_seq_lut_indexes);
};
}
void release(CudaStreams streams) {
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
group_resolved_carries,
@@ -1478,10 +1506,6 @@ template <typename Torus> struct int_hs_group_prop_memory {
uint32_t num_groups, uint32_t big_lwe_size_bytes,
bool allocate_gpu_memory, uint64_t &size_tracker) {
gpu_memory_allocated = allocate_gpu_memory;
auto glwe_dimension = params.glwe_dimension;
auto polynomial_size = params.polynomial_size;
auto message_modulus = params.message_modulus;
auto carry_modulus = params.carry_modulus;
auto f_lut_hillis_steele = [](Torus msb, Torus lsb) -> Torus {
if (msb == 2) {
@@ -1501,16 +1525,11 @@ template <typename Torus> struct int_hs_group_prop_memory {
lut_hillis_steele = new int_radix_lut<Torus>(
streams, params, 1, num_groups, allocate_gpu_memory, size_tracker);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
lut_hillis_steele->get_lut(0, 0), lut_hillis_steele->get_degree(0),
lut_hillis_steele->get_max_degree(0), glwe_dimension, polynomial_size,
message_modulus, carry_modulus, f_lut_hillis_steele,
gpu_memory_allocated);
auto active_streams =
streams.active_gpu_subset(num_groups, params.pbs_type);
lut_hillis_steele->broadcast_lut(active_streams);
};
lut_hillis_steele->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {f_lut_hillis_steele}, gpu_memory_allocated);
}
void release(CudaStreams streams) {
lut_hillis_steele->release(streams);
@@ -1800,112 +1819,6 @@ template <typename Torus> struct int_prop_simu_group_carries_memory {
num_extra_luts = 1;
}
uint32_t num_luts_second_step = 2 * grouping_size + num_extra_luts;
luts_array_second_step = new int_radix_lut<Torus>(
streams, params, num_luts_second_step, num_radix_blocks,
allocate_gpu_memory, size_tracker);
// luts for first group inner propagation
for (int lut_id = 0; lut_id < grouping_size - 1; lut_id++) {
auto f_first_grouping_inner_propagation =
[lut_id](Torus propa_cum_sum_block) -> Torus {
uint64_t carry = (propa_cum_sum_block >> lut_id) & 1;
if (carry != 0) {
return 2ull; // Generates Carry
} else {
return 0ull; // Does not generate carry
}
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
luts_array_second_step->get_lut(0, lut_id),
luts_array_second_step->get_degree(lut_id),
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
polynomial_size, message_modulus, carry_modulus,
f_first_grouping_inner_propagation, gpu_memory_allocated);
}
auto f_first_grouping_outer_propagation =
[num_bits_in_block](Torus block) -> Torus {
return (block >> (num_bits_in_block - 1)) & 1;
};
int lut_id = grouping_size - 1;
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
luts_array_second_step->get_lut(0, lut_id),
luts_array_second_step->get_degree(lut_id),
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
polynomial_size, message_modulus, carry_modulus,
f_first_grouping_outer_propagation, gpu_memory_allocated);
// for other groupings inner propagation
for (int index = 0; index < grouping_size; index++) {
uint32_t lut_id = index + grouping_size;
auto f_other_groupings_inner_propagation =
[index](Torus propa_cum_sum_block) -> Torus {
uint64_t mask = (2 << index) - 1;
if (propa_cum_sum_block >= (2 << index)) {
return 2ull; // Generates
} else if ((propa_cum_sum_block & mask) == mask) {
return 1ull; // Propagate
} else {
return 0ull; // Nothing
}
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
luts_array_second_step->get_lut(0, lut_id),
luts_array_second_step->get_degree(lut_id),
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
polynomial_size, message_modulus, carry_modulus,
f_other_groupings_inner_propagation, gpu_memory_allocated);
}
if (use_sequential_algorithm_to_resolve_group_carries) {
for (int index = 0; index < grouping_size - 1; index++) {
uint32_t lut_id = index + 2 * grouping_size;
auto f_group_propagation = [index, block_modulus,
num_bits_in_block](Torus block) -> Torus {
if (block == (block_modulus - 1)) {
return 0ull;
} else {
return ((UINT64_MAX << index) % (1ull << (num_bits_in_block + 1)));
}
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
luts_array_second_step->get_lut(0, lut_id),
luts_array_second_step->get_degree(lut_id),
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
polynomial_size, message_modulus, carry_modulus,
f_group_propagation, gpu_memory_allocated);
}
} else {
uint32_t lut_id = 2 * grouping_size;
auto f_group_propagation = [block_modulus](Torus block) {
if (block == (block_modulus - 1)) {
return 2ull;
} else {
return UINT64_MAX % (block_modulus * 2ull);
}
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
luts_array_second_step->get_lut(0, lut_id),
luts_array_second_step->get_degree(lut_id),
luts_array_second_step->get_max_degree(lut_id), glwe_dimension,
polynomial_size, message_modulus, carry_modulus, f_group_propagation,
gpu_memory_allocated);
}
Torus *h_second_lut_indexes = (Torus *)malloc(lut_indexes_size);
for (int index = 0; index < num_radix_blocks; index++) {
@@ -1941,6 +1854,11 @@ template <typename Torus> struct int_prop_simu_group_carries_memory {
}
}
uint32_t num_luts_second_step = 2 * grouping_size + num_extra_luts;
luts_array_second_step = new int_radix_lut<Torus>(
streams, params, num_luts_second_step, num_radix_blocks,
allocate_gpu_memory, size_tracker);
// copy the indexes to the gpu
Torus *second_lut_indexes = luts_array_second_step->get_lut_indexes(0, 0);
cuda_memcpy_with_size_tracking_async_to_gpu(
@@ -1951,9 +1869,92 @@ template <typename Torus> struct int_prop_simu_group_carries_memory {
scalar_array_cum_sum, h_scalar_array_cum_sum,
num_radix_blocks * sizeof(Torus), streams.stream(0),
streams.gpu_index(0), allocate_gpu_memory);
std::vector<std::function<Torus(Torus)>> lut_funcs;
std::vector<uint32_t> lut_ids;
// luts for first group inner propagation
for (int lut_id = 0; lut_id < grouping_size - 1; lut_id++) {
auto f_first_grouping_inner_propagation =
[lut_id](Torus propa_cum_sum_block) -> Torus {
uint64_t carry = (propa_cum_sum_block >> lut_id) & 1;
if (carry != 0) {
return 2ull; // Generates Carry
} else {
return 0ull; // Does not generate carry
}
};
lut_funcs.push_back(f_first_grouping_inner_propagation);
lut_ids.push_back(lut_id);
}
auto f_first_grouping_outer_propagation =
[num_bits_in_block](Torus block) -> Torus {
return (block >> (num_bits_in_block - 1)) & 1;
};
int lut_id = grouping_size - 1;
lut_funcs.push_back(f_first_grouping_outer_propagation);
lut_ids.push_back(lut_id);
// for other groupings inner propagation
for (int index = 0; index < grouping_size; index++) {
uint32_t lut_id = index + grouping_size;
auto f_other_groupings_inner_propagation =
[index](Torus propa_cum_sum_block) -> Torus {
uint64_t mask = (2 << index) - 1;
if (propa_cum_sum_block >= (2 << index)) {
return 2ull; // Generates
} else if ((propa_cum_sum_block & mask) == mask) {
return 1ull; // Propagate
} else {
return 0ull; // Nothing
}
};
lut_funcs.push_back(f_other_groupings_inner_propagation);
lut_ids.push_back(lut_id);
}
if (use_sequential_algorithm_to_resolve_group_carries) {
for (int index = 0; index < grouping_size - 1; index++) {
uint32_t lut_id = index + 2 * grouping_size;
auto f_group_propagation = [index, block_modulus,
num_bits_in_block](Torus block) -> Torus {
if (block == (block_modulus - 1)) {
return 0ull;
} else {
return ((UINT64_MAX << index) % (1ull << (num_bits_in_block + 1)));
}
};
lut_funcs.push_back(f_group_propagation);
lut_ids.push_back(lut_id);
}
} else {
uint32_t lut_id = 2 * grouping_size;
auto f_group_propagation = [block_modulus](Torus block) {
if (block == (block_modulus - 1)) {
return 2ull;
} else {
return UINT64_MAX % (block_modulus * 2ull);
}
};
lut_funcs.push_back(f_group_propagation);
lut_ids.push_back(lut_id);
}
auto active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
luts_array_second_step->broadcast_lut(active_streams);
luts_array_second_step->generate_and_broadcast_lut(
active_streams, lut_ids, lut_funcs, gpu_memory_allocated);
// luts_array_second_step->broadcast_lut(active_streams);
if (use_sequential_algorithm_to_resolve_group_carries) {
@@ -2041,12 +2042,28 @@ template <typename Torus> struct int_sc_prop_memory {
uint32_t requested_flag;
bool gpu_memory_allocated;
void setup_message_extract_indices_for_carry_async(CudaStreams streams,
uint32_t num_radix_blocks,
bool allocate_gpu_memory) {
Torus *h_lut_indexes = lut_message_extract->h_lut_indexes;
for (int index = 0; index < num_radix_blocks + 1; index++) {
if (index < num_radix_blocks) {
h_lut_indexes[index] = 0;
} else {
h_lut_indexes[index] = 1;
}
}
cuda_memcpy_with_size_tracking_async_to_gpu(
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
(num_radix_blocks + 1) * sizeof(Torus), streams.stream(0),
streams.gpu_index(0), allocate_gpu_memory);
}
int_sc_prop_memory(CudaStreams streams, int_radix_params params,
uint32_t num_radix_blocks, uint32_t requested_flag_in,
bool allocate_gpu_memory, uint64_t &size_tracker) {
gpu_memory_allocated = allocate_gpu_memory;
this->params = params;
auto glwe_dimension = params.glwe_dimension;
auto polynomial_size = params.polynomial_size;
auto message_modulus = params.message_modulus;
auto carry_modulus = params.carry_modulus;
@@ -2069,24 +2086,6 @@ template <typename Torus> struct int_sc_prop_memory {
streams, params, num_radix_blocks, grouping_size, num_groups,
allocate_gpu_memory, size_tracker);
// Step 3 elements
int num_luts_message_extract =
requested_flag == outputFlag::FLAG_NONE ? 1 : 2;
lut_message_extract = new int_radix_lut<Torus>(
streams, params, num_luts_message_extract, num_radix_blocks + 1,
allocate_gpu_memory, size_tracker);
// lut for the first block in the first grouping
auto f_message_extract = [message_modulus](Torus block) -> Torus {
return (block >> 1) % message_modulus;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
lut_message_extract->get_lut(0, 0), lut_message_extract->get_degree(0),
lut_message_extract->get_max_degree(0), glwe_dimension, polynomial_size,
message_modulus, carry_modulus, f_message_extract,
gpu_memory_allocated);
// This store a single block that with be used to store the overflow or
// carry results
output_flag = new CudaRadixCiphertextFFI;
@@ -2137,22 +2136,30 @@ template <typename Torus> struct int_sc_prop_memory {
return output1 << 3 | output2 << 2;
};
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
lut_overflow_flag_prep->get_lut(0, 0),
lut_overflow_flag_prep->get_degree(0),
lut_overflow_flag_prep->get_max_degree(0), glwe_dimension,
polynomial_size, message_modulus, carry_modulus, f_overflow_fp,
gpu_memory_allocated);
auto active_streams = streams.active_gpu_subset(1, params.pbs_type);
lut_overflow_flag_prep->broadcast_lut(active_streams);
lut_overflow_flag_prep->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {f_overflow_fp}, gpu_memory_allocated);
}
// Step 3 elements
int num_luts_message_extract =
requested_flag == outputFlag::FLAG_NONE ? 1 : 2;
lut_message_extract = new int_radix_lut<Torus>(
streams, params, num_luts_message_extract, num_radix_blocks + 1,
allocate_gpu_memory, size_tracker);
// lut for the first block in the first grouping
auto f_message_extract = [message_modulus](Torus block) -> Torus {
return (block >> 1) % message_modulus;
};
auto active_streams =
streams.active_gpu_subset(num_radix_blocks + 1, params.pbs_type);
// For the final cleanup in case of overflow or carry (it seems that I can)
// It seems that this lut could be apply together with the other one but for
// now we won't do it
if (requested_flag == outputFlag::FLAG_OVERFLOW) { // Overflow case
switch (requested_flag) {
case outputFlag::FLAG_OVERFLOW: { // Overflow case
auto f_overflow_last = [num_radix_blocks,
requested_flag_in](Torus block) -> Torus {
uint32_t position = (num_radix_blocks == 1 &&
@@ -2164,62 +2171,38 @@ template <typename Torus> struct int_sc_prop_memory {
Torus does_overflow_if_carry_is_0 = (block >> 2) & 1;
if (input_carry == outputFlag::FLAG_OVERFLOW) {
return does_overflow_if_carry_is_1;
} else {
return does_overflow_if_carry_is_0;
}
return does_overflow_if_carry_is_0;
};
setup_message_extract_indices_for_carry_async(streams, num_radix_blocks,
allocate_gpu_memory);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
lut_message_extract->get_lut(0, 1),
lut_message_extract->get_degree(1),
lut_message_extract->get_max_degree(1), glwe_dimension,
polynomial_size, message_modulus, carry_modulus, f_overflow_last,
lut_message_extract->generate_and_broadcast_lut(
active_streams, {0, 1}, {f_message_extract, f_overflow_last},
gpu_memory_allocated);
Torus *h_lut_indexes = lut_message_extract->h_lut_indexes;
for (int index = 0; index < num_radix_blocks + 1; index++) {
if (index < num_radix_blocks) {
h_lut_indexes[index] = 0;
} else {
h_lut_indexes[index] = 1;
}
}
cuda_memcpy_with_size_tracking_async_to_gpu(
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
(num_radix_blocks + 1) * sizeof(Torus), streams.stream(0),
streams.gpu_index(0), allocate_gpu_memory);
break;
}
if (requested_flag == outputFlag::FLAG_CARRY) { // Carry case
case outputFlag::FLAG_CARRY: { // Carry case
setup_message_extract_indices_for_carry_async(streams, num_radix_blocks,
allocate_gpu_memory);
auto f_carry_last = [](Torus block) -> Torus {
return ((block >> 2) & 1);
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
lut_message_extract->get_lut(0, 1),
lut_message_extract->get_degree(1),
lut_message_extract->get_max_degree(1), glwe_dimension,
polynomial_size, message_modulus, carry_modulus, f_carry_last,
lut_message_extract->generate_and_broadcast_lut(
active_streams, {0, 1}, {f_message_extract, f_carry_last},
gpu_memory_allocated);
Torus *h_lut_indexes = lut_message_extract->h_lut_indexes;
for (int index = 0; index < num_radix_blocks + 1; index++) {
if (index < num_radix_blocks) {
h_lut_indexes[index] = 0;
} else {
h_lut_indexes[index] = 1;
}
}
cuda_memcpy_with_size_tracking_async_to_gpu(
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
(num_radix_blocks + 1) * sizeof(Torus), streams.stream(0),
streams.gpu_index(0), allocate_gpu_memory);
break;
}
auto active_streams =
streams.active_gpu_subset(num_radix_blocks + 1, params.pbs_type);
lut_message_extract->broadcast_lut(active_streams);
default:
lut_message_extract->generate_and_broadcast_lut(
active_streams, {0}, {f_message_extract}, gpu_memory_allocated);
break;
}
// lut_message_extract->broadcast_lut(active_streams);
};
void release(CudaStreams streams) {
@@ -2517,16 +2500,11 @@ template <typename Torus> struct int_borrow_prop_memory {
return (block >> 1) % message_modulus;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
lut_message_extract->get_lut(0, 0), lut_message_extract->get_degree(0),
lut_message_extract->get_max_degree(0), glwe_dimension, polynomial_size,
message_modulus, carry_modulus, f_message_extract,
gpu_memory_allocated);
active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
lut_message_extract->broadcast_lut(active_streams);
lut_message_extract->generate_and_broadcast_lut(
active_streams, {0}, {f_message_extract}, gpu_memory_allocated);
if (compute_overflow) {
lut_borrow_flag =
@@ -2537,12 +2515,8 @@ template <typename Torus> struct int_borrow_prop_memory {
return ((block >> 2) & 1);
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
lut_borrow_flag->get_lut(0, 0), lut_borrow_flag->get_degree(0),
lut_borrow_flag->get_max_degree(0), glwe_dimension, polynomial_size,
message_modulus, carry_modulus, f_borrow_flag, gpu_memory_allocated);
lut_borrow_flag->broadcast_lut(active_streams);
lut_borrow_flag->generate_and_broadcast_lut(
active_streams, {0}, {f_borrow_flag}, gpu_memory_allocated);
}
active_streams =

View File

@@ -37,17 +37,14 @@ template <typename Torus> struct int_mul_memory {
zero_out_predicate_lut =
new int_radix_lut<Torus>(streams, params, 1, num_radix_blocks,
allocate_gpu_memory, size_tracker);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
zero_out_predicate_lut->get_lut(0, 0),
zero_out_predicate_lut->get_degree(0),
zero_out_predicate_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
zero_out_predicate_lut_f, gpu_memory_allocated);
auto active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
zero_out_predicate_lut->broadcast_lut(active_streams);
zero_out_predicate_lut->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {zero_out_predicate_lut_f},
gpu_memory_allocated);
// zero_out_predicate_lut->broadcast_lut(active_streams);
zero_out_mem = new int_zero_out_if_buffer<Torus>(
streams, params, num_radix_blocks, allocate_gpu_memory, size_tracker);
@@ -55,10 +52,7 @@ template <typename Torus> struct int_mul_memory {
return;
}
auto glwe_dimension = params.glwe_dimension;
auto polynomial_size = params.polynomial_size;
auto message_modulus = params.message_modulus;
auto carry_modulus = params.carry_modulus;
// 'vector_result_lsb' contains blocks from all possible shifts of
// radix_lwe_left excluding zero ciphertext blocks
@@ -102,18 +96,6 @@ template <typename Torus> struct int_mul_memory {
return (x * y) / message_modulus;
};
// generate accumulators
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), lsb_acc,
luts_array->get_degree(0), luts_array->get_max_degree(0),
glwe_dimension, polynomial_size, message_modulus, carry_modulus,
lut_f_lsb, gpu_memory_allocated);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), msb_acc,
luts_array->get_degree(1), luts_array->get_max_degree(1),
glwe_dimension, polynomial_size, message_modulus, carry_modulus,
lut_f_msb, gpu_memory_allocated);
// lut_indexes_vec for luts_array should be reinitialized
// first lsb_vector_block_count value should reference to lsb_acc
// last msb_vector_block_count values should reference to msb_acc
@@ -123,9 +105,12 @@ template <typename Torus> struct int_mul_memory {
streams.stream(0), streams.gpu_index(0),
luts_array->get_lut_indexes(0, lsb_vector_block_count), 1,
msb_vector_block_count);
auto active_streams =
streams.active_gpu_subset(total_block_count, params.pbs_type);
luts_array->broadcast_lut(active_streams);
luts_array->generate_and_broadcast_bivariate_lut(
active_streams, {0, 1}, {lut_f_lsb, lut_f_msb}, gpu_memory_allocated);
// create memory object for sum ciphertexts
sum_ciphertexts_mem = new int_sum_ciphertexts_vec_memory<Torus>(
streams, params, num_radix_blocks, 2 * num_radix_blocks,

View File

@@ -85,15 +85,11 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
}
// right shift
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
cur_lut_bivariate->get_lut(0, 0), cur_lut_bivariate->get_degree(0),
cur_lut_bivariate->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
shift_lut_f, gpu_memory_allocated);
auto active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
cur_lut_bivariate->broadcast_lut(active_streams);
cur_lut_bivariate->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {shift_lut_f}, gpu_memory_allocated);
lut_buffers_bivariate.push_back(cur_lut_bivariate);
}
@@ -172,16 +168,10 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
}
// right shift
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
cur_lut_bivariate->get_lut(0, 0), cur_lut_bivariate->get_degree(0),
cur_lut_bivariate->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
shift_lut_f, gpu_memory_allocated);
auto active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
cur_lut_bivariate->broadcast_lut(active_streams);
cur_lut_bivariate->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {shift_lut_f}, gpu_memory_allocated);
lut_buffers_bivariate.push_back(cur_lut_bivariate);
}
}
@@ -271,16 +261,11 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
return shifted | padding;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
shift_last_block_lut_univariate->get_lut(0, 0),
shift_last_block_lut_univariate->get_degree(0),
shift_last_block_lut_univariate->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, last_block_lut_f, gpu_memory_allocated);
auto active_streams_shift_last =
streams.active_gpu_subset(1, params.pbs_type);
shift_last_block_lut_univariate->broadcast_lut(active_streams_shift_last);
shift_last_block_lut_univariate->generate_and_broadcast_lut(
active_streams_shift_last, {0}, {last_block_lut_f},
gpu_memory_allocated);
lut_buffers_univariate.push_back(shift_last_block_lut_univariate);
}
@@ -298,15 +283,8 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
return (params.message_modulus - 1) * x_sign_bit;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
padding_block_lut_univariate->get_lut(0, 0),
padding_block_lut_univariate->get_degree(0),
padding_block_lut_univariate->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
padding_block_lut_f, gpu_memory_allocated);
// auto active_streams = streams.active_gpu_subset(1, params.pbs_type);
padding_block_lut_univariate->broadcast_lut(active_streams);
padding_block_lut_univariate->generate_and_broadcast_lut(
active_streams, {0}, {padding_block_lut_f}, gpu_memory_allocated);
lut_buffers_univariate.push_back(padding_block_lut_univariate);
@@ -339,16 +317,11 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
return message_of_current_block + carry_of_previous_block;
};
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
shift_blocks_lut_bivariate->get_lut(0, 0),
shift_blocks_lut_bivariate->get_degree(0),
shift_blocks_lut_bivariate->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
blocks_lut_f, gpu_memory_allocated);
auto active_streams_shift_blocks =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
shift_blocks_lut_bivariate->broadcast_lut(active_streams_shift_blocks);
shift_blocks_lut_bivariate->generate_and_broadcast_bivariate_lut(
active_streams_shift_blocks, {0}, {blocks_lut_f},
gpu_memory_allocated);
lut_buffers_bivariate.push_back(shift_blocks_lut_bivariate);
}

View File

@@ -113,27 +113,21 @@ template <typename Torus> struct int_shift_and_rotate_buffer {
else
return current_bit;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), mux_lut->get_lut(0, 0),
mux_lut->get_degree(0), mux_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, mux_lut_f, gpu_memory_allocated);
;
auto active_gpu_count_mux = streams.active_gpu_subset(
bits_per_block * num_radix_blocks, params.pbs_type);
mux_lut->broadcast_lut(active_gpu_count_mux);
mux_lut->generate_and_broadcast_lut(active_gpu_count_mux, {0}, {mux_lut_f},
gpu_memory_allocated);
auto cleaning_lut_f = [params](Torus x) -> Torus {
return x % params.message_modulus;
};
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), cleaning_lut->get_lut(0, 0),
cleaning_lut->get_degree(0), cleaning_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, cleaning_lut_f, gpu_memory_allocated);
auto active_gpu_count_cleaning =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
cleaning_lut->broadcast_lut(active_gpu_count_cleaning);
cleaning_lut->generate_and_broadcast_lut(
active_gpu_count_cleaning, {0}, {cleaning_lut_f}, gpu_memory_allocated);
}
void release(CudaStreams streams) {

View File

@@ -74,45 +74,26 @@ template <typename Torus> struct int_overflowing_sub_memory {
luts_array, size_tracker,
allocate_gpu_memory, size_tracker);
auto lut_does_block_generate_carry = luts_array->get_lut(0, 0);
auto lut_does_block_generate_or_propagate = luts_array->get_lut(0, 1);
// generate luts (aka accumulators)
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), lut_does_block_generate_carry,
luts_array->get_degree(0), luts_array->get_max_degree(0),
glwe_dimension, polynomial_size, message_modulus, carry_modulus,
f_lut_does_block_generate_carry, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
lut_does_block_generate_or_propagate, luts_array->get_degree(1),
luts_array->get_max_degree(1), glwe_dimension, polynomial_size,
message_modulus, carry_modulus, f_lut_does_block_generate_or_propagate,
gpu_memory_allocated);
if (allocate_gpu_memory)
cuda_set_value_async<Torus>(streams.stream(0), streams.gpu_index(0),
luts_array->get_lut_indexes(0, 1), 1,
num_radix_blocks - 1);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
luts_borrow_propagation_sum->get_lut(0, 0),
luts_borrow_propagation_sum->get_degree(0),
luts_borrow_propagation_sum->get_max_degree(0), glwe_dimension,
polynomial_size, message_modulus, carry_modulus,
f_luts_borrow_propagation_sum, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), message_acc->get_lut(0, 0),
message_acc->get_degree(0), message_acc->get_max_degree(0),
glwe_dimension, polynomial_size, message_modulus, carry_modulus,
f_message_acc, gpu_memory_allocated);
auto active_streams =
streams.active_gpu_subset(num_radix_blocks, params.pbs_type);
luts_array->broadcast_lut(active_streams);
luts_borrow_propagation_sum->broadcast_lut(active_streams);
message_acc->broadcast_lut(active_streams);
luts_borrow_propagation_sum->generate_and_broadcast_bivariate_lut(
active_streams, {0}, {f_luts_borrow_propagation_sum},
gpu_memory_allocated);
luts_array->generate_and_broadcast_lut(
active_streams, {0, 1},
{f_lut_does_block_generate_carry,
f_lut_does_block_generate_or_propagate},
gpu_memory_allocated);
// generate luts (aka accumulators)
message_acc->generate_and_broadcast_lut(
active_streams, {0}, {f_message_acc}, gpu_memory_allocated);
}
void release(CudaStreams streams) {

View File

@@ -298,14 +298,10 @@ template <typename Torus> struct int_aggregate_one_hot_buffer {
int_radix_lut<Torus> *lut = new int_radix_lut<Torus>(
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0), lut->get_lut(0, 0),
lut->get_degree(0), lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
id_fn, allocate_gpu_memory);
lut->generate_and_broadcast_lut(
streams.active_gpu_subset(num_blocks, params.pbs_type), {0}, {id_fn},
allocate_gpu_memory);
lut->broadcast_lut(
streams.active_gpu_subset(num_blocks, params.pbs_type));
this->stream_identity_luts[i] = lut;
}
@@ -318,27 +314,17 @@ template <typename Torus> struct int_aggregate_one_hot_buffer {
this->message_extract_lut = new int_radix_lut<Torus>(
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
this->message_extract_lut->get_lut(0, 0),
this->message_extract_lut->get_degree(0),
this->message_extract_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
msg_fn, allocate_gpu_memory);
this->message_extract_lut->broadcast_lut(
streams.active_gpu_subset(num_blocks, params.pbs_type));
this->message_extract_lut->generate_and_broadcast_lut(
streams.active_gpu_subset(num_blocks, params.pbs_type), {0}, {msg_fn},
allocate_gpu_memory);
this->carry_extract_lut = new int_radix_lut<Torus>(
streams, params, 1, num_blocks, allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
this->carry_extract_lut->get_lut(0, 0),
this->carry_extract_lut->get_degree(0),
this->carry_extract_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
carry_fn, allocate_gpu_memory);
this->carry_extract_lut->broadcast_lut(
streams.active_gpu_subset(num_blocks, params.pbs_type));
this->carry_extract_lut->generate_and_broadcast_lut(
streams.active_gpu_subset(num_blocks, params.pbs_type), {0}, {carry_fn},
allocate_gpu_memory);
this->partial_aggregated_vectors =
new CudaRadixCiphertextFFI *[num_streams];
@@ -1185,15 +1171,9 @@ template <typename Torus> struct int_unchecked_first_index_of_clear_buffer {
this->prefix_sum_lut = new int_radix_lut<Torus>(
streams, params, 2, num_inputs, allocate_gpu_memory, size_tracker);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
this->prefix_sum_lut->get_lut(0, 0),
this->prefix_sum_lut->get_degree(0),
this->prefix_sum_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
prefix_sum_fn, allocate_gpu_memory);
this->prefix_sum_lut->broadcast_lut(
streams.active_gpu_subset(num_inputs, params.pbs_type));
this->prefix_sum_lut->generate_and_broadcast_bivariate_lut(
streams.active_gpu_subset(num_inputs, params.pbs_type), {0},
{prefix_sum_fn}, allocate_gpu_memory);
auto cleanup_fn = [ALREADY_SEEN, params](Torus x) -> Torus {
Torus val = x % params.message_modulus;
@@ -1203,14 +1183,9 @@ template <typename Torus> struct int_unchecked_first_index_of_clear_buffer {
};
this->cleanup_lut = new int_radix_lut<Torus>(
streams, params, 1, num_inputs, allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
this->cleanup_lut->get_lut(0, 0), this->cleanup_lut->get_degree(0),
this->cleanup_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
cleanup_fn, allocate_gpu_memory);
this->cleanup_lut->broadcast_lut(
streams.active_gpu_subset(num_inputs, params.pbs_type));
this->cleanup_lut->generate_and_broadcast_lut(
streams.active_gpu_subset(num_inputs, params.pbs_type), {0},
{cleanup_fn}, allocate_gpu_memory);
}
void release(CudaStreams streams) {
@@ -1376,15 +1351,9 @@ template <typename Torus> struct int_unchecked_first_index_of_buffer {
this->prefix_sum_lut = new int_radix_lut<Torus>(
streams, params, 2, num_inputs, allocate_gpu_memory, size_tracker);
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0),
this->prefix_sum_lut->get_lut(0, 0),
this->prefix_sum_lut->get_degree(0),
this->prefix_sum_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
prefix_sum_fn, allocate_gpu_memory);
this->prefix_sum_lut->broadcast_lut(
streams.active_gpu_subset(num_inputs, params.pbs_type));
this->prefix_sum_lut->generate_and_broadcast_bivariate_lut(
streams.active_gpu_subset(num_inputs, params.pbs_type), {0},
{prefix_sum_fn}, allocate_gpu_memory);
auto cleanup_fn = [ALREADY_SEEN, params](Torus x) -> Torus {
Torus val = x % params.message_modulus;
@@ -1394,14 +1363,9 @@ template <typename Torus> struct int_unchecked_first_index_of_buffer {
};
this->cleanup_lut = new int_radix_lut<Torus>(
streams, params, 1, num_inputs, allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
this->cleanup_lut->get_lut(0, 0), this->cleanup_lut->get_degree(0),
this->cleanup_lut->get_max_degree(0), params.glwe_dimension,
params.polynomial_size, params.message_modulus, params.carry_modulus,
cleanup_fn, allocate_gpu_memory);
this->cleanup_lut->broadcast_lut(
streams.active_gpu_subset(num_inputs, params.pbs_type));
this->cleanup_lut->generate_and_broadcast_lut(
streams.active_gpu_subset(num_inputs, params.pbs_type), {0},
{cleanup_fn}, allocate_gpu_memory);
}
void release(CudaStreams streams) {

View File

@@ -0,0 +1,24 @@
#ifndef TRIVIUM_H
#define TRIVIUM_H
#include "../integer/integer.h"
extern "C" {
uint64_t scratch_cuda_trivium_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs);
void cuda_trivium_generate_keystream_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void);
}
#endif

View File

@@ -0,0 +1,295 @@
#ifndef TRIVIUM_UTILITIES_H
#define TRIVIUM_UTILITIES_H
#include "../integer/integer_utilities.h"
/// Struct to hold the LUTs.
template <typename Torus> struct int_trivium_lut_buffers {
// Bivariate AND Gate LUT:
// AND operation: f(a, b) = (a & 1) & (b & 1).
// This is a Bivariate PBS used for the non-linear parts of Trivium.
int_radix_lut<Torus> *and_lut;
// Univariate Identity LUT:
// MESSAGE EXTRACTION operation: f(x) = x & 1.
// This is a Univariate PBS used to "flush" the state: it resets the noise
// after additions and ensures the message stays within the binary message
// space.
int_radix_lut<Torus> *flush_lut;
int_trivium_lut_buffers(CudaStreams streams, const int_radix_params &params,
bool allocate_gpu_memory, uint32_t num_trivium_inputs,
uint64_t &size_tracker) {
constexpr uint32_t BATCH_SIZE = 64;
constexpr uint32_t MAX_AND_PER_STEP = 3;
uint32_t total_lut_ops = num_trivium_inputs * BATCH_SIZE * MAX_AND_PER_STEP;
this->and_lut = new int_radix_lut<Torus>(streams, params, 1, total_lut_ops,
allocate_gpu_memory, size_tracker);
std::function<Torus(Torus, Torus)> and_lambda =
[](Torus a, Torus b) -> Torus { return (a & 1) & (b & 1); };
auto active_streams_and =
streams.active_gpu_subset(total_lut_ops, params.pbs_type);
this->and_lut->generate_and_broadcast_bivariate_lut(
active_streams_and, {0}, {and_lambda}, allocate_gpu_memory);
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
uint32_t total_flush_ops = num_trivium_inputs * BATCH_SIZE * 4;
this->flush_lut = new int_radix_lut<Torus>(
streams, params, 1, total_flush_ops, allocate_gpu_memory, size_tracker);
std::function<Torus(Torus)> flush_lambda = [](Torus x) -> Torus {
return x & 1;
};
auto active_streams_flush =
streams.active_gpu_subset(total_flush_ops, params.pbs_type);
this->flush_lut->generate_and_broadcast_lut(
active_streams_flush, {0}, {flush_lambda}, allocate_gpu_memory);
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
}
void release(CudaStreams streams) {
this->and_lut->release(streams);
delete this->and_lut;
this->and_lut = nullptr;
this->flush_lut->release(streams);
delete this->flush_lut;
this->flush_lut = nullptr;
}
};
/// Struct to hold the state and temporary workspaces required for
/// Trivium execution on the GPU.
///
/// This struct manages the memory for the internal registers (A, B, C),
/// temporary buffers used during the update function, and buffers used for
/// packing data before and after PBS.
template <typename Torus> struct int_trivium_state_workspaces {
// Trivium Internal State Registers:
// Register A: 93 bits
CudaRadixCiphertextFFI *a_reg;
// Register B: 84 bits
CudaRadixCiphertextFFI *b_reg;
// Register C: 111 bits
CudaRadixCiphertextFFI *c_reg;
// Shift Workspace:
// Used to manage bitshifting operations on the registers
CudaRadixCiphertextFFI *shift_workspace;
// Temporary Update Buffers:
// Intermediate buffers for the trivium update logic (t1, t2, t3)
CudaRadixCiphertextFFI *temp_t1;
CudaRadixCiphertextFFI *temp_t2;
CudaRadixCiphertextFFI *temp_t3;
// Buffers to hold the new values for the registers after an update step
CudaRadixCiphertextFFI *new_a;
CudaRadixCiphertextFFI *new_b;
CudaRadixCiphertextFFI *new_c;
// PBS Packing Buffers:
// Buffers for packing inputs into the bivariate lookup table (AND gate)
CudaRadixCiphertextFFI *packed_pbs_lhs;
CudaRadixCiphertextFFI *packed_pbs_rhs;
// Buffer for the output of the bivariate PBS
CudaRadixCiphertextFFI *packed_pbs_out;
// Flush/Cleanup Packing Buffers:
// Buffers for the "flush" LUT which cleans up noise after additions
CudaRadixCiphertextFFI *packed_flush_in;
CudaRadixCiphertextFFI *packed_flush_out;
int_trivium_state_workspaces(CudaStreams streams,
const int_radix_params &params,
bool allocate_gpu_memory, uint32_t num_inputs,
uint64_t &size_tracker) {
this->a_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->a_reg, 93 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->b_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->b_reg, 84 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->c_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->c_reg, 111 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->shift_workspace = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->shift_workspace,
128 * num_inputs, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
uint32_t batch_blocks = 64 * num_inputs;
this->temp_t1 = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->temp_t1, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->temp_t2 = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->temp_t2, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->temp_t3 = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->temp_t3, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->new_a = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->new_a, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->new_b = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->new_b, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->new_c = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->new_c, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->packed_pbs_lhs = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_pbs_lhs,
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_pbs_rhs = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_pbs_rhs,
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_pbs_out = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_pbs_out,
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_flush_in = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_flush_in,
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_flush_out = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_flush_out,
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
}
void release(CudaStreams streams, bool allocate_gpu_memory) {
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->a_reg, allocate_gpu_memory);
delete this->a_reg;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->b_reg, allocate_gpu_memory);
delete this->b_reg;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->c_reg, allocate_gpu_memory);
delete this->c_reg;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->shift_workspace, allocate_gpu_memory);
delete this->shift_workspace;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->temp_t1, allocate_gpu_memory);
delete this->temp_t1;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->temp_t2, allocate_gpu_memory);
delete this->temp_t2;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->temp_t3, allocate_gpu_memory);
delete this->temp_t3;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->new_a, allocate_gpu_memory);
delete this->new_a;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->new_b, allocate_gpu_memory);
delete this->new_b;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->new_c, allocate_gpu_memory);
delete this->new_c;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_pbs_lhs, allocate_gpu_memory);
delete this->packed_pbs_lhs;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_pbs_rhs, allocate_gpu_memory);
delete this->packed_pbs_rhs;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_pbs_out, allocate_gpu_memory);
delete this->packed_pbs_out;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_flush_in, allocate_gpu_memory);
delete this->packed_flush_in;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_flush_out, allocate_gpu_memory);
delete this->packed_flush_out;
}
};
template <typename Torus> struct int_trivium_buffer {
int_radix_params params;
bool allocate_gpu_memory;
uint32_t num_inputs;
int_trivium_lut_buffers<Torus> *luts;
int_trivium_state_workspaces<Torus> *state;
int_trivium_buffer(CudaStreams streams, const int_radix_params &params,
bool allocate_gpu_memory, uint32_t num_inputs,
uint64_t &size_tracker) {
this->params = params;
this->allocate_gpu_memory = allocate_gpu_memory;
this->num_inputs = num_inputs;
this->luts = new int_trivium_lut_buffers<Torus>(
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
this->state = new int_trivium_state_workspaces<Torus>(
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
}
void release(CudaStreams streams) {
luts->release(streams);
delete luts;
luts = nullptr;
state->release(streams, allocate_gpu_memory);
delete state;
state = nullptr;
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};
#endif

View File

@@ -174,40 +174,6 @@ template <typename Torus> struct zk_expand_mem {
message_and_carry_extract_luts = new int_radix_lut<Torus>(
streams, params, 4, 2 * num_lwes, allocate_gpu_memory, size_tracker);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
message_and_carry_extract_luts->get_lut(0, 0),
message_and_carry_extract_luts->get_degree(0),
message_and_carry_extract_luts->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, message_extract_lut_f, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
message_and_carry_extract_luts->get_lut(0, 1),
message_and_carry_extract_luts->get_degree(1),
message_and_carry_extract_luts->get_max_degree(1),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, carry_extract_lut_f, gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
message_and_carry_extract_luts->get_lut(0, 2),
message_and_carry_extract_luts->get_degree(2),
message_and_carry_extract_luts->get_max_degree(2),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, message_extract_and_sanitize_bool_lut_f,
gpu_memory_allocated);
generate_device_accumulator<Torus>(
streams.stream(0), streams.gpu_index(0),
message_and_carry_extract_luts->get_lut(0, 3),
message_and_carry_extract_luts->get_degree(3),
message_and_carry_extract_luts->get_max_degree(3),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, carry_extract_and_sanitize_bool_lut_f,
gpu_memory_allocated);
// We are always packing two LWEs. We just need to be sure we have enough
// space in the carry part to store a message of the same size as is in the
// message part.
@@ -292,7 +258,13 @@ template <typename Torus> struct zk_expand_mem {
auto active_streams =
streams.active_gpu_subset(2 * num_lwes, params.pbs_type);
message_and_carry_extract_luts->broadcast_lut(active_streams);
message_and_carry_extract_luts->generate_and_broadcast_lut(
active_streams, {0, 1, 2, 3},
{message_extract_lut_f, carry_extract_lut_f,
message_extract_and_sanitize_bool_lut_f,
carry_extract_and_sanitize_bool_lut_f},
gpu_memory_allocated);
message_and_carry_extract_luts->allocate_lwe_vector_for_non_trivial_indexes(
active_streams, 2 * num_lwes, size_tracker, allocate_gpu_memory);

View File

@@ -1067,6 +1067,85 @@ void generate_device_accumulator_bivariate(
POP_RANGE()
}
template <typename Torus> struct int_lut_cache {
int_lut_cache() {}
Torus *get_cached_univariate_lut(std::function<Torus(Torus)> &f, uint64_t *degree,
uint64_t *max_degree, uint32_t glwe_dimension,
uint32_t polynomial_size,
uint32_t input_message_modulus,
uint32_t input_carry_modulus,
uint32_t output_message_modulus,
uint32_t output_carry_modulus) {
/*__int128_t f_hash = 0;
uint32_t bits_per_lut_val = 5;
uint32_t input_modulus_sup = input_message_modulus * input_carry_modulus;
for (uint32_t i = 0; i < input_modulus_sup; ++i) {
Torus f_eval = f(i);
GPU_ASSERT(f_eval < (1 << bits_per_lut_val),
"LUT value expected bitwidth overflow");
f_hash |= f_eval;
f_hash <<= bits_per_lut_val;
}
std::lock_guard cache_lock(_mutex);
if (_lut_cache.find(f_hash) != _lut_cache.end()) {
lut_ptr &ptr = _lut_cache[f_hash];
GPU_ASSERT(ptr.output_message_modulus == output_message_modulus,
"Error modulus");
GPU_ASSERT(ptr.input_message_modulus == input_message_modulus,
"Error modulus");
GPU_ASSERT(ptr.glwe_dimension == glwe_dimension, "Error modulus");
*max_degree = ptr.max_degree;
*degree = ptr.degree;
return ptr.ptr;
}*/
// host lut
Torus *h_lut =
(Torus *)malloc((glwe_dimension + 1) * polynomial_size * sizeof(Torus));
*max_degree = input_message_modulus * input_carry_modulus - 1;
*degree = generate_lookup_table_with_encoding<Torus>(
h_lut, glwe_dimension, polynomial_size, input_message_modulus,
input_carry_modulus, output_message_modulus, output_carry_modulus, f);
/*lut_ptr new_ptr = {h_lut,
glwe_dimension,
input_message_modulus,
input_carry_modulus,
output_message_modulus,
output_carry_modulus,
*max_degree,
*degree};*/
//_lut_cache[f_hash] = new_ptr;
return h_lut;
}
~int_lut_cache() {
std::lock_guard cache_lock(_mutex);
for (auto v : _lut_cache) {
free(v.second.ptr);
}
_lut_cache.clear();
}
private:
struct lut_ptr {
Torus *ptr;
uint32_t glwe_dimension;
uint32_t input_message_modulus;
uint32_t input_carry_modulus;
uint32_t output_message_modulus;
uint32_t output_carry_modulus;
uint64_t max_degree;
uint64_t degree;
};
std::map<__int128_t, lut_ptr> _lut_cache;
std::mutex _mutex;
};
static int_lut_cache<uint64_t> g_LutCache64;
/*
* generate bivariate accumulator with factor scaling for device pointer
* v_stream - cuda stream
@@ -1098,8 +1177,8 @@ void generate_device_accumulator_bivariate_with_factor(
(glwe_dimension + 1) * polynomial_size * sizeof(Torus), stream, gpu_index,
gpu_memory_allocated);
cuda_synchronize_stream(stream, gpu_index);
free(h_lut);
// cuda_synchronize_stream(stream, gpu_index);
// free(h_lut);
}
/*
* generate bivariate accumulator for device pointer
@@ -1145,23 +1224,36 @@ void generate_device_accumulator_with_encoding(
uint32_t output_message_modulus, uint32_t output_carry_modulus,
std::function<Torus(Torus)> f, bool gpu_memory_allocated) {
static constexpr auto is_u64 = std::is_same_v<Torus, uint64_t>;
Torus *h_lut = nullptr;
// host lut
Torus *h_lut =
(Torus *)malloc((glwe_dimension + 1) * polynomial_size * sizeof(Torus));
*max_degree = input_message_modulus * input_carry_modulus - 1;
// fill accumulator
*degree = generate_lookup_table_with_encoding<Torus>(
h_lut, glwe_dimension, polynomial_size, input_message_modulus,
input_carry_modulus, output_message_modulus, output_carry_modulus, f);
if constexpr (is_u64) {
h_lut = g_LutCache64.get_cached_univariate_lut(
f, degree, max_degree, glwe_dimension, polynomial_size,
input_message_modulus, input_carry_modulus, output_message_modulus,
output_carry_modulus);
} else {
h_lut =
(Torus *)malloc((glwe_dimension + 1) * polynomial_size * sizeof(Torus));
*max_degree = input_message_modulus * input_carry_modulus - 1;
// fill accumulator
*degree = generate_lookup_table_with_encoding<Torus>(
h_lut, glwe_dimension, polynomial_size, input_message_modulus,
input_carry_modulus, output_message_modulus, output_carry_modulus, f);
}
/*
// copy host lut and lut_indexes_vec to device
cuda_memcpy_with_size_tracking_async_to_gpu(
acc, h_lut, (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
stream, gpu_index, gpu_memory_allocated);
cuda_synchronize_stream(stream, gpu_index);
free(h_lut);
*/
if (!std::is_same_v<Torus, uint64_t>) {
cuda_synchronize_stream(stream, gpu_index);
free(h_lut);
}
}
template <typename Torus>
void generate_device_accumulator_with_encoding_with_cpu_prealloc(
cudaStream_t stream, uint32_t gpu_index, Torus *acc, uint64_t *degree,
@@ -1264,8 +1356,8 @@ void generate_many_lut_device_accumulator(
acc, h_lut, (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
stream, gpu_index, gpu_memory_allocated);
cuda_synchronize_stream(stream, gpu_index);
free(h_lut);
//cuda_synchronize_stream(stream, gpu_index);
//free(h_lut);
POP_RANGE()
}

View File

@@ -0,0 +1,45 @@
#include "../../include/trivium/trivium.h"
#include "trivium.cuh"
uint64_t scratch_cuda_trivium_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
glwe_dimension * polynomial_size, lwe_dimension,
ks_level, ks_base_log, pbs_level, pbs_base_log,
grouping_factor, message_modulus, carry_modulus,
noise_reduction_type);
return scratch_cuda_trivium_encrypt<uint64_t>(
CudaStreams(streams), (int_trivium_buffer<uint64_t> **)mem_ptr, params,
allocate_gpu_memory, num_inputs);
}
void cuda_trivium_generate_keystream_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
void *const *ksks) {
auto buffer = (int_trivium_buffer<uint64_t> *)mem_ptr;
host_trivium_generate_keystream<uint64_t>(
CudaStreams(streams), keystream_output, key, iv, num_inputs, num_steps,
buffer, bsks, (uint64_t *const *)ksks);
}
void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
int_trivium_buffer<uint64_t> *mem_ptr =
(int_trivium_buffer<uint64_t> *)(*mem_ptr_void);
mem_ptr->release(CudaStreams(streams));
delete mem_ptr;
*mem_ptr_void = nullptr;
}

View File

@@ -0,0 +1,341 @@
#ifndef TRIVIUM_CUH
#define TRIVIUM_CUH
#include "../../include/trivium/trivium_utilities.h"
#include "../integer/integer.cuh"
#include "../integer/radix_ciphertext.cuh"
#include "../integer/scalar_addition.cuh"
#include "../linearalgebra/addition.cuh"
// Reverses the order of bits (blocks) in a ciphertext buffer.
// Used to align the input Key/IV with the internal state format if needed.
template <typename Torus>
void reverse_bitsliced_radix_inplace(CudaStreams streams,
int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *radix,
uint32_t num_bits_in_reg) {
uint32_t N = mem->num_inputs;
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
for (uint32_t i = 0; i < num_bits_in_reg; i++) {
uint32_t src_start = i * N;
uint32_t src_end = (i + 1) * N;
uint32_t dest_start = (num_bits_in_reg - 1 - i) * N;
uint32_t dest_end = (num_bits_in_reg - i) * N;
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), temp, dest_start, dest_end,
radix, src_start, src_end);
}
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), radix, 0, num_bits_in_reg * N,
temp, 0, num_bits_in_reg * N);
}
// Creates a slice of specific bits in a register without copying data.
template <typename Torus>
__host__ void slice_reg_batch(CudaRadixCiphertextFFI *slice,
const CudaRadixCiphertextFFI *reg,
uint32_t start_bit_idx, uint32_t num_bits,
uint32_t num_inputs) {
as_radix_ciphertext_slice<Torus>(slice, reg, start_bit_idx * num_inputs,
(start_bit_idx + num_bits) * num_inputs);
}
// Handles the shift-register update: discards old bits, shifts the rest,
// and inserts the newly computed bits at the beginning.
template <typename Torus>
__host__ void shift_and_insert_batch(CudaStreams streams,
int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *reg,
CudaRadixCiphertextFFI *new_bits,
uint32_t reg_size, uint32_t num_inputs) {
constexpr uint32_t BATCH = 64;
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
uint32_t num_blocks_to_keep = (reg_size - BATCH) * num_inputs;
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), temp, 0, BATCH * num_inputs,
new_bits, 0, BATCH * num_inputs);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), temp, BATCH * num_inputs,
reg_size * num_inputs, reg, 0, num_blocks_to_keep);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), reg, 0, reg_size * num_inputs,
temp, 0, reg_size * num_inputs);
}
// core logic: computes 64 parallel updates for the state registers.
// It performs the XORs (additions) and the AND gates (using Bivariate PBS),
// then updates the registers and writes to output if needed.
template <typename Torus>
__host__ void
trivium_compute_64_steps(CudaStreams streams, int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *output_dest, void *const *bsks,
uint64_t *const *ksks) {
uint32_t N = mem->num_inputs;
constexpr uint32_t BATCH = 64;
uint32_t batch_size_blocks = BATCH * N;
auto s = mem->state;
CudaRadixCiphertextFFI a65_slice, a92_slice, a91_slice, a90_slice, a68_slice;
slice_reg_batch<Torus>(&a65_slice, s->a_reg, 2, BATCH, N);
slice_reg_batch<Torus>(&a92_slice, s->a_reg, 29, BATCH, N);
slice_reg_batch<Torus>(&a91_slice, s->a_reg, 28, BATCH, N);
slice_reg_batch<Torus>(&a90_slice, s->a_reg, 27, BATCH, N);
slice_reg_batch<Torus>(&a68_slice, s->a_reg, 5, BATCH, N);
CudaRadixCiphertextFFI b68_slice, b83_slice, b82_slice, b81_slice, b77_slice;
slice_reg_batch<Torus>(&b68_slice, s->b_reg, 5, BATCH, N);
slice_reg_batch<Torus>(&b83_slice, s->b_reg, 20, BATCH, N);
slice_reg_batch<Torus>(&b82_slice, s->b_reg, 19, BATCH, N);
slice_reg_batch<Torus>(&b81_slice, s->b_reg, 18, BATCH, N);
slice_reg_batch<Torus>(&b77_slice, s->b_reg, 14, BATCH, N);
CudaRadixCiphertextFFI c65_slice, c110_slice, c109_slice, c108_slice,
c86_slice;
slice_reg_batch<Torus>(&c65_slice, s->c_reg, 2, BATCH, N);
slice_reg_batch<Torus>(&c110_slice, s->c_reg, 47, BATCH, N);
slice_reg_batch<Torus>(&c109_slice, s->c_reg, 46, BATCH, N);
slice_reg_batch<Torus>(&c108_slice, s->c_reg, 45, BATCH, N);
slice_reg_batch<Torus>(&c86_slice, s->c_reg, 23, BATCH, N);
// t1 = a66 + a93
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->temp_t1,
&a65_slice, &a92_slice, s->temp_t1->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
// t2 = b69 + b84
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->temp_t2,
&b68_slice, &b83_slice, s->temp_t2->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
// t3 = c66 + c111
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->temp_t3,
&c65_slice, &c110_slice, s->temp_t3->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, 0,
batch_size_blocks, &c109_slice, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
batch_size_blocks, 2 * batch_size_blocks, &a91_slice, 0,
batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
2 * batch_size_blocks, 3 * batch_size_blocks, &b82_slice, 0,
batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, 0,
batch_size_blocks, &c108_slice, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
batch_size_blocks, 2 * batch_size_blocks, &a90_slice, 0,
batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
2 * batch_size_blocks, 3 * batch_size_blocks, &b81_slice, 0,
batch_size_blocks);
integer_radix_apply_bivariate_lookup_table<Torus>(
streams, s->packed_pbs_out, s->packed_pbs_lhs, s->packed_pbs_rhs, bsks,
ksks, mem->luts->and_lut, 3 * batch_size_blocks,
mem->params.message_modulus);
CudaRadixCiphertextFFI and_res_a, and_res_b, and_res_c;
as_radix_ciphertext_slice<Torus>(&and_res_a, s->packed_pbs_out, 0,
batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&and_res_b, s->packed_pbs_out,
batch_size_blocks, 2 * batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&and_res_c, s->packed_pbs_out,
2 * batch_size_blocks,
3 * batch_size_blocks);
// a = t3 + a69 + and_res_a
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_a,
s->temp_t3, &a68_slice, s->new_a->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_a,
s->new_a, &and_res_a, s->new_a->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
// b = t1 + b78 + and_res_b
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_b,
s->temp_t1, &b77_slice, s->new_b->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_b,
s->new_b, &and_res_b, s->new_b->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
// c = t2 + c87 + and_res_c
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_c,
s->temp_t2, &c86_slice, s->new_c->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), s->new_c,
s->new_c, &and_res_c, s->new_c->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
if (output_dest != nullptr) {
// z = t1 + t2 + t3
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), output_dest,
s->temp_t1, s->temp_t2, output_dest->num_radix_blocks,
mem->params.message_modulus,
mem->params.carry_modulus);
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), output_dest,
output_dest, s->temp_t3, output_dest->num_radix_blocks,
mem->params.message_modulus,
mem->params.carry_modulus);
}
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0,
batch_size_blocks, s->new_a, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
2 * batch_size_blocks, 3 * batch_size_blocks, s->new_c, 0,
batch_size_blocks);
uint32_t total_flush_blocks = 3 * batch_size_blocks;
if (output_dest != nullptr) {
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
3 * batch_size_blocks, 4 * batch_size_blocks, output_dest, 0,
batch_size_blocks);
total_flush_blocks += batch_size_blocks;
}
integer_radix_apply_univariate_lookup_table<Torus>(
streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks,
mem->luts->flush_lut, total_flush_blocks);
CudaRadixCiphertextFFI flushed_a, flushed_b, flushed_c;
as_radix_ciphertext_slice<Torus>(&flushed_a, s->packed_flush_out, 0,
batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&flushed_b, s->packed_flush_out,
batch_size_blocks, 2 * batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&flushed_c, s->packed_flush_out,
2 * batch_size_blocks,
3 * batch_size_blocks);
shift_and_insert_batch(streams, mem, s->a_reg, &flushed_a, 93, N);
shift_and_insert_batch(streams, mem, s->b_reg, &flushed_b, 84, N);
shift_and_insert_batch(streams, mem, s->c_reg, &flushed_c, 111, N);
if (output_dest != nullptr) {
CudaRadixCiphertextFFI flushed_out;
as_radix_ciphertext_slice<Torus>(&flushed_out, s->packed_flush_out,
3 * batch_size_blocks,
4 * batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), output_dest, 0,
batch_size_blocks, &flushed_out, 0, batch_size_blocks);
reverse_bitsliced_radix_inplace<Torus>(streams, mem, output_dest, 64);
}
}
// Sets up the initial state: loads Key and IV, fixes constants,
// and runs the warm-up phase (1152 steps).
template <typename Torus>
__host__ void trivium_init(CudaStreams streams, int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI const *key_bitsliced,
CudaRadixCiphertextFFI const *iv_bitsliced,
void *const *bsks, uint64_t *const *ksks) {
uint32_t N = mem->num_inputs;
auto s = mem->state;
CudaRadixCiphertextFFI src_key_slice;
slice_reg_batch<Torus>(&src_key_slice, key_bitsliced, 0, 80, N);
CudaRadixCiphertextFFI dest_a_slice;
slice_reg_batch<Torus>(&dest_a_slice, s->a_reg, 0, 80, N);
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
&dest_a_slice, &src_key_slice);
reverse_bitsliced_radix_inplace<Torus>(streams, mem, s->a_reg, 80);
CudaRadixCiphertextFFI src_iv_slice;
slice_reg_batch<Torus>(&src_iv_slice, iv_bitsliced, 0, 80, N);
CudaRadixCiphertextFFI dest_b_slice;
slice_reg_batch<Torus>(&dest_b_slice, s->b_reg, 0, 80, N);
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
&dest_b_slice, &src_iv_slice);
reverse_bitsliced_radix_inplace<Torus>(streams, mem, s->b_reg, 80);
CudaRadixCiphertextFFI dest_c_ones;
slice_reg_batch<Torus>(&dest_c_ones, s->c_reg, 108, 3, N);
host_add_scalar_one_inplace<Torus>(streams, &dest_c_ones,
mem->params.message_modulus,
mem->params.carry_modulus);
integer_radix_apply_univariate_lookup_table<Torus>(
streams, &dest_c_ones, &dest_c_ones, bsks, ksks, mem->luts->flush_lut,
dest_c_ones.num_radix_blocks);
for (int i = 0; i < 18; i++) {
trivium_compute_64_steps(streams, mem, nullptr, bsks, ksks);
}
}
// Main entry point: checks input validity, initializes state,
// and loops to generate the keystream in batches of 64.
template <typename Torus>
__host__ void host_trivium_generate_keystream(
CudaStreams streams, CudaRadixCiphertextFFI *keystream_output,
CudaRadixCiphertextFFI const *key_bitsliced,
CudaRadixCiphertextFFI const *iv_bitsliced, uint32_t num_inputs,
uint32_t num_steps, int_trivium_buffer<Torus> *mem, void *const *bsks,
uint64_t *const *ksks) {
PANIC_IF_FALSE(num_steps % 64 == 0,
"Trivium Error: num_steps must be a multiple of 64.\n");
trivium_init(streams, mem, key_bitsliced, iv_bitsliced, bsks, ksks);
uint32_t num_batches = num_steps / 64;
for (uint32_t i = 0; i < num_batches; i++) {
CudaRadixCiphertextFFI batch_out_slice;
slice_reg_batch<Torus>(&batch_out_slice, keystream_output, i * 64, 64,
num_inputs);
trivium_compute_64_steps(streams, mem, &batch_out_slice, bsks, ksks);
}
}
template <typename Torus>
uint64_t scratch_cuda_trivium_encrypt(CudaStreams streams,
int_trivium_buffer<Torus> **mem_ptr,
int_radix_params params,
bool allocate_gpu_memory,
uint32_t num_inputs) {
uint64_t size_tracker = 0;
*mem_ptr = new int_trivium_buffer<Torus>(streams, params, allocate_gpu_memory,
num_inputs, size_tracker);
return size_tracker;
}
#endif

View File

@@ -2511,6 +2511,42 @@ unsafe extern "C" {
mem_ptr_void: *mut *mut i8,
);
}
unsafe extern "C" {
pub fn scratch_cuda_trivium_64(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
num_inputs: u32,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_trivium_generate_keystream_64(
streams: CudaStreamsFFI,
keystream_output: *mut CudaRadixCiphertextFFI,
key: *const CudaRadixCiphertextFFI,
iv: *const CudaRadixCiphertextFFI,
num_inputs: u32,
num_steps: u32,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_trivium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0;
pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1;
pub type KS_TYPE = ffi::c_uint;

View File

@@ -4,6 +4,7 @@
#include "cuda/include/integer/integer.h"
#include "cuda/include/integer/rerand.h"
#include "cuda/include/aes/aes.h"
#include "cuda/include/trivium/trivium.h"
#include "cuda/include/zk/zk.h"
#include "cuda/include/keyswitch/keyswitch.h"
#include "cuda/include/keyswitch/ks_enums.h"

View File

@@ -133,6 +133,12 @@ path = "benches/integer/aes.rs"
harness = false
required-features = ["integer", "internal-keycache"]
[[bench]]
name = "integer-trivium"
path = "benches/integer/trivium.rs"
harness = false
required-features = ["integer", "internal-keycache"]
[[bench]]
name = "integer-aes256"
path = "benches/integer/aes256.rs"

View File

@@ -3,6 +3,7 @@
mod aes;
mod aes256;
mod oprf;
mod trivium;
mod vector_find;
mod rerand;

View File

@@ -0,0 +1,89 @@
use criterion::Criterion;
#[cfg(feature = "gpu")]
pub mod cuda {
use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
use benchmark::utilities::{write_to_json, OperatorType};
use criterion::{black_box, criterion_group, Criterion};
use tfhe::core_crypto::gpu::CudaStreams;
use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use tfhe::integer::gpu::CudaServerKey;
use tfhe::integer::keycache::KEY_CACHE;
use tfhe::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey};
use tfhe::keycache::NamedParam;
use tfhe::shortint::AtomicPatternParameters;
fn encrypt_bit_stream(cks: &RadixClientKey, bits: &[u64]) -> RadixCiphertext {
RadixCiphertext::from(
bits.iter()
.map(|&bit| cks.encrypt_one_block(bit))
.collect::<Vec<_>>(),
)
}
pub fn cuda_trivium(c: &mut Criterion) {
let bench_name = "integer::cuda::trivium";
let mut bench_group = c.benchmark_group(bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60))
.warm_up_time(std::time::Duration::from_secs(5));
let param = BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let atomic_param: AtomicPatternParameters = param.into();
let key_bits = vec![0u64; 80];
let iv_bits = vec![0u64; 80];
let param_name = param.name();
let streams = CudaStreams::new_multi_gpu();
let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix);
let sks = CudaServerKey::new(&cpu_cks, &streams);
let cks = RadixClientKey::from((cpu_cks, 1));
let ct_key = encrypt_bit_stream(&cks, &key_bits);
let ct_iv = encrypt_bit_stream(&cks, &iv_bits);
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
for num_steps in [64, 512] {
let bench_id = format!("{bench_name}::{param_name}::generate_{num_steps}_bits");
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
black_box(
sks.trivium_generate_keystream(&d_key, &d_iv, num_steps, &streams)
.unwrap(),
);
})
});
write_to_json::<u64, _>(
&bench_id,
atomic_param,
param.name(),
&format!("trivium_generation_{}_bits", num_steps),
&OperatorType::Atomic,
80,
vec![atomic_param.message_modulus().0.ilog2(); 80],
);
}
bench_group.finish();
}
criterion_group!(gpu_trivium, cuda_trivium);
}
#[cfg(feature = "gpu")]
use cuda::gpu_trivium;
fn main() {
#[cfg(feature = "gpu")]
gpu_trivium();
Criterion::default().configure_from_args().final_summary();
}

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 181 KiB

View File

@@ -19,11 +19,13 @@ The overall process to write an homomorphic program is the same for all types. T
This library has different modules, with different levels of abstraction.
There is the **core\_crypto** module, which is the lowest level API with the primitive functions and types of the TFHE scheme.
There is the [core\_crypto](../core-crypto-api/presentation.md) module, which is the lowest level API with the primitive functions and types of the TFHE scheme.
Above the core\_crypto module, there are the **Boolean**, **shortint**, and **integer** modules, which contain easy to use APIs enabling evaluation of Boolean, short integer, and integer circuits.
Above the core\_crypto module, there are the [Boolean](boolean/README.md), [shortint](shortint/README.md), and [integer](integer/README.md) modules, which contain easy to use APIs enabling evaluation of Boolean, short integer, and integer circuits.
Finally, there is the high-level module built on top of the Boolean, shortint, integer modules. This module is meant to abstract cryptographic complexities: no cryptographical knowledge is required to start developing an FHE application. Another benefit of the high-level module is the drastically simplified development process compared to lower level modules.
Finally, there is the high-level module built on top of the shortint and integer modules. This module is meant to abstract cryptographic complexities: no cryptographical knowledge is required to start developing an FHE application. Another benefit of the high-level module is the drastically simplified development process compared to lower level modules.
![API levels diagram](../../.gitbook/assets/api-levels.svg)
#### high-level API

View File

@@ -1084,18 +1084,10 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
);
assert_eq!(
input.ciphertext_modulus(),
output.ciphertext_modulus(),
"Mismatched CiphertextModulus between input ({:?}) and output ({:?})",
input.ciphertext_modulus(),
output.ciphertext_modulus(),
);
assert_eq!(
input.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
input.ciphertext_modulus(),
"Mismatched CiphertextModulus between output ({:?}) and accumulator ({:?})",
output.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
);
@@ -1795,18 +1787,10 @@ pub fn std_multi_bit_programmable_bootstrap_lwe_ciphertext<
);
assert_eq!(
input.ciphertext_modulus(),
output.ciphertext_modulus(),
"Mismatched CiphertextModulus between input ({:?}) and output ({:?})",
input.ciphertext_modulus(),
output.ciphertext_modulus(),
);
assert_eq!(
input.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
input.ciphertext_modulus(),
"Mismatched CiphertextModulus between output ({:?}) and accumulator ({:?})",
output.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
);
@@ -1868,14 +1852,6 @@ pub fn std_multi_bit_f128_blind_rotate_assign<Scalar, InputCont, OutputCont, Key
multi_bit_bsk.input_lwe_dimension(),
);
assert_eq!(
input.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
input.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
);
let grouping_factor = multi_bit_bsk.grouping_factor();
let lut_poly_size = accumulator.polynomial_size();
@@ -2213,18 +2189,10 @@ pub fn std_multi_bit_programmable_bootstrap_f128_lwe_ciphertext<
);
assert_eq!(
input.ciphertext_modulus(),
output.ciphertext_modulus(),
"Mismatched CiphertextModulus between input ({:?}) and output ({:?})",
input.ciphertext_modulus(),
output.ciphertext_modulus(),
);
assert_eq!(
input.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
input.ciphertext_modulus(),
"Mismatched CiphertextModulus between output ({:?}) and accumulator ({:?})",
output.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
);
@@ -2718,3 +2686,224 @@ pub fn multi_bit_programmable_bootstrap_f128_lwe_ciphertext<
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
}
// ============== Noise measurement trait implementations ============== //
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
AllocateMultiBitModSwitchResult, LweMultiBitFft128BlindRotate, LweMultiBitFft128Bootstrap,
LweMultiBitFftBlindRotate, LweMultiBitFftBootstrap, MultiBitModSwitch,
};
impl<
Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
C: Container<Element = Scalar> + Sync,
> AllocateMultiBitModSwitchResult for LweCiphertext<C>
{
type Output = StandardMultiBitModulusSwitchedCt<Scalar, Vec<Scalar>>;
type SideResources = ();
fn allocate_multi_bit_mod_switch_result(
&self,
_side_resources: &mut Self::SideResources,
) -> Self::Output {
// We will mod switch but we keep the current modulus as the noise is interesting in the
// context of the input modulus
Self::Output {
input: LweCiphertextOwned::from_container(
self.as_ref().to_vec(),
self.ciphertext_modulus(),
),
// Placeholder values for those as they will be filled during mod switch
// Choose defaults that should crash things
grouping_factor: LweBskGroupingFactor(usize::MAX),
log_modulus: CiphertextModulusLog(usize::MAX),
}
}
}
impl<
Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
C: Container<Element = Scalar>,
OutCont: ContainerMut<Element = Scalar> + Sync,
> MultiBitModSwitch<StandardMultiBitModulusSwitchedCt<Scalar, OutCont>> for LweCiphertext<C>
{
type SideResources = ();
fn multi_bit_mod_switch(
&self,
grouping_factor: LweBskGroupingFactor,
output_modulus_log: CiphertextModulusLog,
output: &mut StandardMultiBitModulusSwitchedCt<Scalar, OutCont>,
_side_resources: &mut Self::SideResources,
) {
assert_eq!(output.input.ciphertext_modulus(), self.ciphertext_modulus());
let StandardMultiBitModulusSwitchedCt {
input: output_input,
grouping_factor: output_grouping_factor,
log_modulus: output_log_modulus,
} = output;
*output_grouping_factor = grouping_factor;
*output_log_modulus = output_modulus_log;
output_input.as_mut().copy_from_slice(self.as_ref());
// Nothing to do given the mod switch will be lazily evaluated by the following multi bit
// blind rotation
}
}
impl<
InputScalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
InputCont: Container<Element = InputScalar> + Sync,
OutputScalar: UnsignedTorus + Sync,
OutputCont: ContainerMut<Element = OutputScalar>,
AccCont: Container<Element = OutputScalar>,
KeyCont: Container<Element = c64> + Sync,
>
LweMultiBitFftBlindRotate<
StandardMultiBitModulusSwitchedCt<InputScalar, InputCont>,
LweCiphertext<OutputCont>,
GlweCiphertext<AccCont>,
> for FourierLweMultiBitBootstrapKey<KeyCont>
{
type SideResources = (ThreadCount, bool);
fn lwe_multi_bit_fft_blind_rotate(
&self,
input: &StandardMultiBitModulusSwitchedCt<InputScalar, InputCont>,
output: &mut LweCiphertext<OutputCont>,
accumulator: &GlweCiphertext<AccCont>,
side_resources: &mut Self::SideResources,
) {
let (thread_count, deterministic_execution) = *side_resources;
let mut local_accumulator = GlweCiphertext::from_container(
accumulator.as_ref().to_vec(),
accumulator.polynomial_size(),
accumulator.ciphertext_modulus(),
);
multi_bit_blind_rotate_assign(
input,
&mut local_accumulator,
self,
thread_count,
deterministic_execution,
);
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
}
}
impl<
InputScalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
InputCont: Container<Element = InputScalar> + Sync,
OutputScalar: UnsignedTorus + Sync,
OutputCont: ContainerMut<Element = OutputScalar>,
AccCont: Container<Element = OutputScalar>,
KeyCont: Container<Element = f64> + Sync + Split,
>
LweMultiBitFft128BlindRotate<
StandardMultiBitModulusSwitchedCt<InputScalar, InputCont>,
LweCiphertext<OutputCont>,
GlweCiphertext<AccCont>,
> for Fourier128LweMultiBitBootstrapKey<KeyCont>
{
type SideResources = ThreadCount;
fn lwe_multi_bit_fft_128_blind_rotate(
&self,
input: &StandardMultiBitModulusSwitchedCt<InputScalar, InputCont>,
output: &mut LweCiphertext<OutputCont>,
accumulator: &GlweCiphertext<AccCont>,
side_resources: &mut Self::SideResources,
) {
let thread_count = *side_resources;
let mut local_accumulator = GlweCiphertext::from_container(
accumulator.as_ref().to_vec(),
accumulator.polynomial_size(),
accumulator.ciphertext_modulus(),
);
multi_bit_f128_deterministic_blind_rotate_assign(
input,
&mut local_accumulator,
self,
thread_count,
);
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
}
}
impl<
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>,
InputCont: Container<Element = Scalar> + Sync,
OutputCont: ContainerMut<Element = Scalar>,
AccCont: Container<Element = Scalar>,
KeyCont: Container<Element = c64> + Sync,
>
LweMultiBitFftBootstrap<
LweCiphertext<InputCont>,
LweCiphertext<OutputCont>,
GlweCiphertext<AccCont>,
> for FourierLweMultiBitBootstrapKey<KeyCont>
{
type SideResources = (ThreadCount, bool);
fn lwe_multi_bit_fft_bootstrap(
&self,
input: &LweCiphertext<InputCont>,
output: &mut LweCiphertext<OutputCont>,
accumulator: &GlweCiphertext<AccCont>,
side_resources: &mut Self::SideResources,
) {
let (thread_count, deterministic_execution) = *side_resources;
multi_bit_programmable_bootstrap_lwe_ciphertext(
input,
output,
accumulator,
self,
thread_count,
deterministic_execution,
);
}
}
impl<
InputScalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>,
InputCont: Container<Element = InputScalar> + Sync,
OutputScalar: UnsignedTorus + Sync,
OutputCont: ContainerMut<Element = OutputScalar>,
AccCont: Container<Element = OutputScalar>,
KeyCont: Container<Element = f64> + Sync + Split,
>
LweMultiBitFft128Bootstrap<
LweCiphertext<InputCont>,
LweCiphertext<OutputCont>,
GlweCiphertext<AccCont>,
> for Fourier128LweMultiBitBootstrapKey<KeyCont>
{
type SideResources = (ThreadCount, bool);
fn lwe_multi_bit_fft_128_bootstrap(
&self,
input: &LweCiphertext<InputCont>,
output: &mut LweCiphertext<OutputCont>,
accumulator: &GlweCiphertext<AccCont>,
side_resources: &mut Self::SideResources,
) {
let (thread_count, deterministic_execution) = *side_resources;
multi_bit_programmable_bootstrap_f128_lwe_ciphertext(
input,
output,
accumulator,
self,
thread_count,
deterministic_execution,
);
}
}

View File

@@ -77,7 +77,9 @@ where
}
// ============== Noise measurement trait implementations ============== //
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::AllocateLweBootstrapResult;
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
AllocateLweBootstrapResult, AllocateLweMultiBitBlindRotateResult,
};
impl<Scalar: UnsignedInteger, AccCont: Container<Element = Scalar>> AllocateLweBootstrapResult
for GlweCiphertext<AccCont>
@@ -100,3 +102,25 @@ impl<Scalar: UnsignedInteger, AccCont: Container<Element = Scalar>> AllocateLweB
)
}
}
impl<Scalar: UnsignedInteger, AccCont: Container<Element = Scalar>>
AllocateLweMultiBitBlindRotateResult for GlweCiphertext<AccCont>
{
type Output = LweCiphertextOwned<Scalar>;
type SideResources = ();
fn allocate_lwe_multi_bit_blind_rotate_result(
&self,
_side_resources: &mut Self::SideResources,
) -> Self::Output {
let glwe_dim = self.glwe_size().to_glwe_dimension();
let poly_size = self.polynomial_size();
let equivalent_lwe_dim = glwe_dim.to_equivalent_lwe_dimension(poly_size);
LweCiphertext::new(
Scalar::ZERO,
equivalent_lwe_dim.to_lwe_size(),
self.ciphertext_modulus(),
)
}
}

View File

@@ -3,6 +3,7 @@ use crate::core_crypto::commons::noise_formulas::lwe_multi_bit_programmable_boot
multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul,
multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul,
};
use crate::core_crypto::commons::noise_formulas::noise_simulation::PBS_FFT_64_MANTISSA_SIZE;
use crate::core_crypto::commons::noise_formulas::secure_noise::{
minimal_lwe_variance_for_132_bits_security_gaussian,
minimal_lwe_variance_for_132_bits_security_tuniform,
@@ -48,6 +49,7 @@ where
polynomial_size,
pbs_decomposition_base_log,
pbs_decomposition_level_count,
PBS_FFT_64_MANTISSA_SIZE,
modulus_as_f64,
)
}
@@ -58,6 +60,7 @@ where
polynomial_size,
pbs_decomposition_base_log,
pbs_decomposition_level_count,
PBS_FFT_64_MANTISSA_SIZE,
modulus_as_f64,
)
}

View File

@@ -1,5 +1,6 @@
use super::*;
use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap::pbs_variance_132_bits_security_gaussian_fft_mul;
use crate::core_crypto::commons::noise_formulas::noise_simulation::PBS_FFT_64_MANTISSA_SIZE;
use crate::core_crypto::commons::noise_formulas::secure_noise::minimal_lwe_variance_for_132_bits_security_gaussian;
use crate::core_crypto::commons::test_tools::{torus_modular_diff, variance};
use rayon::prelude::*;
@@ -38,6 +39,7 @@ where
polynomial_size,
pbs_decomposition_base_log,
pbs_decomposition_level_count,
PBS_FFT_64_MANTISSA_SIZE,
modulus_as_f64,
);

View File

@@ -16,6 +16,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_2_fft_mul(
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(
@@ -25,6 +26,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_2_fft_mul(
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
),
)
@@ -40,16 +42,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_2_fft_mul_impl(
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
(1_f64 / 2.0)
* input_lwe_dimension
* (0.0022
* (2.0
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
core::f64::consts::LOG2_E * modulus.ln() - 53.0
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())
@@ -86,6 +89,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul(
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(
@@ -95,6 +99,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul(
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
),
)
@@ -110,16 +115,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul_impl(
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
(1_f64 / 3.0)
* input_lwe_dimension
* (0.00492
* (2.0
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
core::f64::consts::LOG2_E * modulus.ln() - 53.0
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())
@@ -156,6 +162,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul(
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(
@@ -165,6 +172,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul(
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
),
)
@@ -180,16 +188,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul_impl(
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
(1_f64 / 4.0)
* input_lwe_dimension
* (0.00855
* (2.0
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
core::f64::consts::LOG2_E * modulus.ln() - 53.0
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())
@@ -226,6 +235,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_2_fft_mul(
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(
@@ -235,6 +245,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_2_fft_mul(
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
),
)
@@ -250,16 +261,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_2_fft_mul_impl(
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
(1_f64 / 2.0)
* input_lwe_dimension
* (0.0022
* (2.0
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
core::f64::consts::LOG2_E * modulus.ln() - 53.0
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())
@@ -302,6 +314,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul(
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(
@@ -311,6 +324,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul(
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
),
)
@@ -326,16 +340,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul_impl(
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
(1_f64 / 3.0)
* input_lwe_dimension
* (0.00492
* (2.0
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
core::f64::consts::LOG2_E * modulus.ln() - 53.0
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())
@@ -378,6 +393,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul(
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(
@@ -387,6 +403,7 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul(
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
),
)
@@ -402,16 +419,17 @@ pub fn multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul_impl(
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
(1_f64 / 4.0)
* input_lwe_dimension
* (0.00855
* (2.0
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
core::f64::consts::LOG2_E * modulus.ln() - 53.0
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())

View File

@@ -16,6 +16,7 @@ pub fn pbs_variance_132_bits_security_gaussian_fft_mul(
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(pbs_variance_132_bits_security_gaussian_fft_mul_impl(
@@ -24,6 +25,7 @@ pub fn pbs_variance_132_bits_security_gaussian_fft_mul(
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
))
}
@@ -38,15 +40,16 @@ pub fn pbs_variance_132_bits_security_gaussian_fft_mul_impl(
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
input_lwe_dimension
* (0.00705
* (2.0
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
core::f64::consts::LOG2_E * modulus.ln() - 53.0
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())
@@ -83,6 +86,7 @@ pub fn pbs_variance_132_bits_security_tuniform_fft_mul(
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(pbs_variance_132_bits_security_tuniform_fft_mul_impl(
@@ -91,6 +95,7 @@ pub fn pbs_variance_132_bits_security_tuniform_fft_mul(
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
))
}
@@ -105,15 +110,16 @@ pub fn pbs_variance_132_bits_security_tuniform_fft_mul_impl(
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
input_lwe_dimension
* (0.00705
* (2.0
* if (core::f64::consts::LOG2_E * modulus.ln() - 53.0 <= 0.0) {
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
core::f64::consts::LOG2_E * modulus.ln() - 53.0
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())

View File

@@ -1,153 +0,0 @@
// This file was autogenerated, do not modify by hand.
#![allow(unused_parens)]
#![allow(clippy::neg_multiply)]
#![allow(clippy::suspicious_operation_groupings)]
use crate::core_crypto::commons::dispersion::Variance;
use crate::core_crypto::commons::parameters::*;
/// This formula is only valid if the proper noise distributions are used and
/// if the keys used are encrypted using secure noise given by the
/// [`minimal_glwe_variance`](`super::secure_noise`)
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
pub fn pbs_128_variance_132_bits_security_gaussian_fft_mul(
input_lwe_dimension: LweDimension,
output_glwe_dimension: GlweDimension,
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(pbs_128_variance_132_bits_security_gaussian_fft_mul_impl(
input_lwe_dimension.0 as f64,
output_glwe_dimension.0 as f64,
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
))
}
/// This formula is only valid if the proper noise distributions are used and
/// if the keys used are encrypted using secure noise given by the
/// [`minimal_glwe_variance`](`super::secure_noise`)
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
pub fn pbs_128_variance_132_bits_security_gaussian_fft_mul_impl(
input_lwe_dimension: f64,
output_glwe_dimension: f64,
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
input_lwe_dimension
* (0.00705
* (2.0
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())
.exp2()
* decomposition_level_count.powf(1.01827)
* output_glwe_dimension.powf(1.22003)
* output_polynomial_size.powf(2.22003)
* (output_glwe_dimension + 1.0).powf(1.01827)
+ decomposition_level_count
* output_polynomial_size
* ((4.0 - 2.88539008177793 * modulus.ln()).exp2()
+ (-0.0497829131652661 * output_glwe_dimension * output_polynomial_size
+ 5.31469187675068)
.exp2())
* ((1_f64 / 12.0) * decomposition_base.powf(2.0) + 0.166666666666667)
* (output_glwe_dimension + 1.0)
- 1_f64 / 24.0 * modulus.powf(-2.0)
+ (1_f64 / 2.0)
* output_glwe_dimension
* output_polynomial_size
* (0.0208333333333333 * modulus.powf(-2.0)
+ 0.0416666666666667
* decomposition_base.powf(-2.0 * decomposition_level_count))
+ (1_f64 / 24.0) * decomposition_base.powf(-2.0 * decomposition_level_count))
}
/// This formula is only valid if the proper noise distributions are used and
/// if the keys used are encrypted using secure noise given by the
/// [`minimal_glwe_variance`](`super::secure_noise`)
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
pub fn pbs_128_variance_132_bits_security_tuniform_fft_mul(
input_lwe_dimension: LweDimension,
output_glwe_dimension: GlweDimension,
output_polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
mantissa_size: f64,
modulus: f64,
) -> Variance {
Variance(pbs_128_variance_132_bits_security_tuniform_fft_mul_impl(
input_lwe_dimension.0 as f64,
output_glwe_dimension.0 as f64,
output_polynomial_size.0 as f64,
2.0f64.powi(decomposition_base_log.0 as i32),
decomposition_level_count.0 as f64,
mantissa_size,
modulus,
))
}
/// This formula is only valid if the proper noise distributions are used and
/// if the keys used are encrypted using secure noise given by the
/// [`minimal_glwe_variance`](`super::secure_noise`)
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
pub fn pbs_128_variance_132_bits_security_tuniform_fft_mul_impl(
input_lwe_dimension: f64,
output_glwe_dimension: f64,
output_polynomial_size: f64,
decomposition_base: f64,
decomposition_level_count: f64,
mantissa_size: f64,
modulus: f64,
) -> f64 {
input_lwe_dimension
* (0.00705
* (2.0
* if (1.0 * mantissa_size - core::f64::consts::LOG2_E * modulus.ln() >= 0.0) {
0.0
} else {
-1.0 * mantissa_size + core::f64::consts::LOG2_E * modulus.ln()
}
+ 2.88539008177793 * decomposition_base.ln()
- 2.88539008177793 * modulus.ln())
.exp2()
* decomposition_level_count.powf(1.01827)
* output_glwe_dimension.powf(1.22003)
* output_polynomial_size.powf(2.22003)
* (output_glwe_dimension + 1.0).powf(1.01827)
+ decomposition_level_count
* output_polynomial_size
* ((4.44 - 2.88539008177793 * modulus.ln()).exp2()
+ (1_f64 / 3.0)
* modulus.powf(-2.0)
* ((2.0
* (-0.025167785 * output_glwe_dimension * output_polynomial_size
+ core::f64::consts::LOG2_E * modulus.ln()
+ 4.10067100000001)
.ceil())
.exp2()
+ 0.5))
* ((1_f64 / 12.0) * decomposition_base.powf(2.0) + 0.166666666666667)
* (output_glwe_dimension + 1.0)
- 1_f64 / 24.0 * modulus.powf(-2.0)
+ (1_f64 / 2.0)
* output_glwe_dimension
* output_polynomial_size
* (0.0208333333333333 * modulus.powf(-2.0)
+ 0.0416666666666667
* decomposition_base.powf(-2.0 * decomposition_level_count))
+ (1_f64 / 24.0) * decomposition_base.powf(-2.0 * decomposition_level_count))
}

View File

@@ -5,7 +5,6 @@ pub mod lwe_keyswitch;
pub mod lwe_multi_bit_programmable_bootstrap;
pub mod lwe_packing_keyswitch;
pub mod lwe_programmable_bootstrap;
pub mod lwe_programmable_bootstrap_128;
pub mod modulus_switch;
pub mod multi_bit_modulus_switch;
pub mod secure_noise;

View File

@@ -7,16 +7,22 @@ use crate::core_crypto::commons::noise_formulas::lwe_multi_bit_programmable_boot
multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul,
multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul,
};
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::LweMultiBitFftBlindRotate;
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
LweMultiBitFft128BlindRotate, LweMultiBitFft128Bootstrap, LweMultiBitFftBlindRotate,
LweMultiBitFftBootstrap,
};
use crate::core_crypto::commons::noise_formulas::noise_simulation::{
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationModulus,
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationModulus, PBS_FFT_128_MANTISSA_SIZE,
PBS_FFT_64_MANTISSA_SIZE,
};
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweSize,
LweBskGroupingFactor, LweDimension, PolynomialSize,
};
use crate::core_crypto::commons::traits::container::Container;
use crate::core_crypto::entities::lwe_multi_bit_bootstrap_key::FourierLweMultiBitBootstrapKey;
use crate::core_crypto::entities::lwe_multi_bit_bootstrap_key::{
Fourier128LweMultiBitBootstrapKey, FourierLweMultiBitBootstrapKey,
};
use crate::core_crypto::fft_impl::fft64::c64;
#[derive(Clone, Copy)]
@@ -116,6 +122,11 @@ impl NoiseSimulationLweMultiBitFourierBsk {
pub fn modulus(&self) -> NoiseSimulationModulus {
self.modulus
}
pub fn mantissa_size(&self) -> f64 {
let _ = self;
PBS_FFT_64_MANTISSA_SIZE
}
}
impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
@@ -147,6 +158,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
3 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul(
@@ -155,6 +167,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
4 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul(
@@ -163,6 +176,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
gf => panic!("Unsupported grouping factor: {gf}"),
@@ -174,6 +188,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
3 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul(
@@ -182,6 +197,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
4 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul(
@@ -190,6 +206,7 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
gf => panic!("Unsupported grouping factor: {gf}"),
@@ -208,3 +225,238 @@ impl LweMultiBitFftBlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
);
}
}
impl LweMultiBitFftBootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
for NoiseSimulationLweMultiBitFourierBsk
{
type SideResources = ();
fn lwe_multi_bit_fft_bootstrap(
&self,
input: &NoiseSimulationLwe,
output: &mut NoiseSimulationLwe,
accumulator: &NoiseSimulationGlwe,
side_resources: &mut Self::SideResources,
) {
// Noise-wise it is the same
self.lwe_multi_bit_fft_blind_rotate(input, output, accumulator, side_resources);
}
}
#[derive(Clone, Copy)]
pub struct NoiseSimulationLweMultiBitFourier128Bsk {
input_lwe_dimension: LweDimension,
output_glwe_size: GlweSize,
output_polynomial_size: PolynomialSize,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
noise_distribution: DynamicDistribution<u128>,
modulus: NoiseSimulationModulus,
}
impl NoiseSimulationLweMultiBitFourier128Bsk {
#[allow(clippy::too_many_arguments)]
pub fn new(
input_lwe_dimension: LweDimension,
output_glwe_size: GlweSize,
output_polynomial_size: PolynomialSize,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
noise_distribution: DynamicDistribution<u128>,
modulus: NoiseSimulationModulus,
) -> Self {
Self {
input_lwe_dimension,
output_glwe_size,
output_polynomial_size,
decomp_base_log,
decomp_level_count,
grouping_factor,
noise_distribution,
modulus,
}
}
pub fn matches_actual_bsk<C: Container<Element = f64>>(
&self,
lwe_bsk: &Fourier128LweMultiBitBootstrapKey<C>,
) -> bool {
let Self {
input_lwe_dimension,
output_glwe_size: glwe_size,
output_polynomial_size: polynomial_size,
decomp_base_log,
decomp_level_count,
grouping_factor,
noise_distribution: _,
modulus: _,
} = *self;
let bsk_input_lwe_dimension = lwe_bsk.input_lwe_dimension();
let bsk_glwe_size = lwe_bsk.glwe_size();
let bsk_polynomial_size = lwe_bsk.polynomial_size();
let bsk_decomp_base_log = lwe_bsk.decomposition_base_log();
let bsk_decomp_level_count = lwe_bsk.decomposition_level_count();
let bsk_grouping_factor = lwe_bsk.grouping_factor();
input_lwe_dimension == bsk_input_lwe_dimension
&& glwe_size == bsk_glwe_size
&& polynomial_size == bsk_polynomial_size
&& decomp_base_log == bsk_decomp_base_log
&& decomp_level_count == bsk_decomp_level_count
&& grouping_factor == bsk_grouping_factor
}
pub fn input_lwe_dimension(&self) -> LweDimension {
self.input_lwe_dimension
}
pub fn output_glwe_size(&self) -> GlweSize {
self.output_glwe_size
}
pub fn output_polynomial_size(&self) -> PolynomialSize {
self.output_polynomial_size
}
pub fn decomp_base_log(&self) -> DecompositionBaseLog {
self.decomp_base_log
}
pub fn decomp_level_count(&self) -> DecompositionLevelCount {
self.decomp_level_count
}
pub fn grouping_factor(&self) -> LweBskGroupingFactor {
self.grouping_factor
}
pub fn noise_distribution(&self) -> DynamicDistribution<u128> {
self.noise_distribution
}
pub fn modulus(&self) -> NoiseSimulationModulus {
self.modulus
}
pub fn mantissa_size(&self) -> f64 {
let _ = self;
PBS_FFT_128_MANTISSA_SIZE
}
}
impl LweMultiBitFft128BlindRotate<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
for NoiseSimulationLweMultiBitFourier128Bsk
{
type SideResources = ();
fn lwe_multi_bit_fft_128_blind_rotate(
&self,
input: &NoiseSimulationLwe,
output: &mut NoiseSimulationLwe,
accumulator: &NoiseSimulationGlwe,
_side_resources: &mut Self::SideResources,
) {
assert_eq!(self.input_lwe_dimension(), input.lwe_dimension());
assert_eq!(
self.output_glwe_size(),
accumulator.glwe_dimension().to_glwe_size()
);
assert_eq!(self.output_polynomial_size(), accumulator.polynomial_size());
assert_eq!(self.modulus(), accumulator.modulus());
let grouping_factor = self.grouping_factor();
let br_additive_variance = match self.noise_distribution() {
DynamicDistribution::Gaussian(_) => match grouping_factor.0 {
2 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_2_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
3 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
4 => multi_bit_pbs_variance_132_bits_security_gaussian_gf_4_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
gf => panic!("Unsupported grouping factor: {gf}"),
},
DynamicDistribution::TUniform(_) => match grouping_factor.0 {
2 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_2_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
3 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
4 => multi_bit_pbs_variance_132_bits_security_tuniform_gf_4_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
gf => panic!("Unsupported grouping factor: {gf}"),
},
};
let output_lwe_dimension = self
.output_glwe_size()
.to_glwe_dimension()
.to_equivalent_lwe_dimension(self.output_polynomial_size());
*output = NoiseSimulationLwe::new(
output_lwe_dimension,
Variance(accumulator.variance_per_occupied_slot().0 + br_additive_variance.0),
accumulator.modulus(),
);
}
}
impl LweMultiBitFft128Bootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
for NoiseSimulationLweMultiBitFourier128Bsk
{
type SideResources = ();
fn lwe_multi_bit_fft_128_bootstrap(
&self,
input: &NoiseSimulationLwe,
output: &mut NoiseSimulationLwe,
accumulator: &NoiseSimulationGlwe,
side_resources: &mut Self::SideResources,
) {
// Noise-wise it is the same
self.lwe_multi_bit_fft_128_blind_rotate(input, output, accumulator, side_resources);
}
}

View File

@@ -3,15 +3,12 @@ use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap::{
pbs_variance_132_bits_security_gaussian_fft_mul,
pbs_variance_132_bits_security_tuniform_fft_mul,
};
use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap_128::{
pbs_128_variance_132_bits_security_gaussian_fft_mul,
pbs_128_variance_132_bits_security_tuniform_fft_mul,
};
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
LweClassicFft128Bootstrap, LweClassicFftBootstrap,
};
use crate::core_crypto::commons::noise_formulas::noise_simulation::{
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationModulus,
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationModulus, PBS_FFT_128_MANTISSA_SIZE,
PBS_FFT_64_MANTISSA_SIZE,
};
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweSize, LweDimension,
@@ -108,6 +105,11 @@ impl NoiseSimulationLweFourierBsk {
pub fn modulus(&self) -> NoiseSimulationModulus {
self.modulus
}
pub fn mantissa_size(&self) -> f64 {
let _ = self;
PBS_FFT_64_MANTISSA_SIZE
}
}
impl LweClassicFftBootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
@@ -137,6 +139,7 @@ impl LweClassicFftBootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulat
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
DynamicDistribution::TUniform(_) => pbs_variance_132_bits_security_tuniform_fft_mul(
@@ -145,6 +148,7 @@ impl LweClassicFftBootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulat
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
};
@@ -248,6 +252,11 @@ impl NoiseSimulationLweFourier128Bsk {
pub fn modulus(&self) -> NoiseSimulationModulus {
self.modulus
}
pub fn mantissa_size(&self) -> f64 {
let _ = self;
PBS_FFT_128_MANTISSA_SIZE
}
}
impl LweClassicFft128Bootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimulationGlwe>
@@ -271,30 +280,24 @@ impl LweClassicFft128Bootstrap<NoiseSimulationLwe, NoiseSimulationLwe, NoiseSimu
assert_eq!(self.modulus(), accumulator.modulus());
let br_additive_variance = match self.noise_distribution() {
DynamicDistribution::Gaussian(_) => {
pbs_128_variance_132_bits_security_gaussian_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
// Current PBS 128 implem has 104 bits of equivalent mantissa
104.0f64,
self.modulus().as_f64(),
)
}
DynamicDistribution::TUniform(_) => {
pbs_128_variance_132_bits_security_tuniform_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
// Current PBS 128 implem has 104 bits of equivalent mantissa
104.0f64,
self.modulus().as_f64(),
)
}
DynamicDistribution::Gaussian(_) => pbs_variance_132_bits_security_gaussian_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
DynamicDistribution::TUniform(_) => pbs_variance_132_bits_security_tuniform_fft_mul(
self.input_lwe_dimension(),
self.output_glwe_size().to_glwe_dimension(),
self.output_polynomial_size(),
self.decomp_base_log(),
self.decomp_level_count(),
self.mantissa_size(),
self.modulus().as_f64(),
),
};
let output_lwe_dimension = self

View File

@@ -6,6 +6,9 @@ pub mod modulus_switch;
pub mod traits;
pub use lwe_keyswitch::NoiseSimulationLweKeyswitchKey;
pub use lwe_multi_bit_programmable_bootstrap::{
NoiseSimulationLweMultiBitFourier128Bsk, NoiseSimulationLweMultiBitFourierBsk,
};
pub use lwe_packing_keyswitch::NoiseSimulationLwePackingKeyswitchKey;
pub use lwe_programmable_bootstrap::{
NoiseSimulationLweFourier128Bsk, NoiseSimulationLweFourierBsk,
@@ -23,6 +26,9 @@ use crate::core_crypto::commons::parameters::{
CiphertextModulusLog, GlweDimension, GlweSize, LweDimension, PolynomialSize,
};
pub const PBS_FFT_64_MANTISSA_SIZE: f64 = 53.;
pub const PBS_FFT_128_MANTISSA_SIZE: f64 = 104.;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NoiseSimulationModulus {
NativeU128,

View File

@@ -31,3 +31,49 @@ pub trait LweMultiBitFft128BlindRotate<Input, Output, Accumulator> {
side_resources: &mut Self::SideResources,
);
}
pub trait AllocateLweMultiBitBootstrapResult {
type Output;
type SideResources;
fn allocate_lwe_multi_bit_bootstrap_result(
&self,
side_resources: &mut Self::SideResources,
) -> Self::Output;
}
impl<T: AllocateLweMultiBitBlindRotateResult> AllocateLweMultiBitBootstrapResult for T {
type Output = <T as AllocateLweMultiBitBlindRotateResult>::Output;
type SideResources = <T as AllocateLweMultiBitBlindRotateResult>::SideResources;
fn allocate_lwe_multi_bit_bootstrap_result(
&self,
side_resources: &mut Self::SideResources,
) -> Self::Output {
self.allocate_lwe_multi_bit_blind_rotate_result(side_resources)
}
}
pub trait LweMultiBitFftBootstrap<Input, Output, Accumulator> {
type SideResources;
fn lwe_multi_bit_fft_bootstrap(
&self,
input: &Input,
output: &mut Output,
accumulator: &Accumulator,
side_resources: &mut Self::SideResources,
);
}
pub trait LweMultiBitFft128Bootstrap<Input, Output, Accumulator> {
type SideResources;
fn lwe_multi_bit_fft_128_bootstrap(
&self,
input: &Input,
output: &mut Output,
accumulator: &Accumulator,
side_resources: &mut Self::SideResources,
);
}

View File

@@ -9,7 +9,9 @@ pub mod scalar_mul;
pub use add_sub::{LweUncorrelatedAdd, LweUncorrelatedSub};
pub use lwe_keyswitch::{AllocateLweKeyswitchResult, LweKeyswitch};
pub use lwe_multi_bit_programmable_bootstrap::{
AllocateLweMultiBitBlindRotateResult, LweMultiBitFft128BlindRotate, LweMultiBitFftBlindRotate,
AllocateLweMultiBitBlindRotateResult, AllocateLweMultiBitBootstrapResult,
LweMultiBitFft128BlindRotate, LweMultiBitFft128Bootstrap, LweMultiBitFftBlindRotate,
LweMultiBitFftBootstrap,
};
pub use lwe_packing_keyswitch::{AllocateLwePackingKeyswitchResult, LwePackingKeyswitch};
pub use lwe_programmable_bootstrap::{

View File

@@ -3,6 +3,7 @@ use crate::core_crypto::commons::noise_formulas::lwe_multi_bit_programmable_boot
multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul,
multi_bit_pbs_variance_132_bits_security_tuniform_gf_3_fft_mul,
};
use crate::core_crypto::commons::noise_formulas::noise_simulation::PBS_FFT_64_MANTISSA_SIZE;
use crate::core_crypto::commons::noise_formulas::secure_noise::{
minimal_lwe_variance_for_132_bits_security_gaussian,
minimal_lwe_variance_for_132_bits_security_tuniform,
@@ -58,6 +59,7 @@ where
polynomial_size,
pbs_decomposition_base_log,
pbs_decomposition_level_count,
PBS_FFT_64_MANTISSA_SIZE,
modulus_as_f64,
)
}
@@ -68,6 +70,7 @@ where
polynomial_size,
pbs_decomposition_base_log,
pbs_decomposition_level_count,
PBS_FFT_64_MANTISSA_SIZE,
modulus_as_f64,
)
}

View File

@@ -1,5 +1,6 @@
use super::*;
use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap::pbs_variance_132_bits_security_gaussian_fft_mul;
use crate::core_crypto::commons::noise_formulas::noise_simulation::PBS_FFT_64_MANTISSA_SIZE;
use crate::core_crypto::commons::noise_formulas::secure_noise::minimal_lwe_variance_for_132_bits_security_gaussian;
use crate::core_crypto::commons::test_tools::{torus_modular_diff, variance};
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
@@ -49,6 +50,7 @@ where
polynomial_size,
pbs_decomposition_base_log,
pbs_decomposition_level_count,
PBS_FFT_64_MANTISSA_SIZE,
modulus_as_f64,
);

View File

@@ -10424,6 +10424,7 @@ pub unsafe fn unchecked_small_scalar_mul_integer_async(
carry_modulus.0 as u32,
);
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
@@ -10465,3 +10466,98 @@ pub unsafe fn extract_glwe_async<T: UnsignedInteger>(
panic!("Unsupported integer size for CUDA GLWE extraction");
}
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_trivium_generate_keystream<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
keystream_output: &mut CudaRadixCiphertext,
key: &CudaRadixCiphertext,
iv: &CudaRadixCiphertext,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
grouping_factor: LweBskGroupingFactor,
pbs_type: PBSType,
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
num_steps: u32,
) {
let mut keystream_degrees = keystream_output
.info
.blocks
.iter()
.map(|b| b.degree.0)
.collect();
let mut keystream_noise_levels = keystream_output
.info
.blocks
.iter()
.map(|b| b.noise_level.0)
.collect();
let mut cuda_ffi_keystream = prepare_cuda_radix_ffi(
keystream_output,
&mut keystream_degrees,
&mut keystream_noise_levels,
);
let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect();
let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect();
let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels);
let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect();
let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect();
let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels);
let num_inputs = (key.info.blocks.len() / 80) as u32;
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_trivium_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
num_inputs,
);
cuda_trivium_generate_keystream_64(
streams.ffi(),
&raw mut cuda_ffi_keystream,
&raw const cuda_ffi_key,
&raw const cuda_ffi_iv,
num_inputs,
num_steps,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cleanup_cuda_trivium_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(keystream_output, &cuda_ffi_keystream);
}

View File

@@ -66,6 +66,7 @@ mod tests_noise_distribution;
mod tests_signed;
#[cfg(test)]
mod tests_unsigned;
mod trivium;
impl CudaServerKey {
/// Create a trivial ciphertext filled with zeros on the GPU.

View File

@@ -26,7 +26,7 @@ use crate::shortint::parameters::{AtomicPatternParameters, MetaParameters, Varia
use crate::shortint::server_key::tests::noise_distribution::br_dp_ks_ms::br_dp_ks_any_ms;
use crate::shortint::server_key::tests::noise_distribution::should_use_single_key_debug;
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
NoiseSimulationGenericBootstrapKey, NoiseSimulationGlwe, NoiseSimulationLwe,
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulus,
};
use crate::shortint::server_key::tests::noise_distribution::utils::{
@@ -360,7 +360,7 @@ fn noise_check_encrypt_br_dp_ks_ms_noise(params: MetaParameters) {
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_bsk =
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
let gpu_index = 0;
let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));

View File

@@ -17,7 +17,7 @@ use crate::shortint::parameters::test_params::TEST_META_PARAM_CPU_2_2_KS_PBS_PKE
use crate::shortint::parameters::{CompressionParameters, MetaParameters, Variance};
use crate::shortint::server_key::tests::noise_distribution::br_dp_packingks_ms::br_dp_packing_ks_ms;
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
NoiseSimulationGenericBootstrapKey, NoiseSimulationGlwe, NoiseSimulationLwe,
NoiseSimulationLwePackingKeyswitchKey, NoiseSimulationModulus,
};
use crate::shortint::server_key::tests::noise_distribution::utils::{
@@ -412,7 +412,7 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_noise_gpu(meta_params: MetaParameters
let cuda_compression_key = compressed_compression_key.decompress_to_cuda(&streams);
let noise_simulation_bsk =
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_packing_key =
NoiseSimulationLwePackingKeyswitchKey::new_from_comp_parameters(params, comp_params);
@@ -591,7 +591,7 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_pfail_gpu(meta_params: MetaParameters
assert_eq!(original_carry_modulus.0, 4);
let noise_simulation_bsk =
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_packing_key =
NoiseSimulationLwePackingKeyswitchKey::new_from_comp_parameters(params, comp_params);

View File

@@ -1,7 +1,5 @@
use super::utils::noise_simulation::{CudaDynLwe, CudaSideResources};
use crate::core_crypto::commons::noise_formulas::noise_simulation::{
NoiseSimulationLweFourier128Bsk, NoiseSimulationLwePackingKeyswitchKey,
};
use crate::core_crypto::commons::noise_formulas::noise_simulation::NoiseSimulationLwePackingKeyswitchKey;
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::{GlweCiphertext, LweCiphertextCount};
@@ -28,8 +26,9 @@ use crate::shortint::server_key::tests::noise_distribution::dp_ks_pbs128_packing
};
use crate::shortint::server_key::tests::noise_distribution::should_use_single_key_debug;
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
NoiseSimulationGenericBootstrapKey128, NoiseSimulationGlwe, NoiseSimulationLwe,
NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey,
NoiseSimulationModulusSwitchConfig,
};
use crate::shortint::server_key::tests::noise_distribution::utils::{
mean_and_variance_check, DecryptionAndNoiseResult, NoiseSample,
@@ -718,8 +717,10 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise_gpu(meta_params: M
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(atomic_params);
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(atomic_params);
let noise_simulation_bsk128 =
NoiseSimulationLweFourier128Bsk::new_from_parameters(atomic_params, noise_squashing_params);
let noise_simulation_bsk128 = NoiseSimulationGenericBootstrapKey128::new_from_parameters(
atomic_params,
noise_squashing_params,
);
let noise_simulation_packing_key =
NoiseSimulationLwePackingKeyswitchKey::new_from_noise_squashing_parameters(
noise_squashing_params,

View File

@@ -1,9 +1,10 @@
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
AllocateCenteredBinaryShiftedStandardModSwitchResult,
AllocateDriftTechniqueStandardModSwitchResult, AllocateLweBootstrapResult,
AllocateLweKeyswitchResult, AllocateLwePackingKeyswitchResult, AllocateStandardModSwitchResult,
CenteredBinaryShiftedStandardModSwitch, DriftTechniqueStandardModSwitch,
LweClassicFftBootstrap, LweKeyswitch, ScalarMul, StandardModSwitch,
AllocateLweKeyswitchResult, AllocateLwePackingKeyswitchResult, AllocateMultiBitModSwitchResult,
AllocateStandardModSwitchResult, CenteredBinaryShiftedStandardModSwitch,
DriftTechniqueStandardModSwitch, LweClassicFft128Bootstrap, LweClassicFftBootstrap,
LweKeyswitch, MultiBitModSwitch, ScalarMul, StandardModSwitch,
};
use crate::core_crypto::commons::noise_formulas::noise_simulation::{
NoiseSimulationLweFourier128Bsk, NoiseSimulationLweFourierBsk,
@@ -25,8 +26,12 @@ use crate::integer::gpu::server_key::{
use crate::integer::gpu::{
cuda_centered_modulus_switch_64, unchecked_small_scalar_mul_integer_async, CudaStreams,
};
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::NoiseSimulationModulusSwitchConfig;
use crate::shortint::server_key::tests::noise_distribution::utils::traits::LwePackingKeyswitch;
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
NoiseSimulationGenericBootstrapKey, NoiseSimulationModulusSwitchConfig,
};
use crate::shortint::server_key::tests::noise_distribution::utils::traits::{
LweGenericBlindRotate128, LweGenericBootstrap, LwePackingKeyswitch,
};
/// Side resources for CUDA operations in noise simulation
#[derive(Clone)]
pub struct CudaSideResources {
@@ -228,19 +233,22 @@ impl NoiseSimulationLweFourierBsk {
&& decomp_base_log == bsk_decomp_base_log
&& decomp_level_count == bsk_decomp_level_count
}
CudaBootstrappingKey::MultiBit(cuda_mb_bsk) => {
let bsk_input_lwe_dimension = cuda_mb_bsk.input_lwe_dimension();
let bsk_glwe_size = cuda_mb_bsk.glwe_dimension().to_glwe_size();
let bsk_polynomial_size = cuda_mb_bsk.polynomial_size();
let bsk_decomp_base_log = cuda_mb_bsk.decomp_base_log();
let bsk_decomp_level_count = cuda_mb_bsk.decomp_level_count();
// MultiBit key cannot match classic key
CudaBootstrappingKey::MultiBit(_) => false,
}
}
}
input_lwe_dimension == bsk_input_lwe_dimension
&& glwe_size == bsk_glwe_size
&& polynomial_size == bsk_polynomial_size
&& decomp_base_log == bsk_decomp_base_log
&& decomp_level_count == bsk_decomp_level_count
// Extensions for NoiseSimulationGenericBootstrapKey to support GPU operations
impl NoiseSimulationGenericBootstrapKey {
pub fn matches_actual_bsk_gpu(&self, lwe_bsk: &CudaBootstrappingKey<u64>) -> bool {
match self {
Self::Classic(noise_simulation_lwe_fourier_bsk) => {
noise_simulation_lwe_fourier_bsk.matches_actual_bsk_gpu(lwe_bsk)
}
Self::MultiBit(_) => todo!(
"Implement the matching for NoiseSimulationLweMultiBitFourierBsk and forward here"
),
}
}
}
@@ -743,12 +751,8 @@ impl AllocateLweBootstrapResult for CudaGlweCiphertextList<u128> {
}
// Implement LweClassicFft128Bootstrap for CudaNoiseSquashingKey using 128-bit PBS CUDA function
impl
crate::core_crypto::commons::noise_formulas::noise_simulation::traits::LweClassicFft128Bootstrap<
CudaDynLwe,
CudaDynLwe,
CudaGlweCiphertextList<u128>,
> for crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey
impl LweClassicFft128Bootstrap<CudaDynLwe, CudaDynLwe, CudaGlweCiphertextList<u128>>
for CudaNoiseSquashingKey
{
type SideResources = CudaSideResources;
@@ -931,3 +935,83 @@ impl LwePackingKeyswitch<[&CudaDynLwe], CudaGlweCiphertextList<u128>>
);
}
}
// Multi bit and generic extensions
impl LweGenericBootstrap<CudaDynLwe, CudaDynLwe, CudaGlweCiphertextList<u64>> for CudaServerKey {
type SideResources = CudaSideResources;
fn lwe_generic_bootstrap(
&self,
input: &CudaDynLwe,
output: &mut CudaDynLwe,
accumulator: &CudaGlweCiphertextList<u64>,
side_resources: &mut Self::SideResources,
) {
match self.bootstrapping_key {
CudaBootstrappingKey::Classic(_) => {
self.lwe_classic_fft_pbs(input, output, accumulator, side_resources);
}
CudaBootstrappingKey::MultiBit(_) => {
todo!("TODO: this currently only manages classic PBS")
}
}
}
}
impl AllocateMultiBitModSwitchResult for CudaDynLwe {
type Output = Self;
type SideResources = CudaSideResources;
fn allocate_multi_bit_mod_switch_result(
&self,
_side_resources: &mut Self::SideResources,
) -> Self::Output {
todo!(
"TODO: the output type likely needs to be a specialized enum CudaDynModSwitchedLwe\n\
See shortint CPU impls, the standard mod switch results likely \
need an update for the output type"
)
}
}
impl MultiBitModSwitch<Self> for CudaDynLwe {
type SideResources = CudaSideResources;
fn multi_bit_mod_switch(
&self,
_grouping_factor: LweBskGroupingFactor,
_output_modulus_log: CiphertextModulusLog,
_output: &mut Self,
_side_resources: &mut Self::SideResources,
) {
todo!(
"TODO: the output type likely needs to be a specialized enum CudaDynModSwitchedLwe\n\
See shortint CPU impls, the standard mod switch results likely \
need an update for the output type"
)
}
}
impl LweGenericBlindRotate128<CudaDynLwe, CudaDynLwe, CudaGlweCiphertextList<u128>>
for CudaNoiseSquashingKey
{
type SideResources = CudaSideResources;
fn lwe_generic_blind_rotate_128(
&self,
input: &CudaDynLwe,
output: &mut CudaDynLwe,
accumulator: &CudaGlweCiphertextList<u128>,
side_resources: &mut Self::SideResources,
) {
match self.bootstrapping_key {
CudaBootstrappingKey::Classic(_) => {
self.lwe_classic_fft_128_pbs(input, output, accumulator, side_resources)
}
CudaBootstrappingKey::MultiBit(_) => todo!(
"CPU manages this by taking a modswitched type to be able to apply \
the blind rotate correctly without redoing the modswitch, to adapt for the GPU case"
),
}
}
}

View File

@@ -20,6 +20,7 @@ pub(crate) mod test_scalar_shift;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_shift;
pub(crate) mod test_sub;
pub(crate) mod test_trivium;
pub(crate) mod test_vector_comparisons;
pub(crate) mod test_vector_find;
@@ -85,6 +86,50 @@ impl<F> GpuFunctionExecutor<F> {
}
}
impl<'a, F>
FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
crate::Result<RadixCiphertext>,
> for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
&CudaUnsignedRadixCiphertext,
usize,
&CudaStreams,
) -> crate::Result<CudaUnsignedRadixCiphertext>,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a RadixCiphertext, &'a RadixCiphertext, usize),
) -> crate::Result<RadixCiphertext> {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
let gpu_result = (self.func)(
&context.sks,
&d_ctxt_1,
&d_ctxt_2,
input.2,
&context.streams,
)?;
Ok(gpu_result.to_radix_ciphertext(&context.streams))
}
}
impl<'a, F>
FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize),

View File

@@ -0,0 +1,71 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parameterized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_unsigned::test_trivium::{
trivium_comparison_test, trivium_test_vector_1_test, trivium_test_vector_2_test,
trivium_test_vector_3_test, trivium_test_vector_4_test,
};
use crate::shortint::parameters::{
TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
};
create_gpu_parameterized_test!(integer_trivium_test_vector_1 {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(integer_trivium_test_vector_2 {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(integer_trivium_test_vector_3 {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(integer_trivium_test_vector_4 {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(integer_trivium_comparison {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
fn integer_trivium_test_vector_1<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_test_vector_1_test(param, executor);
}
fn integer_trivium_test_vector_2<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_test_vector_2_test(param, executor);
}
fn integer_trivium_test_vector_3<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_test_vector_3_test(param, executor);
}
fn integer_trivium_test_vector_4<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_test_vector_4_test(param, executor);
}
fn integer_trivium_comparison<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_comparison_test(param, executor);
}

View File

@@ -0,0 +1,111 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{cuda_backend_trivium_generate_keystream, LweBskGroupingFactor, PBSType};
impl CudaServerKey {
/// Generates a Trivium keystream homomorphically on the GPU.
///
/// # Arguments
/// * `key` - The encrypted secret key.
/// * `iv` - The encrypted initialization vector.
/// * `num_steps` - The number of keystream bits to generate per input.
/// * `streams` - The CUDA streams to use for execution.
pub fn trivium_generate_keystream(
&self,
key: &CudaUnsignedRadixCiphertext,
iv: &CudaUnsignedRadixCiphertext,
num_steps: usize,
streams: &CudaStreams,
) -> crate::Result<CudaUnsignedRadixCiphertext> {
let num_key_bits = 80;
let num_iv_bits = 80;
let batch_size = 64;
if key.as_ref().d_blocks.lwe_ciphertext_count().0 != num_key_bits {
return Err(format!(
"Input key must contain {} encrypted bits, but contains {}",
num_key_bits,
key.as_ref().d_blocks.lwe_ciphertext_count().0
)
.into());
}
if iv.as_ref().d_blocks.lwe_ciphertext_count().0 != num_iv_bits {
return Err(format!(
"Input IV must contain {} encrypted bits, but contains {}",
num_iv_bits,
iv.as_ref().d_blocks.lwe_ciphertext_count().0
)
.into());
}
if !num_steps.is_multiple_of(batch_size) {
return Err(format!(
"The number of steps must be a multiple of {batch_size}, but is {num_steps}"
)
.into());
}
let num_output_bits = num_steps;
let mut keystream: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(num_output_bits, streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_trivium_generate_keystream(
streams,
keystream.as_mut(),
key.as_ref(),
iv.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
num_steps as u32,
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_trivium_generate_keystream(
streams,
keystream.as_mut(),
key.as_ref(),
iv.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
num_steps as u32,
);
}
}
}
Ok(keystream)
}
}

View File

@@ -28,6 +28,8 @@ pub(crate) mod test_shift;
pub(crate) mod test_slice;
pub(crate) mod test_sub;
pub(crate) mod test_sum;
#[cfg(feature = "gpu")]
pub(crate) mod test_trivium;
pub(crate) mod test_vector_comparisons;
pub(crate) mod test_vector_find;

View File

@@ -0,0 +1,306 @@
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey};
use crate::shortint::parameters::TestParameters;
use rand::Rng;
use std::sync::Arc;
fn encrypt_bits(cks: &RadixClientKey, bits: &[u64]) -> RadixCiphertext {
RadixCiphertext::from(
bits.iter()
.map(|&bit| cks.encrypt_one_block(bit))
.collect::<Vec<_>>(),
)
}
fn decrypt_bits(cks: &RadixClientKey, ct: &RadixCiphertext) -> Vec<u8> {
ct.blocks
.iter()
.map(|block| cks.decrypt_one_block(block) as u8)
.collect()
}
struct TriviumRef {
a: Vec<u8>,
b: Vec<u8>,
c: Vec<u8>,
}
impl TriviumRef {
fn new(key: &[u8], iv: &[u8]) -> Self {
let mut a = vec![0u8; 93];
let mut b = vec![0u8; 84];
let mut c = vec![0u8; 111];
for i in 0..80 {
a[i] = key[79 - i];
b[i] = iv[79 - i];
}
c[108] = 1;
c[109] = 1;
c[110] = 1;
let mut triv = Self { a, b, c };
for _ in 0..(18 * 64) {
triv.next();
}
triv
}
fn next(&mut self) -> u8 {
let t1 = self.a[65] ^ self.a[92];
let t2 = self.b[68] ^ self.b[83];
let t3 = self.c[65] ^ self.c[110];
let out = t1 ^ t2 ^ t3;
let a_in = t3 ^ self.a[68] ^ (self.c[108] & self.c[109]);
let b_in = t1 ^ self.b[77] ^ (self.a[90] & self.a[91]);
let c_in = t2 ^ self.c[86] ^ (self.b[81] & self.b[82]);
self.a.pop();
self.a.insert(0, a_in);
self.b.pop();
self.b.insert(0, b_in);
self.c.pop();
self.c.insert(0, c_in);
out
}
}
#[test]
fn test_trivium_ref_consistency() {
let key = vec![0u8; 80];
let iv = vec![0u8; 80];
let expected_hex = "FBE0BF265859051B";
let mut trivium = TriviumRef::new(&key, &iv);
let mut output_bits = Vec::new();
for _ in 0..64 {
output_bits.push(trivium.next());
}
let packed = get_hexadecimal_string_from_lsb_first_stream(&output_bits);
assert_eq!(&packed[0..16], expected_hex);
}
fn get_hexadecimal_string_from_lsb_first_stream(a: &[u8]) -> String {
assert!(a.len().is_multiple_of(8));
let mut hexadecimal = String::new();
for test in a.chunks(8) {
let to_hex = |chunk: &[u8]| -> char {
let mut val = 0u8;
if chunk[0] == 1 {
val |= 1;
}
if chunk[1] == 1 {
val |= 2;
}
if chunk[2] == 1 {
val |= 4;
}
if chunk[3] == 1 {
val |= 8;
}
match val {
0..=9 => (val + b'0') as char,
10..=15 => (val - 10 + b'A') as char,
_ => unreachable!(),
}
};
hexadecimal.push(to_hex(&test[4..8]));
hexadecimal.push(to_hex(&test[0..4]));
}
hexadecimal
}
pub fn trivium_test_vector_1_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
crate::Result<RadixCiphertext>,
>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let key = vec![0u64; 80];
let iv = vec![0u64; 80];
let expected_output_0_63 = "FBE0BF265859051B517A2E4E239FC97F563203161907CF2DE7A8790FA1B2E9CDF75292030268B7382B4C1A759AA2599A285549986E74805903801A4CB5A5D4F2";
let ct_key = encrypt_bits(&cks, &key);
let ct_iv = encrypt_bits(&cks, &iv);
let num_steps = 512;
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
let decrypted_bits = decrypt_bits(&cks, &output_radix);
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}
pub fn trivium_test_vector_2_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
crate::Result<RadixCiphertext>,
>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let mut key = vec![0u64; 80];
let iv = vec![0u64; 80];
key[7] = 1;
let expected_output_0_63 = "38EB86FF730D7A9CAF8DF13A4420540DBB7B651464C87501552041C249F29A64D2FBF515610921EBE06C8F92CECF7F8098FF20CCCC6A62B97BE8EF7454FC80F9";
let ct_key = encrypt_bits(&cks, &key);
let ct_iv = encrypt_bits(&cks, &iv);
let num_steps = 512;
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
let decrypted_bits = decrypt_bits(&cks, &output_radix);
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}
pub fn trivium_test_vector_3_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
crate::Result<RadixCiphertext>,
>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let key = vec![0u64; 80];
let mut iv = vec![0u64; 80];
iv[7] = 1;
let expected_output_0_63 = "F8901736640549E3BA7D42EA2D07B9F49233C18D773008BD755585B1A8CBAB86C1E9A9B91F1AD33483FD6EE3696D659C9374260456A36AAE11F033A519CBD5D7";
let ct_key = encrypt_bits(&cks, &key);
let ct_iv = encrypt_bits(&cks, &iv);
let num_steps = 512;
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
let decrypted_bits = decrypt_bits(&cks, &output_radix);
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}
pub fn trivium_test_vector_4_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
crate::Result<RadixCiphertext>,
>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let key_string = "0053A6F94C9FF24598EB";
let mut key = vec![0u64; 80];
for i in (0..key_string.len()).step_by(2) {
let mut val = u8::from_str_radix(&key_string[i..i + 2], 16).unwrap();
for j in 0..8 {
key[8 * (i >> 1) + j] = (val % 2) as u64;
val >>= 1;
}
}
let iv_string = "0D74DB42A91077DE45AC";
let mut iv = vec![0u64; 80];
for i in (0..iv_string.len()).step_by(2) {
let mut val = u8::from_str_radix(&iv_string[i..i + 2], 16).unwrap();
for j in 0..8 {
iv[8 * (i >> 1) + j] = (val % 2) as u64;
val >>= 1;
}
}
let expected_output_0_63 = "F4CD954A717F26A7D6930830C4E7CF0819F80E03F25F342C64ADC66ABA7F8A8E6EAA49F23632AE3CD41A7BD290A0132F81C6D4043B6E397D7388F3A03B5FE358";
let ct_key = encrypt_bits(&cks, &key);
let ct_iv = encrypt_bits(&cks, &iv);
let num_steps = 512;
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
let decrypted_bits = decrypt_bits(&cks, &output_radix);
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}
pub fn trivium_comparison_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext, usize),
crate::Result<RadixCiphertext>,
>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let num_runs = 5;
let num_steps = 512;
for i in 0..num_runs {
let mut rng = rand::thread_rng();
let plain_key: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();
let plain_iv: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();
let key_bits_u64: Vec<u64> = plain_key.iter().map(|&x| x as u64).collect();
let iv_bits_u64: Vec<u64> = plain_iv.iter().map(|&x| x as u64).collect();
let ct_key = encrypt_bits(&cks, &key_bits_u64);
let ct_iv = encrypt_bits(&cks, &iv_bits_u64);
let mut cpu_trivium = TriviumRef::new(&plain_key, &plain_iv);
let mut cpu_output = Vec::with_capacity(num_steps);
for _ in 0..num_steps {
cpu_output.push(cpu_trivium.next());
}
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
let fhe_output = decrypt_bits(&cks, &output_radix);
assert_eq!(cpu_output.len(), fhe_output.len());
assert_eq!(cpu_output, fhe_output, "Mismatch at iteration {i}");
}
}

View File

@@ -624,6 +624,24 @@ impl AtomicPatternParameters {
}
}
}
pub fn set_deterministic_execution(&mut self, do_it: bool) {
match self {
Self::Standard(pbsparameters) => match pbsparameters {
PBSParameters::PBS(_) => (),
PBSParameters::MultiBitPBS(multi_bit_pbsparameters) => {
multi_bit_pbsparameters.deterministic_execution = do_it
}
},
Self::KeySwitch32(_) => (),
}
}
pub fn with_deterministic_execution(mut self) -> Self {
self.set_deterministic_execution(true);
self
}
}
impl ParameterSetConformant for AtomicPatternServerKey {

View File

@@ -36,6 +36,7 @@ pub(crate) fn compute_delta<Scalar: UnsignedInteger + CastFrom<u64>>(
}
}
#[derive(Clone, Copy)]
pub(crate) struct ShortintEncoding<Scalar: UnsignedInteger> {
pub(crate) ciphertext_modulus: CoreCiphertextModulus<Scalar>,
pub(crate) message_modulus: MessageModulus,

View File

@@ -99,6 +99,19 @@ impl NoiseSquashingParameters {
}
}
}
pub fn set_deterministic_execution(&mut self, do_it: bool) {
match self {
Self::Classic(_) => (),
Self::MultiBit(noise_squashing_multi_bit_parameters) => {
noise_squashing_multi_bit_parameters.deterministic_execution = do_it
}
}
}
pub fn with_deterministic_execution(mut self) -> Self {
self.set_deterministic_execution(true);
self
}
}
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]

View File

@@ -1,4 +1,5 @@
use super::current_params::meta::cpu::*;
use super::current_params::meta::gpu::*;
use super::current_params::*;
use super::{
AtomicPatternParameters, ClassicPBSParameters, CompactPublicKeyEncryptionParameters,
@@ -229,6 +230,7 @@ pub const TEST_PARAM_NOISE_SQUASHING_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFOR
pub const TEST_PARAM_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128:
NoiseSquashingParameters = V1_5_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
// Meta params
pub const TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128: MetaParameters =
V1_5_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128;
@@ -237,3 +239,8 @@ pub const TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128: Met
pub const TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128: MetaParameters =
V1_5_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128;
// GPU params we want to check for full scenario
pub const TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128:
MetaParameters =
V1_5_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128;

View File

@@ -1,6 +1,6 @@
use super::dp_ks_ms::dp_ks_any_ms;
use super::utils::noise_simulation::{
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
NoiseSimulationGenericBootstrapKey, NoiseSimulationGlwe, NoiseSimulationLwe,
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
};
use super::utils::traits::*;
@@ -13,20 +13,19 @@ use crate::core_crypto::commons::dispersion::Variance;
use crate::core_crypto::commons::parameters::CiphertextModulusLog;
use crate::shortint::atomic_pattern::AtomicPattern;
use crate::shortint::ciphertext::NoiseLevel;
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::client_key::ClientKey;
use crate::shortint::encoding::ShortintEncoding;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::test_params::{
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
};
use crate::shortint::parameters::{AtomicPatternParameters, CarryModulus, MetaParameters};
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::NoiseSimulationModulus;
use crate::shortint::server_key::tests::parameterized_test::create_parameterized_test;
use crate::shortint::server_key::ServerKey;
use crate::shortint::{Ciphertext, PaddingBit};
use crate::shortint::Ciphertext;
use rayon::prelude::*;
#[allow(clippy::too_many_arguments)]
@@ -63,7 +62,7 @@ pub fn br_dp_ks_any_ms<
where
// We need to be able to allocate the result and bootstrap the Input
Accumulator: AllocateLweBootstrapResult<Output = PBSResult, SideResources = Resources>,
PBSKey: LweClassicFftBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources>,
PBSKey: LweGenericBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources>,
// Result of the PBS/Blind rotate needs to be multipliable by the scalar
PBSResult: ScalarMul<DPScalar, Output = ScalarMulResult, SideResources = Resources>,
// We need to be able to allocate the result and keyswitch the result of the ScalarMul
@@ -74,7 +73,9 @@ where
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
Output = MsResult,
SideResources = Resources,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
// We need to be able to allocate the result and apply drift technique + mod switch it
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
AfterDriftOutput = DriftTechniqueResult,
@@ -88,7 +89,7 @@ where
>,
{
let mut pbs_result = accumulator.allocate_lwe_bootstrap_result(side_resources);
bsk.lwe_classic_fft_pbs(&input, &mut pbs_result, accumulator, side_resources);
bsk.lwe_generic_bootstrap(&input, &mut pbs_result, accumulator, side_resources);
let (pbs_result, after_dp, ks_result, drift_technique_result, ms_result) = dp_ks_any_ms(
pbs_result,
scalar,
@@ -111,7 +112,9 @@ where
/// Test function to verify that the noise checking tools match the actual atomic patterns
/// implemented in shortint
fn sanity_check_encrypt_br_dp_ks_pbs(meta_params: MetaParameters) {
let params = meta_params.compute_parameters;
let params = meta_params
.compute_parameters
.with_deterministic_execution();
let cks = ClientKey::new(params);
let sks = ServerKey::new(&cks);
@@ -139,7 +142,8 @@ fn sanity_check_encrypt_br_dp_ks_pbs(meta_params: MetaParameters) {
// Complete the AP by computing the PBS to match shortint
let mut pbs_result = id_lut.allocate_lwe_bootstrap_result(&mut ());
sks.lwe_classic_fft_pbs(&ms_result, &mut pbs_result, &id_lut, &mut ());
sks.apply_generic_blind_rotation(&ms_result, &mut pbs_result, &id_lut);
// Shortint APIs are not granular enough to compare ciphertexts at the MS level
// and inject arbitrary LWEs as input to the blind rotate step of the PBS.
@@ -169,6 +173,7 @@ create_parameterized_test!(sanity_check_encrypt_br_dp_ks_pbs {
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
fn encrypt_br_dp_ks_any_ms_inner_helper(
@@ -217,113 +222,21 @@ fn encrypt_br_dp_ks_any_ms_inner_helper(
let before_ms = after_drift.as_ref().unwrap_or(&after_ks);
match &cks.atomic_pattern {
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
let params = standard_atomic_pattern_client_key.parameters;
let output_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let large_lwe_secret_key = standard_atomic_pattern_client_key.large_lwe_secret_key();
let small_lwe_secret_key = standard_atomic_pattern_client_key.small_lwe_secret_key();
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_64(),
&small_lwe_secret_key,
msg,
&output_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_br.as_lwe_64(),
&large_lwe_secret_key,
msg,
&output_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_dp.as_lwe_64(),
&large_lwe_secret_key,
msg,
&output_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks.as_lwe_64(),
&small_lwe_secret_key,
msg,
&output_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_64(),
&small_lwe_secret_key,
msg,
&output_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_64(),
&small_lwe_secret_key,
msg,
&output_encoding,
),
)
}
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
let msg_u32: u32 = msg.try_into().unwrap();
let params = ks32_atomic_pattern_client_key.parameters;
let small_key_encoding = ShortintEncoding {
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let big_key_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let large_lwe_secret_key = ks32_atomic_pattern_client_key.large_lwe_secret_key();
let small_lwe_secret_key = ks32_atomic_pattern_client_key.small_lwe_secret_key();
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&small_key_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_br.as_lwe_64(),
&large_lwe_secret_key,
msg,
&big_key_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_dp.as_lwe_64(),
&large_lwe_secret_key,
msg,
&big_key_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&small_key_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&small_key_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&small_key_encoding,
),
)
}
}
let large_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
let small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
(
DecryptionAndNoiseResult::new_from_dyn_lwe(&input, &small_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_br, &large_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_dp, &large_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_ks, &small_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(before_ms, &small_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
&after_ms,
&small_lwe_secret_key_dyn,
msg,
),
)
}
fn encrypt_br_dp_ks_any_ms_noise_helper(
@@ -391,7 +304,9 @@ fn encrypt_br_dp_ks_any_ms_pfail_helper(
}
fn noise_check_encrypt_br_dp_ks_ms_noise(meta_params: MetaParameters) {
let params = meta_params.compute_parameters;
let params = meta_params
.compute_parameters
.with_deterministic_execution();
let cks = ClientKey::new(params);
let sks = ServerKey::new(&cks);
@@ -400,7 +315,7 @@ fn noise_check_encrypt_br_dp_ks_ms_noise(meta_params: MetaParameters) {
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_bsk =
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let br_input_modulus_log = sks.br_input_modulus_log();
@@ -502,11 +417,14 @@ create_parameterized_test!(noise_check_encrypt_br_dp_ks_ms_noise {
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
fn noise_check_encrypt_br_dp_ks_ms_pfail(meta_params: MetaParameters) {
let (pfail_test_meta, params) = {
let mut ap_params = meta_params.compute_parameters;
let mut ap_params = meta_params
.compute_parameters
.with_deterministic_execution();
let original_message_modulus = ap_params.message_modulus();
let original_carry_modulus = ap_params.carry_modulus();
@@ -568,4 +486,5 @@ create_parameterized_test!(noise_check_encrypt_br_dp_ks_ms_pfail {
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});

View File

@@ -8,13 +8,13 @@ use super::utils::{
use super::{should_run_short_pfail_tests_debug, should_use_single_key_debug};
use crate::shortint::atomic_pattern::AtomicPattern;
use crate::shortint::ciphertext::{Ciphertext, Degree, NoiseLevel};
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::client_key::ClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::list_compression::{CompressionKey, CompressionPrivateKeys};
use crate::shortint::parameters::test_params::{
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
};
use crate::shortint::parameters::{
AtomicPatternParameters, CarryModulus, CiphertextModulusLog, CompressionParameters,
@@ -53,8 +53,7 @@ pub fn br_dp_packing_ks_ms<
)
where
Accumulator: AllocateLweBootstrapResult<Output = PBSResult, SideResources = Resources> + Sync,
PBSKey:
LweClassicFftBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources> + Sync,
PBSKey: LweGenericBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources> + Sync,
PBSResult: ScalarMul<DPScalar, Output = DPResult, SideResources = Resources> + Send,
PackingKsk: AllocateLwePackingKeyswitchResult<Output = PackingKsResult, SideResources = Resources>
+ for<'a> LwePackingKeyswitch<[&'a DPResult], PackingKsResult, SideResources = Resources>,
@@ -70,7 +69,7 @@ where
.zip(side_resources.par_iter_mut())
.map(|(input, side_resources)| {
let mut pbs_result = accumulator.allocate_lwe_bootstrap_result(side_resources);
bsk.lwe_classic_fft_pbs(&input, &mut pbs_result, accumulator, side_resources);
bsk.lwe_generic_bootstrap(&input, &mut pbs_result, accumulator, side_resources);
let after_dp = pbs_result.scalar_mul(scalar, side_resources);
(input, pbs_result, after_dp)
@@ -98,7 +97,9 @@ where
fn sanity_check_encrypt_br_dp_packing_ks_ms(meta_params: MetaParameters) {
let (params, comp_params) = (
meta_params.compute_parameters,
meta_params
.compute_parameters
.with_deterministic_execution(),
meta_params.compression_parameters.unwrap(),
);
let cks = ClientKey::new(params);
@@ -161,6 +162,7 @@ fn sanity_check_encrypt_br_dp_packing_ks_ms(meta_params: MetaParameters) {
create_parameterized_test!(sanity_check_encrypt_br_dp_packing_ks_ms {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
#[allow(clippy::type_complexity)]
@@ -237,57 +239,41 @@ fn encrypt_br_dp_packing_ks_ms_inner_helper(
&mut side_resources,
);
let compute_large_lwe_secret_key = cks.encryption_key();
let compression_glwe_secret_key = &compression_private_key.post_packing_ks_key;
let compute_encoding = sks.encoding(PaddingBit::Yes);
let compression_encoding = ShortintEncoding {
carry_modulus: CarryModulus(1),
..compute_encoding
};
let compute_large_lwe_secret_key = cks.encryption_key();
let compute_large_lwe_secret_key_with_compression_encoding_dyn = DynLweSecretKeyView::U64 {
key: compute_large_lwe_secret_key,
encoding: compression_encoding,
};
let compression_glwe_secret_key = &compression_private_key.post_packing_ks_key;
(
before_packing
.into_iter()
.map(|(input, pbs_result, dp_result)| {
(
match &cks.atomic_pattern {
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
DecryptionAndNoiseResult::new_from_lwe(
input.as_ref_64(),
&standard_atomic_pattern_client_key.lwe_secret_key,
msg,
&compute_encoding,
)
}
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
let ks32_params = ks32_atomic_pattern_client_key.parameters;
let compute_encoding_32 = ShortintEncoding {
ciphertext_modulus: ks32_params.post_keyswitch_ciphertext_modulus,
message_modulus: ks32_params.message_modulus,
carry_modulus: ks32_params.carry_modulus,
padding_bit: PaddingBit::Yes,
};
let compute_large_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
let compute_small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
DecryptionAndNoiseResult::new_from_lwe(
input.as_ref_32(),
&ks32_atomic_pattern_client_key.lwe_secret_key,
msg.try_into().unwrap(),
&compute_encoding_32,
)
}
},
DecryptionAndNoiseResult::new_from_lwe(
pbs_result.as_ref_64(),
&compute_large_lwe_secret_key,
(
DecryptionAndNoiseResult::new_from_dyn_lwe(
&input,
&compute_small_lwe_secret_key_dyn,
msg,
&compute_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
dp_result.as_ref_64(),
&compute_large_lwe_secret_key,
DecryptionAndNoiseResult::new_from_dyn_lwe(
&pbs_result,
&compute_large_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_lwe(
&dp_result,
&compute_large_lwe_secret_key_with_compression_encoding_dyn,
msg,
&compression_encoding,
),
)
})
@@ -392,7 +378,9 @@ fn encrypt_br_dp_packing_ks_ms_pfail_helper(
fn noise_check_encrypt_br_dp_packing_ks_ms_noise(meta_params: MetaParameters) {
let (params, comp_params) = (
meta_params.compute_parameters,
meta_params
.compute_parameters
.with_deterministic_execution(),
meta_params.compression_parameters.unwrap(),
);
let cks = ClientKey::new(params);
@@ -401,7 +389,7 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_noise(meta_params: MetaParameters) {
let compression_key = cks.new_compression_key(&compression_private_key);
let noise_simulation_bsk =
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_packing_key =
NoiseSimulationLwePackingKeyswitchKey::new_from_comp_parameters(params, comp_params);
@@ -520,12 +508,15 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_noise(meta_params: MetaParameters) {
create_parameterized_test!(noise_check_encrypt_br_dp_packing_ks_ms_noise {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
fn noise_check_encrypt_br_dp_packing_ks_ms_pfail(meta_params: MetaParameters) {
let (pfail_test_meta, params, comp_params) = {
let (mut params, comp_params) = (
meta_params.compute_parameters,
meta_params
.compute_parameters
.with_deterministic_execution(),
meta_params.compression_parameters.unwrap(),
);
@@ -537,7 +528,7 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_pfail(meta_params: MetaParameters) {
assert_eq!(original_carry_modulus.0, 4);
let noise_simulation_bsk =
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
NoiseSimulationGenericBootstrapKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_packing_key =
NoiseSimulationLwePackingKeyswitchKey::new_from_comp_parameters(params, comp_params);
@@ -685,4 +676,5 @@ fn noise_check_encrypt_br_dp_packing_ks_ms_pfail(meta_params: MetaParameters) {
create_parameterized_test!(noise_check_encrypt_br_dp_packing_ks_ms_pfail {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});

View File

@@ -1,7 +1,7 @@
use super::dp_ks_ms::any_ms;
use super::utils::noise_simulation::{
DynLwe, NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
DynLwe, DynLweSecretKeyView, NoiseSimulationGenericBootstrapKey, NoiseSimulationGlwe,
NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
};
use super::utils::traits::*;
use super::utils::{
@@ -22,7 +22,6 @@ use crate::shortint::atomic_pattern::AtomicPattern;
use crate::shortint::ciphertext::{
CompressedCiphertextList, CompressedCiphertextListMeta, ReRandomizationSeed,
};
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::client_key::ClientKey;
use crate::shortint::encoding::ShortintEncoding;
use crate::shortint::engine::ShortintEngine;
@@ -31,6 +30,7 @@ use crate::shortint::list_compression::{CompressionPrivateKeys, DecompressionKey
use crate::shortint::parameters::test_params::{
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
};
use crate::shortint::parameters::{
AtomicPatternParameters, CarryModulus, CompactCiphertextListExpansionKind,
@@ -85,8 +85,7 @@ pub fn br_rerand_dp_ks_any_ms<
)
where
Accumulator: AllocateLweBootstrapResult<Output = PBSResult, SideResources = Resources>,
DecompPBSKey:
LweClassicFftBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources>,
DecompPBSKey: LweGenericBootstrap<InputCt, PBSResult, Accumulator, SideResources = Resources>,
KsKeyRerand: AllocateLweKeyswitchResult<Output = KsedZeroReRand, SideResources = Resources>
+ LweKeyswitch<InputZeroRerand, KsedZeroReRand, SideResources = Resources>,
PBSResult: for<'a> LweUncorrelatedAdd<
@@ -102,7 +101,9 @@ where
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
Output = MsResult,
SideResources = Resources,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
AfterDriftOutput = DriftTechniqueResult,
AfterMsOutput = MsResult,
@@ -116,7 +117,7 @@ where
{
// BR to decomp
let mut br_result = decomp_accumulator.allocate_lwe_bootstrap_result(side_resources);
decomp_bsk.lwe_classic_fft_pbs(&input, &mut br_result, decomp_accumulator, side_resources);
decomp_bsk.lwe_generic_bootstrap(&input, &mut br_result, decomp_accumulator, side_resources);
// Ks the CPK encryption of 0 to be added to BR result
let mut ksed_zero_rerand = ksk_rerand.allocate_lwe_keyswitch_result(side_resources);
@@ -278,185 +279,76 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper(
let before_ms = after_drift.as_ref().unwrap_or(&after_ks);
match &cks.atomic_pattern {
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
let params = standard_atomic_pattern_client_key.parameters;
let comp_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
// Adapt to the compression which has no carry bits
carry_modulus: CarryModulus(1),
padding_bit: PaddingBit::Yes,
};
let compute_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let params = cks.parameters();
let compute_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let comp_encoding = ShortintEncoding {
// Adapt to the compression which has no carry bits
carry_modulus: CarryModulus(1),
..compute_encoding
};
let cpk_lwe_secret_key = cpk_private_key.key();
let comp_lwe_secret_key = comp_private_key.post_packing_ks_key.as_lwe_secret_key();
let cpk_lwe_secret_key_dyn = cpk_private_key.lwe_secret_key_as_dyn();
let comp_lwe_secret_key = comp_private_key.post_packing_ks_key.as_lwe_secret_key();
let comp_lwe_secret_key_dyn = DynLweSecretKeyView::U64 {
key: comp_lwe_secret_key,
encoding: comp_encoding,
};
let large_compute_lwe_secret_key =
standard_atomic_pattern_client_key.large_lwe_secret_key();
let small_compute_lwe_secret_key =
standard_atomic_pattern_client_key.small_lwe_secret_key();
(
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_64(),
&comp_lwe_secret_key,
msg,
&comp_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_br.as_lwe_64(),
&large_compute_lwe_secret_key,
msg,
&compute_encoding,
),
),
(
DecryptionAndNoiseResult::new_from_lwe(
&input_zero_rerand.as_lwe_64(),
&cpk_lwe_secret_key,
msg,
&compute_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ksed_zero_rerand.as_lwe_64(),
&large_compute_lwe_secret_key,
msg,
&compute_encoding,
),
),
DecryptionAndNoiseResult::new_from_lwe(
&after_rerand.as_lwe_64(),
&large_compute_lwe_secret_key,
msg,
&compute_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_dp.as_lwe_64(),
&large_compute_lwe_secret_key,
msg,
&compute_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks.as_lwe_64(),
&small_compute_lwe_secret_key,
msg,
&compute_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_64(),
&small_compute_lwe_secret_key,
msg,
&compute_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_64(),
&small_compute_lwe_secret_key,
msg,
&compute_encoding,
),
)
}
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
let params = ks32_atomic_pattern_client_key.parameters;
let comp_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
// Adapt to the compression which has no carry bits
carry_modulus: CarryModulus(1),
padding_bit: PaddingBit::Yes,
};
let compute_encoding_u32 = ShortintEncoding {
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let compute_encoding_u64 = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let large_compute_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
let small_compute_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
let cpk_lwe_secret_key = cpk_private_key.key();
let comp_lwe_secret_key = comp_private_key.post_packing_ks_key.as_lwe_secret_key();
let large_compute_lwe_secret_key =
ks32_atomic_pattern_client_key.large_lwe_secret_key();
let small_compute_lwe_secret_key =
ks32_atomic_pattern_client_key.small_lwe_secret_key();
let msg_u32: u32 = msg.try_into().unwrap();
(
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_64(),
&comp_lwe_secret_key,
msg,
&comp_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_br.as_lwe_64(),
&large_compute_lwe_secret_key,
msg,
&compute_encoding_u64,
),
),
(
DecryptionAndNoiseResult::new_from_lwe(
&input_zero_rerand.as_lwe_64(),
&cpk_lwe_secret_key,
msg,
&compute_encoding_u64,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ksed_zero_rerand.as_lwe_64(),
&large_compute_lwe_secret_key,
msg,
&compute_encoding_u64,
),
),
DecryptionAndNoiseResult::new_from_lwe(
&after_rerand.as_lwe_64(),
&large_compute_lwe_secret_key,
msg,
&compute_encoding_u64,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_dp.as_lwe_64(),
&large_compute_lwe_secret_key,
msg,
&compute_encoding_u64,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks.as_lwe_32(),
&small_compute_lwe_secret_key,
msg_u32,
&compute_encoding_u32,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_32(),
&small_compute_lwe_secret_key,
msg_u32,
&compute_encoding_u32,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_32(),
&small_compute_lwe_secret_key,
msg_u32,
&compute_encoding_u32,
),
)
}
}
(
(
DecryptionAndNoiseResult::new_from_dyn_lwe(&input, &comp_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(
&after_br,
&large_compute_lwe_secret_key_dyn,
msg,
),
),
(
DecryptionAndNoiseResult::new_from_dyn_lwe(
&input_zero_rerand,
&cpk_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_lwe(
&after_ksed_zero_rerand,
&large_compute_lwe_secret_key_dyn,
msg,
),
),
DecryptionAndNoiseResult::new_from_dyn_lwe(
&after_rerand,
&large_compute_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_lwe(
&after_dp,
&large_compute_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_lwe(
&after_ks,
&small_compute_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_lwe(
before_ms,
&small_compute_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
&after_ms,
&small_compute_lwe_secret_key_dyn,
msg,
),
)
}
#[allow(clippy::too_many_arguments)]
@@ -589,7 +481,9 @@ fn encrypt_br_rerand_dp_ks_any_ms_pfail_helper(
fn noise_check_encrypt_br_rerand_dp_ks_ms_noise(meta_params: MetaParameters) {
let (params, cpk_params, rerand_ksk_params, compression_params) = {
let compute_params = meta_params.compute_parameters;
let compute_params = meta_params
.compute_parameters
.with_deterministic_execution();
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
// expand
@@ -627,7 +521,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise(meta_params: MetaParameters) {
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_decomp_bsk =
NoiseSimulationLweFourierBsk::new_from_comp_parameters(params, compression_params);
NoiseSimulationGenericBootstrapKey::new_from_comp_parameters(params, compression_params);
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let compute_br_input_modulus_log = sks.br_input_modulus_log();
@@ -791,11 +685,14 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise(meta_params: MetaParameters) {
create_parameterized_test!(noise_check_encrypt_br_rerand_dp_ks_ms_noise {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
fn noise_check_encrypt_br_rerand_dp_ks_ms_pfail(meta_params: MetaParameters) {
let (params, cpk_params, rerand_ksk_params, compression_params) = {
let compute_params = meta_params.compute_parameters;
let compute_params = meta_params
.compute_parameters
.with_deterministic_execution();
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
// expand
@@ -898,11 +795,14 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_pfail(meta_params: MetaParameters) {
create_parameterized_test!(noise_check_encrypt_br_rerand_dp_ks_ms_pfail {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs(meta_params: MetaParameters) {
let (params, cpk_params, rerand_ksk_params, compression_params) = {
let compute_params = meta_params.compute_parameters;
let compute_params = meta_params
.compute_parameters
.with_deterministic_execution();
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
// expand
@@ -1050,7 +950,7 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs(meta_params: MetaParameters) {
// Complete the AP by computing the PBS to match shortint
let mut pbs_result = id_lut.allocate_lwe_bootstrap_result(&mut ());
sks.lwe_classic_fft_pbs(&after_ms, &mut pbs_result, &id_lut, &mut ());
sks.apply_generic_blind_rotation(&after_ms, &mut pbs_result, &id_lut);
assert_eq!(pbs_result.as_lwe_64(), shortint_res.ct.as_view());
}
@@ -1059,4 +959,5 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs(meta_params: MetaParameters) {
create_parameterized_test!(sanity_check_encrypt_br_rerand_dp_ks_ms_pbs {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});

View File

@@ -9,14 +9,13 @@ use super::utils::{
};
use super::{should_run_short_pfail_tests_debug, should_use_single_key_debug};
use crate::core_crypto::commons::parameters::CiphertextModulusLog;
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::client_key::ClientKey;
use crate::shortint::encoding::ShortintEncoding;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::key_switching_key::{KeySwitchingKeyBuildHelper, KeySwitchingKeyView};
use crate::shortint::parameters::test_params::{
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
};
use crate::shortint::parameters::{
AtomicPatternParameters, CarryModulus, CompactCiphertextListExpansionKind,
@@ -26,7 +25,6 @@ use crate::shortint::parameters::{
use crate::shortint::public_key::compact::{CompactPrivateKey, CompactPublicKey};
use crate::shortint::server_key::tests::parameterized_test::create_parameterized_test;
use crate::shortint::server_key::ServerKey;
use crate::shortint::PaddingBit;
use rayon::prelude::*;
#[allow(clippy::too_many_arguments)]
@@ -54,7 +52,9 @@ where
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
Output = MsResult,
SideResources = Resources,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
AfterDriftOutput = DriftTechniqueResult,
AfterMsOutput = MsResult,
@@ -163,97 +163,19 @@ fn cpk_ks_any_ms_inner_helper(
let before_ms = after_drift.as_ref().unwrap_or(&after_ks_ds);
match &cks.atomic_pattern {
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
let params = standard_atomic_pattern_client_key.parameters;
let encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let cpk_lwe_secret_key_dyn = cpk_private_key.lwe_secret_key_as_dyn();
let small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
let cpk_lwe_secret_key = cpk_private_key.key();
let small_compute_lwe_secret_key =
standard_atomic_pattern_client_key.small_lwe_secret_key();
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_64(),
&cpk_lwe_secret_key,
msg,
&encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks_ds.as_lwe_64(),
&small_compute_lwe_secret_key,
msg,
&encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_64(),
&small_compute_lwe_secret_key,
msg,
&encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_64(),
&small_compute_lwe_secret_key,
msg,
&encoding,
),
)
}
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
let params = ks32_atomic_pattern_client_key.parameters;
let compute_encoding_u32 = ShortintEncoding {
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let compute_encoding_u64 = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let cpk_lwe_secret_key = cpk_private_key.key();
let small_compute_lwe_secret_key =
ks32_atomic_pattern_client_key.small_lwe_secret_key();
let msg_u32: u32 = msg.try_into().unwrap();
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_64(),
&cpk_lwe_secret_key,
msg,
&compute_encoding_u64,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks_ds.as_lwe_32(),
&small_compute_lwe_secret_key,
msg_u32,
&compute_encoding_u32,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_32(),
&small_compute_lwe_secret_key,
msg_u32,
&compute_encoding_u32,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_32(),
&small_compute_lwe_secret_key,
msg_u32,
&compute_encoding_u32,
),
)
}
}
(
DecryptionAndNoiseResult::new_from_dyn_lwe(&input, &cpk_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_ks_ds, &small_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(before_ms, &small_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
&after_ms,
&small_lwe_secret_key_dyn,
msg,
),
)
}
#[allow(clippy::too_many_arguments)]
@@ -327,7 +249,9 @@ fn cpk_ks_any_ms_pfail_helper(
fn noise_check_encrypt_cpk_ks_ms_noise(meta_params: MetaParameters) {
let (params, cpk_params, ksk_ds_params) = {
let compute_params = meta_params.compute_parameters;
let compute_params = meta_params
.compute_parameters
.with_deterministic_execution();
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
// expand
@@ -454,11 +378,14 @@ fn noise_check_encrypt_cpk_ks_ms_noise(meta_params: MetaParameters) {
create_parameterized_test!(noise_check_encrypt_cpk_ks_ms_noise {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
fn noise_check_encrypt_cpk_ks_ms_pfail(meta_params: MetaParameters) {
let (params, cpk_params, ksk_ds_params) = {
let compute_params = meta_params.compute_parameters;
let compute_params = meta_params
.compute_parameters
.with_deterministic_execution();
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
// expand
@@ -548,11 +475,14 @@ fn noise_check_encrypt_cpk_ks_ms_pfail(meta_params: MetaParameters) {
create_parameterized_test!(noise_check_encrypt_cpk_ks_ms_pfail {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
fn sanity_check_encrypt_cpk_ks_ms_pbs(meta_params: MetaParameters) {
let (params, cpk_params, ksk_ds_params, orig_cast_mode) = {
let compute_params = meta_params.compute_parameters;
let compute_params = meta_params
.compute_parameters
.with_deterministic_execution();
let dedicated_cpk_params = meta_params.dedicated_compact_public_key_parameters.unwrap();
// To avoid the expand logic of shortint which would force a keyswitch + LUT eval after
// expand
@@ -635,7 +565,7 @@ fn sanity_check_encrypt_cpk_ks_ms_pbs(meta_params: MetaParameters) {
// Complete the AP by computing the PBS to match shortint
let mut pbs_result = id_lut.allocate_lwe_bootstrap_result(&mut ());
sks.lwe_classic_fft_pbs(&after_ms, &mut pbs_result, &id_lut, &mut ());
sks.apply_generic_blind_rotation(&after_ms, &mut pbs_result, &id_lut);
assert_eq!(pbs_result.as_lwe_64(), shortint_res.ct.as_view());
}
@@ -644,4 +574,5 @@ fn sanity_check_encrypt_cpk_ks_ms_pbs(meta_params: MetaParameters) {
create_parameterized_test!(sanity_check_encrypt_cpk_ks_ms_pbs {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});

View File

@@ -6,14 +6,13 @@ use super::utils::{
};
use super::{should_run_short_pfail_tests_debug, should_use_single_key_debug};
use crate::core_crypto::commons::parameters::CiphertextModulusLog;
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::client_key::ClientKey;
use crate::shortint::encoding::{PaddingBit, ShortintEncoding};
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::test_params::{
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
};
use crate::shortint::parameters::{AtomicPatternParameters, CarryModulus, MetaParameters};
use crate::shortint::server_key::tests::parameterized_test::create_parameterized_test;
@@ -32,7 +31,9 @@ where
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
Output = MsResult,
SideResources = Resources,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
// We need to be able to allocate the result and apply drift technique + mod switch it
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
AfterDriftOutput = DriftTechniqueResult,
@@ -76,6 +77,17 @@ where
side_resources,
);
(None, ms_result)
}
NoiseSimulationModulusSwitchConfig::MultiBit(grouping_factor) => {
let mut ms_result = input.allocate_multi_bit_mod_switch_result(side_resources);
input.multi_bit_mod_switch(
grouping_factor,
br_input_modulus_log,
&mut ms_result,
side_resources,
);
(None, ms_result)
}
}
@@ -116,7 +128,9 @@ where
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
Output = MsResult,
SideResources = Resources,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
// We need to be able to allocate the result and apply drift technique + mod switch it
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
AfterDriftOutput = DriftTechniqueResult,
@@ -152,7 +166,9 @@ where
/// Test function to verify that the noise checking tools match the actual atomic patterns
/// implemented in shortint
fn sanity_check_encrypt_dp_ks_pbs(meta_params: MetaParameters) {
let params = meta_params.compute_parameters;
let params = meta_params
.compute_parameters
.with_deterministic_execution();
let cks = ClientKey::new(params);
let sks = ServerKey::new(&cks);
@@ -178,7 +194,7 @@ fn sanity_check_encrypt_dp_ks_pbs(meta_params: MetaParameters) {
// Complete the AP by computing the PBS to match shortint
let mut pbs_result = id_lut.allocate_lwe_bootstrap_result(&mut ());
sks.lwe_classic_fft_pbs(&after_ms, &mut pbs_result, &id_lut, &mut ());
sks.apply_generic_blind_rotation(&after_ms, &mut pbs_result, &id_lut);
let mut shortint_res =
sks.unchecked_scalar_mul(&input_zero, max_scalar_mul.try_into().unwrap());
@@ -192,6 +208,7 @@ create_parameterized_test!(sanity_check_encrypt_dp_ks_pbs {
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
fn encrypt_dp_ks_any_ms_inner_helper(
@@ -235,100 +252,20 @@ fn encrypt_dp_ks_any_ms_inner_helper(
let before_ms = after_drift.as_ref().unwrap_or(&after_ks);
match &cks.atomic_pattern {
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
let output_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let large_lwe_secret_key = standard_atomic_pattern_client_key.large_lwe_secret_key();
let small_lwe_secret_key = standard_atomic_pattern_client_key.small_lwe_secret_key();
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_64(),
&large_lwe_secret_key,
msg,
&output_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_dp.as_lwe_64(),
&large_lwe_secret_key,
msg,
&output_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks.as_lwe_64(),
&small_lwe_secret_key,
msg,
&output_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_64(),
&small_lwe_secret_key,
msg,
&output_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_64(),
&small_lwe_secret_key,
msg,
&output_encoding,
),
)
}
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
let msg_u32: u32 = msg.try_into().unwrap();
let params = ks32_atomic_pattern_client_key.parameters;
let small_key_encoding = ShortintEncoding {
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let big_key_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let large_lwe_secret_key = ks32_atomic_pattern_client_key.large_lwe_secret_key();
let small_lwe_secret_key = ks32_atomic_pattern_client_key.small_lwe_secret_key();
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_64(),
&large_lwe_secret_key,
msg,
&big_key_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_dp.as_lwe_64(),
&large_lwe_secret_key,
msg,
&big_key_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&small_key_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&small_key_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&small_key_encoding,
),
)
}
}
let large_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
let small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
(
DecryptionAndNoiseResult::new_from_dyn_lwe(&input, &large_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_dp, &large_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(&after_ks, &small_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_lwe(before_ms, &small_lwe_secret_key_dyn, msg),
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
&after_ms,
&small_lwe_secret_key_dyn,
msg,
),
)
}
fn encrypt_dp_ks_any_ms_noise_helper(
@@ -390,7 +327,9 @@ fn encrypt_dp_ks_any_ms_pfail_helper(
}
fn noise_check_encrypt_dp_ks_ms_noise(meta_params: MetaParameters) {
let params = meta_params.compute_parameters;
let params = meta_params
.compute_parameters
.with_deterministic_execution();
let cks = ClientKey::new(params);
let sks = ServerKey::new(&cks);
@@ -482,11 +421,14 @@ create_parameterized_test!(noise_check_encrypt_dp_ks_ms_noise {
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
fn noise_check_encrypt_dp_ks_ms_pfail(meta_params: MetaParameters) {
let (pfail_test_meta, params) = {
let mut ap_params = meta_params.compute_parameters;
let mut ap_params = meta_params
.compute_parameters
.with_deterministic_execution();
let original_message_modulus = ap_params.message_modulus();
let original_carry_modulus = ap_params.carry_modulus();
@@ -548,4 +490,5 @@ create_parameterized_test!(noise_check_encrypt_dp_ks_ms_pfail {
TEST_META_PARAM_CPU_2_2_KS_PBS_GAUSSIAN_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});

View File

@@ -6,7 +6,6 @@ use super::utils::{mean_and_variance_check, DecryptionAndNoiseResult, NoiseSampl
use crate::core_crypto::algorithms::lwe_programmable_bootstrapping::generate_programmable_bootstrap_glwe_lut;
use crate::core_crypto::commons::dispersion::Variance;
use crate::core_crypto::commons::parameters::CiphertextModulusLog;
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::client_key::ClientKey;
use crate::shortint::encoding::{PaddingBit, ShortintEncoding};
use crate::shortint::engine::ShortintEngine;
@@ -18,6 +17,7 @@ use crate::shortint::parameters::noise_squashing::NoiseSquashingParameters;
use crate::shortint::parameters::test_params::{
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
};
use crate::shortint::parameters::{
AtomicPatternParameters, MetaParameters, NoiseSquashingCompressionParameters,
@@ -68,7 +68,9 @@ where
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
Output = MsResult,
SideResources = Resources,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
// We need to be able to allocate the result and apply drift technique + mod switch it
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
AfterDriftOutput = DriftTechniqueResult,
@@ -84,7 +86,7 @@ where
// one to allocate the blind rotation result
Accumulator: AllocateLweBootstrapResult<Output = PbsResult, SideResources = Resources>,
// We need to be able to apply the PBS
Bsk: LweClassicFft128Bootstrap<MsResult, PbsResult, Accumulator, SideResources = Resources>,
Bsk: LweGenericBlindRotate128<MsResult, PbsResult, Accumulator, SideResources = Resources>,
{
let after_dp = input.scalar_mul(scalar, side_resources);
let mut ks_result = ksk.allocate_lwe_keyswitch_result(side_resources);
@@ -98,7 +100,7 @@ where
);
let mut pbs_result = accumulator.allocate_lwe_bootstrap_result(side_resources);
bsk_128.lwe_classic_fft_128_pbs(&ms_result, &mut pbs_result, accumulator, side_resources);
bsk_128.lwe_generic_blind_rotate_128(&ms_result, &mut pbs_result, accumulator, side_resources);
(
input,
after_dp,
@@ -159,7 +161,9 @@ where
+ AllocateCenteredBinaryShiftedStandardModSwitchResult<
Output = MsResult,
SideResources = Resources,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>,
> + CenteredBinaryShiftedStandardModSwitch<MsResult, SideResources = Resources>
+ AllocateMultiBitModSwitchResult<Output = MsResult, SideResources = Resources>
+ MultiBitModSwitch<MsResult, SideResources = Resources>,
// We need to be able to allocate the result and apply drift technique + mod switch it
DriftKey: AllocateDriftTechniqueStandardModSwitchResult<
AfterDriftOutput = DriftTechniqueResult,
@@ -175,7 +179,7 @@ where
// one to allocate the blind rotation result
Accumulator: AllocateLweBootstrapResult<Output = PbsResult, SideResources = Resources> + Sync,
// We need to be able to apply the PBS
Bsk: LweClassicFft128Bootstrap<MsResult, PbsResult, Accumulator, SideResources = Resources>
Bsk: LweGenericBlindRotate128<MsResult, PbsResult, Accumulator, SideResources = Resources>
+ Sync,
PackingKey: AllocateLwePackingKeyswitchResult<Output = PackingResult, SideResources = Resources>
+ for<'a> LwePackingKeyswitch<[&'a PbsResult], PackingResult, SideResources = Resources>,
@@ -240,8 +244,12 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks(meta_params: MetaParame
let (params, noise_squashing_params, noise_squashing_compression_params) = {
let meta_noise_squashing_params = meta_params.noise_squashing_parameters.unwrap();
(
meta_params.compute_parameters,
meta_noise_squashing_params.parameters,
meta_params
.compute_parameters
.with_deterministic_execution(),
meta_noise_squashing_params
.parameters
.with_deterministic_execution(),
meta_noise_squashing_params.compression_parameters.unwrap(),
)
};
@@ -314,14 +322,14 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks(meta_params: MetaParame
let compressed = noise_squashing_compression_key
.compress_noise_squashed_ciphertexts_into_list(&noise_squashed);
let underlying_glwes = compressed.glwe_ciphertext_list;
let underlying_glwes = &compressed.glwe_ciphertext_list;
assert_eq!(underlying_glwes.len(), 1);
let extracted = underlying_glwes[0].extract();
// Bodies that were not filled are discarded
after_packing.get_mut_body().as_mut()[lwe_per_glwe.0..].fill(0);
after_packing.get_mut_body().as_mut()[compressed.len()..].fill(0);
assert_eq!(after_packing.as_view(), extracted.as_view());
}
@@ -329,6 +337,7 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks(meta_params: MetaParame
create_parameterized_test!(sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});
#[allow(clippy::too_many_arguments)]
@@ -451,117 +460,45 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper(
.map(
|(input, after_dp, after_ks, after_drift, after_ms, after_pbs128)| {
let before_ms = after_drift.as_ref().unwrap_or(&after_ks);
match &cks.atomic_pattern {
AtomicPatternClientKey::Standard(standard_atomic_pattern_client_key) => {
let params = standard_atomic_pattern_client_key.parameters;
let u64_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let large_lwe_secret_key =
standard_atomic_pattern_client_key.large_lwe_secret_key();
let small_lwe_secret_key =
standard_atomic_pattern_client_key.small_lwe_secret_key();
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_64(),
&large_lwe_secret_key,
msg,
&u64_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_dp.as_lwe_64(),
&large_lwe_secret_key,
msg,
&u64_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks.as_lwe_64(),
&small_lwe_secret_key,
msg,
&u64_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_64(),
&small_lwe_secret_key,
msg,
&u64_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_64(),
&small_lwe_secret_key,
msg,
&u64_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_pbs128,
&noise_squashing_private_key.post_noise_squashing_lwe_secret_key(),
msg.into(),
&u128_encoding,
),
)
}
AtomicPatternClientKey::KeySwitch32(ks32_atomic_pattern_client_key) => {
let msg_u32: u32 = msg.try_into().unwrap();
let params = ks32_atomic_pattern_client_key.parameters;
let u32_encoding = ShortintEncoding {
ciphertext_modulus: params.post_keyswitch_ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let u64_encoding = ShortintEncoding {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit: PaddingBit::Yes,
};
let large_lwe_secret_key =
ks32_atomic_pattern_client_key.large_lwe_secret_key();
let small_lwe_secret_key =
ks32_atomic_pattern_client_key.small_lwe_secret_key();
(
DecryptionAndNoiseResult::new_from_lwe(
&input.as_lwe_64(),
&large_lwe_secret_key,
msg,
&u64_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_dp.as_lwe_64(),
&large_lwe_secret_key,
msg,
&u64_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ks.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&u32_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&before_ms.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&u32_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_ms.as_lwe_32(),
&small_lwe_secret_key,
msg_u32,
&u32_encoding,
),
DecryptionAndNoiseResult::new_from_lwe(
&after_pbs128,
&noise_squashing_private_key.post_noise_squashing_lwe_secret_key(),
msg.into(),
&u128_encoding,
),
)
}
}
let large_lwe_secret_key_dyn = cks.large_lwe_secret_key_as_dyn();
let small_lwe_secret_key_dyn = cks.small_lwe_secret_key_as_dyn();
(
DecryptionAndNoiseResult::new_from_dyn_lwe(
&input,
&large_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_lwe(
&after_dp,
&large_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_lwe(
&after_ks,
&small_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_lwe(
before_ms,
&small_lwe_secret_key_dyn,
msg,
),
DecryptionAndNoiseResult::new_from_dyn_modswitched_lwe(
&after_ms,
&small_lwe_secret_key_dyn,
msg,
),
// This one is a "raw" LWE given today we don't need to manage several unsigned
// integer types after a PBS128
DecryptionAndNoiseResult::new_from_lwe(
&after_pbs128,
&noise_squashing_private_key.post_noise_squashing_lwe_secret_key(),
msg.into(),
&u128_encoding,
),
)
},
)
.collect();
@@ -662,8 +599,12 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise(meta_params: MetaP
let (params, noise_squashing_params, noise_squashing_compression_params) = {
let meta_noise_squashing_params = meta_params.noise_squashing_parameters.unwrap();
(
meta_params.compute_parameters,
meta_noise_squashing_params.parameters,
meta_params
.compute_parameters
.with_deterministic_execution(),
meta_noise_squashing_params
.parameters
.with_deterministic_execution(),
meta_noise_squashing_params.compression_parameters.unwrap(),
)
};
@@ -684,7 +625,7 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise(meta_params: MetaP
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_bsk128 =
NoiseSimulationLweFourier128Bsk::new_from_parameters(params, noise_squashing_params);
NoiseSimulationGenericBootstrapKey128::new_from_parameters(params, noise_squashing_params);
let noise_simulation_packing_key =
NoiseSimulationLwePackingKeyswitchKey::new_from_noise_squashing_parameters(
noise_squashing_params,
@@ -806,4 +747,5 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise(meta_params: MetaP
create_parameterized_test!(noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise {
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128,
});

View File

@@ -1,10 +1,13 @@
pub mod noise_simulation;
pub use noise_simulation::traits;
pub mod traits;
use crate::core_crypto::algorithms::glwe_encryption::decrypt_glwe_ciphertext;
use crate::core_crypto::algorithms::lwe_encryption::{
allocate_and_encrypt_new_lwe_ciphertext, decrypt_lwe_ciphertext,
};
use crate::core_crypto::algorithms::lwe_multi_bit_programmable_bootstrapping::{
MultiBitModulusSwitchedLweCiphertext, StandardMultiBitModulusSwitchedCt,
};
use crate::core_crypto::algorithms::misc::torus_modular_diff;
use crate::core_crypto::algorithms::test::round_decode;
use crate::core_crypto::commons::dispersion::{DispersionParameter, Variance};
@@ -14,7 +17,7 @@ use crate::core_crypto::commons::noise_formulas::secure_noise::{
minimal_lwe_variance_for_132_bits_security_gaussian,
minimal_lwe_variance_for_132_bits_security_tuniform,
};
use crate::core_crypto::commons::numeric::{CastFrom, UnsignedInteger};
use crate::core_crypto::commons::numeric::{CastFrom, CastInto, UnsignedInteger};
use crate::core_crypto::commons::parameters::{
CiphertextModulus, DynamicDistribution, LweCiphertextCount, LweDimension, PlaintextCount,
};
@@ -29,11 +32,14 @@ use crate::core_crypto::entities::glwe_ciphertext::GlweCiphertext;
use crate::core_crypto::entities::glwe_secret_key::GlweSecretKey;
use crate::core_crypto::entities::lwe_ciphertext::{LweCiphertext, LweCiphertextOwned};
use crate::core_crypto::entities::lwe_secret_key::LweSecretKey;
use crate::core_crypto::entities::{Cleartext, PlaintextList};
use crate::core_crypto::entities::{Cleartext, Plaintext, PlaintextList};
use crate::shortint::encoding::ShortintEncoding;
use crate::shortint::parameters::{
AtomicPatternParameters, CarryModulus, MessageModulus, PBSParameters,
};
use crate::shortint::server_key::tests::noise_distribution::utils::noise_simulation::{
DynLwe, DynLweSecretKeyView, DynModSwitchedLwe, DynStandardMultiBitModulusSwitchedCt,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct PrecisionWithPadding {
@@ -399,6 +405,40 @@ pub enum DecryptionAndNoiseResult {
}
impl DecryptionAndNoiseResult {
pub fn new_from_plaintext<Scalar: UnsignedInteger + CastFrom<u64>>(
decrypted_plaintext: Plaintext<Scalar>,
expected_msg: Scalar,
encoding: &ShortintEncoding<Scalar>,
) -> Self {
let delta = encoding.delta();
let cleartext_modulus_with_padding = encoding.full_cleartext_space();
// We apply the modulus on the cleartext + the padding bit
let decoded_msg =
round_decode(decrypted_plaintext.0, delta) % cleartext_modulus_with_padding;
let expected_plaintext = expected_msg * delta;
// decrypted_plaintext = expected_plaintext + error
// The order below computes:
// decrypted_plaintext - expected_plaintext in a modular way, which is what we want
// It only changes the average value sign, so that it is more intuitive when comparing to
// theory
let noise = torus_modular_diff(
decrypted_plaintext.0,
expected_plaintext,
encoding.ciphertext_modulus,
);
if decoded_msg == expected_msg {
Self::DecryptionSucceeded {
noise: NoiseSample { value: noise },
}
} else {
Self::DecryptionFailed
}
}
pub fn new_from_lwe<Scalar: UnsignedInteger + CastFrom<u64>, CtCont, KeyCont>(
ct: &LweCiphertext<CtCont>,
secret_key: &LweSecretKey<KeyCont>,
@@ -409,33 +449,85 @@ impl DecryptionAndNoiseResult {
CtCont: Container<Element = Scalar>,
KeyCont: Container<Element = Scalar>,
{
let decrypted_plaintext = decrypt_lwe_ciphertext(secret_key, ct).0;
let decrypted_plaintext = decrypt_lwe_ciphertext(secret_key, ct);
let delta = encoding.delta();
let cleartext_modulus_with_padding = encoding.full_cleartext_space();
Self::new_from_plaintext(decrypted_plaintext, expected_msg, encoding)
}
// We apply the modulus on the cleartext + the padding bit
let decoded_msg = round_decode(decrypted_plaintext, delta) % cleartext_modulus_with_padding;
let expected_plaintext = expected_msg * delta;
// decrypted_plaintext = expected_plaintext + error
// The order below computes:
// decrypted_plaintext - expected_plaintext in a modular way, which is what we want
// It only changes the average value sign, so that it is more intuitive when comparing to
// theory
let noise = torus_modular_diff(
decrypted_plaintext,
expected_plaintext,
ct.ciphertext_modulus(),
);
if decoded_msg == expected_msg {
Self::DecryptionSucceeded {
noise: NoiseSample { value: noise },
pub fn new_from_dyn_lwe(
ct: &DynLwe,
secret_key: &DynLweSecretKeyView<'_>,
expected_msg: u64,
) -> Self {
match (ct, secret_key) {
(DynLwe::U32(lwe_ciphertext), DynLweSecretKeyView::U32 { key, encoding }) => {
Self::new_from_lwe(
lwe_ciphertext,
key,
expected_msg.try_into().unwrap(),
encoding,
)
}
} else {
Self::DecryptionFailed
(DynLwe::U64(lwe_ciphertext), DynLweSecretKeyView::U64 { key, encoding }) => {
Self::new_from_lwe(lwe_ciphertext, key, expected_msg, encoding)
}
_ => panic!("Incompatible types in DecryptionAndNoiseResult::new_from_dyn_lwe"),
}
}
pub fn new_from_dyn_multi_bit_mod_switched_lwe(
ct: &DynStandardMultiBitModulusSwitchedCt,
secret_key: &DynLweSecretKeyView<'_>,
expected_msg: u64,
) -> Self {
match (ct, secret_key) {
(
DynStandardMultiBitModulusSwitchedCt::U32(standard_multi_bit_modulus_switched_ct),
DynLweSecretKeyView::U32 { key, encoding },
) => {
let decrypted_plaintext = decrypt_multi_bit_mod_switched_lwe_ciphertext(
key,
standard_multi_bit_modulus_switched_ct,
);
Self::new_from_plaintext(
decrypted_plaintext,
expected_msg.try_into().unwrap(),
encoding,
)
}
(
DynStandardMultiBitModulusSwitchedCt::U64(standard_multi_bit_modulus_switched_ct),
DynLweSecretKeyView::U64 { key, encoding },
) => {
let decrypted_plaintext = decrypt_multi_bit_mod_switched_lwe_ciphertext(
key,
standard_multi_bit_modulus_switched_ct,
);
Self::new_from_plaintext(decrypted_plaintext, expected_msg, encoding)
}
_ => panic!(
"Incompatible types in \
DecryptionAndNoiseResult::new_from_dyn_multi_bit_mod_switched_lwe"
),
}
}
pub fn new_from_dyn_modswitched_lwe(
ct: &DynModSwitchedLwe,
secret_key: &DynLweSecretKeyView<'_>,
expected_msg: u64,
) -> Self {
match ct {
DynModSwitchedLwe::ModSwitchedLwe(dyn_lwe) => {
Self::new_from_dyn_lwe(dyn_lwe, secret_key, expected_msg)
}
DynModSwitchedLwe::MultiBitModSwitchedLwe(
dyn_standard_multi_bit_modulus_switched_ct,
) => Self::new_from_dyn_multi_bit_mod_switched_lwe(
dyn_standard_multi_bit_modulus_switched_ct,
secret_key,
expected_msg,
),
}
}
@@ -599,6 +691,64 @@ pub fn expected_pfail_for_precision(
statrs::function::erf::erfc(measured_std_score / core::f64::consts::SQRT_2)
}
pub fn decrypt_multi_bit_mod_switched_lwe_ciphertext<Scalar, CtCont, KeyCont>(
lwe_secret_key: &LweSecretKey<KeyCont>,
mod_switched_lwe: &StandardMultiBitModulusSwitchedCt<Scalar, CtCont>,
) -> Plaintext<Scalar>
where
Scalar: UnsignedInteger + CastFrom<usize> + CastInto<usize>,
CtCont: Container<Element = Scalar> + Sync,
KeyCont: Container<Element = Scalar>,
{
let mut result: Scalar = mod_switched_lwe
.switched_modulus_input_lwe_body()
.cast_into();
let log_modulus = mod_switched_lwe.log_modulus;
let grouping_factor = mod_switched_lwe.grouping_factor();
let shift_to_native = Scalar::BITS - log_modulus.0;
result <<= shift_to_native;
for (loop_idx, lwe_key_bits) in lwe_secret_key
.as_ref()
.chunks_exact(grouping_factor.0)
.enumerate()
{
let selector = {
let mut selector = 0usize;
for bit in lwe_key_bits.iter() {
let bit: usize = (*bit).cast_into();
selector <<= 1;
selector |= bit;
}
if selector == 0 {
// We dont generate a mod switched value for selector == 0 it corresponds to key
// bits == 0
None
} else {
// We subtract 1 to be coherent with the fact the first mod switched value is not
// generated
Some(selector - 1)
}
};
if let Some(selector) = selector {
let mod_switched: Scalar = mod_switched_lwe
.switched_modulus_input_mask_per_group(loop_idx)
.nth(selector)
.unwrap()
.cast_into();
// Put in the high bits same as the body to be able to measure the noise in the
// encompassing modulus
let mod_switched = mod_switched << shift_to_native;
result = result.wrapping_sub(mod_switched);
}
}
Plaintext(result)
}
#[test]
fn test_expected_pfail_for_ci_run_filter() {
// Practical check on a compression-like scenario, of interest because pfail is known to be very

View File

@@ -0,0 +1,28 @@
pub use super::noise_simulation::traits::*;
/// Abstracts several bootstrapping implementation in the same way shortint's ServerKey does
pub trait LweGenericBootstrap<Input, Output, Accumulator> {
type SideResources;
fn lwe_generic_bootstrap(
&self,
input: &Input,
output: &mut Output,
accumulator: &Accumulator,
side_resources: &mut Self::SideResources,
);
}
/// Abstracts several blind rotate implementation in the same way shortint's ServerKey does, this
/// one is specific to the PBS 128 blind rotation
pub trait LweGenericBlindRotate128<Input, Output, Accumulator> {
type SideResources;
fn lwe_generic_blind_rotate_128(
&self,
input: &Input,
output: &mut Output,
accumulator: &Accumulator,
side_resources: &mut Self::SideResources,
);
}