mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
Compare commits
1 Commits
tfhe-rs-1.
...
tm/shuffle
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88241fee21 |
@@ -97,6 +97,12 @@ path = "benches/high_level_api/erc7984.rs"
|
||||
harness = false
|
||||
required-features = ["integer", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "hlapi-bitonic-shuffle"
|
||||
path = "benches/high_level_api/bitonic_shuffle.rs"
|
||||
harness = false
|
||||
required-features = ["integer", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "hlapi-dex"
|
||||
path = "benches/high_level_api/dex.rs"
|
||||
|
||||
53
tfhe-benchmark/benches/high_level_api/bitonic_shuffle.rs
Normal file
53
tfhe-benchmark/benches/high_level_api/bitonic_shuffle.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use benchmark::params_aliases::BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
use criterion::{black_box, criterion_group, Criterion};
|
||||
use rand::Rng;
|
||||
use tfhe::core_crypto::seeders::UnixSeeder;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{bitonic_shuffle, set_server_key, ClientKey, ConfigBuilder, FheUint64, ServerKey};
|
||||
|
||||
fn bitonic_shuffle_bench(c: &mut Criterion, bench_name: &str, cks: &ClientKey) {
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group.sample_size(10);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for num_elements in [16, 32] {
|
||||
let clear_values: Vec<u64> = (0..num_elements).map(|_| rng.gen()).collect();
|
||||
let encrypted: Vec<FheUint64> = clear_values
|
||||
.iter()
|
||||
.map(|&v| FheUint64::encrypt(v, cks))
|
||||
.collect();
|
||||
|
||||
let bench_id = format!("{bench_name}::n_{num_elements}");
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let mut seeder = UnixSeeder::new(0);
|
||||
let result = bitonic_shuffle(
|
||||
encrypted.clone(),
|
||||
tfhe::integer::server_key::BitonicShuffleKeySize::NumBlocks(16),
|
||||
&mut seeder,
|
||||
)
|
||||
.expect("shuffle failed");
|
||||
black_box(result);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
bench_group.finish();
|
||||
}
|
||||
|
||||
pub fn bitonic_shuffle_cpu(c: &mut Criterion) {
|
||||
let param = BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
let config = ConfigBuilder::with_custom_parameters(param).build();
|
||||
let cks = ClientKey::generate(config);
|
||||
let sks = ServerKey::new(&cks);
|
||||
|
||||
rayon::broadcast(|_| set_server_key(sks.clone()));
|
||||
set_server_key(sks);
|
||||
|
||||
bitonic_shuffle_bench(c, "hlapi::bitonic_shuffle_cpu", &cks);
|
||||
}
|
||||
|
||||
criterion_group!(bitonic_shuffle_group, bitonic_shuffle_cpu);
|
||||
criterion::criterion_main!(bitonic_shuffle_group);
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tfhe"
|
||||
version = "1.6.1"
|
||||
version = "1.6.0"
|
||||
edition = "2021"
|
||||
readme = "../README.md"
|
||||
keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography"]
|
||||
@@ -64,7 +64,7 @@ tfhe-fft = { version = "0.10.1", path = "../tfhe-fft", features = [
|
||||
"serde",
|
||||
"fft128",
|
||||
] }
|
||||
tfhe-ntt = { version = "0.7.1", path = "../tfhe-ntt" }
|
||||
tfhe-ntt = { version = "0.7.0", path = "../tfhe-ntt" }
|
||||
pulp = { workspace = true, features = ["default"] }
|
||||
tfhe-cuda-backend = { version = "0.14.0", path = "../backends/tfhe-cuda-backend", optional = true }
|
||||
aligned-vec = { workspace = true, features = ["default", "serde"] }
|
||||
|
||||
@@ -75,11 +75,11 @@
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="420.0">121 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="420.0">165 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="start" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="6" y="460.0">Leading / Trailing zeros/ones</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="460.0">67.2 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="460.0">70.6 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="510.0" y="460.0">89.8 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="460.0">92.6 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="460.0">113 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="460.0">88.4 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="460.0">148 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="510.0" y="460.0">169 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="460.0">222 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="460.0">275 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="start" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="6" y="500.0">Log2</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="500.0">110 ms</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="500.0">163 ms</text>
|
||||
|
||||
|
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 16 KiB |
@@ -75,11 +75,11 @@
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="420.0">32.5 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="420.0">14.0 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="start" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="6" y="460.0">Leading / Trailing zeros/ones</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="460.0">824 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="460.0">487 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="510.0" y="460.0">222 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="460.0">119 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="460.0">57.8 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="460.0">625 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="460.0">247 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="510.0" y="460.0">108 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="460.0">44.1 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="460.0">19.0 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="start" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="6" y="500.0">Log2</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="500.0">542 ops/s</text>
|
||||
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="500.0">220 ops/s</text>
|
||||
|
||||
|
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 16 KiB |
@@ -74,7 +74,7 @@ To compile and execute GPU TFHE-rs programs, make sure your system has the follo
|
||||
To use the **TFHE-rs** GPU backend in your project, add the following dependency in your `Cargo.toml`.
|
||||
|
||||
```toml
|
||||
tfhe = { version = "~1.6.1", features = ["boolean", "shortint", "integer", "gpu"] }
|
||||
tfhe = { version = "~1.6.0", features = ["boolean", "shortint", "integer", "gpu"] }
|
||||
```
|
||||
|
||||
If none of the supported backends is configured in `Cargo.toml`, the CPU backend is used.
|
||||
|
||||
@@ -17,7 +17,7 @@ This guide explains how to update your existing program to leverage HPU accelera
|
||||
To use the **TFHE-rs** HPU backend in your project, add the following dependency in your `Cargo.toml`.
|
||||
|
||||
```toml
|
||||
tfhe = { version = "~1.6.1", features = ["integer", "hpu-v80"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer", "hpu-v80"] }
|
||||
```
|
||||
|
||||
{% hint style="success" %}
|
||||
|
||||
@@ -16,7 +16,7 @@ You can load serialized data with the `unversionize` function, even in newer ver
|
||||
|
||||
[dependencies]
|
||||
# ...
|
||||
tfhe = { version = "~1.6.1", features = ["integer"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer"] }
|
||||
tfhe-versionable = "0.6.0"
|
||||
bincode = "1.3.3"
|
||||
```
|
||||
|
||||
@@ -161,7 +161,7 @@ In the following example, we use [bincode](https://crates.io/crates/bincode) for
|
||||
|
||||
[dependencies]
|
||||
# ...
|
||||
tfhe = { version = "~1.6.1", features = ["integer"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer"] }
|
||||
bincode = "1.3.3"
|
||||
```
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ The following example shows a complete workflow of working with encrypted arrays
|
||||
# Cargo.toml
|
||||
|
||||
[dependencies]
|
||||
tfhe = { version = "~1.6.1", features = ["integer"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer"] }
|
||||
```
|
||||
|
||||
```rust
|
||||
|
||||
@@ -36,7 +36,7 @@ To serialize a `KVStore`, it must first be compressed.
|
||||
# Cargo.toml
|
||||
|
||||
[dependencies]
|
||||
tfhe = { version = "~1.6.1", features = ["integer"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer"] }
|
||||
```
|
||||
|
||||
```rust
|
||||
|
||||
@@ -29,7 +29,7 @@ Here is an example:
|
||||
# Cargo.toml
|
||||
|
||||
[dependencies]
|
||||
tfhe = { version = "~1.6.1", features = ["integer", "strings"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer", "strings"] }
|
||||
```
|
||||
|
||||
```rust
|
||||
|
||||
@@ -7,7 +7,7 @@ This document provides instructions to set up **TFHE-rs** in your project.
|
||||
First, add **TFHE-rs** as a dependency in your `Cargo.toml`.
|
||||
|
||||
```toml
|
||||
tfhe = { version = "~1.6.1", features = ["boolean", "shortint", "integer"] }
|
||||
tfhe = { version = "~1.6.0", features = ["boolean", "shortint", "integer"] }
|
||||
```
|
||||
|
||||
{% hint style="info" %}
|
||||
@@ -35,7 +35,7 @@ By default, **TFHE-rs** makes the assumption that hardware AES features are enab
|
||||
To add support for older CPU, import **TFHE-rs** with the `software-prng` feature in your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
tfhe = { version = "~1.6.1", features = ["boolean", "shortint", "integer", "software-prng"] }
|
||||
tfhe = { version = "~1.6.0", features = ["boolean", "shortint", "integer", "software-prng"] }
|
||||
```
|
||||
|
||||
## Hardware acceleration
|
||||
|
||||
@@ -59,7 +59,7 @@ edition = "2021"
|
||||
Then add the following configuration to include **TFHE-rs**:
|
||||
|
||||
```toml
|
||||
tfhe = { version = "~1.6.1", features = ["integer"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer"] }
|
||||
```
|
||||
|
||||
Your updated `Cargo.toml` file should look like this:
|
||||
@@ -71,7 +71,7 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
tfhe = { version = "~1.6.1", features = ["integer"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer"] }
|
||||
```
|
||||
|
||||
If you are on a different platform please refer to the [installation documentation](installation.md) for configuration options of other supported platforms.
|
||||
|
||||
@@ -9,7 +9,7 @@ Welcome to this tutorial about `TFHE-rs` `core_crypto` module.
|
||||
To use `TFHE-rs`, it first has to be added as a dependency in the `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
tfhe = { version = "~1.6.1" }
|
||||
tfhe = { version = "~1.6.0" }
|
||||
```
|
||||
|
||||
### Commented code to double a 2-bit message in a leveled fashion and using a PBS with the `core_crypto` module.
|
||||
|
||||
@@ -28,7 +28,7 @@ To use the `FheUint8` type, enable the `integer` feature:
|
||||
# Cargo.toml
|
||||
|
||||
[dependencies]
|
||||
tfhe = { version = "~1.6.1", features = ["integer"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer"] }
|
||||
```
|
||||
|
||||
The `MyFheString::encrypt` function performs data validation to ensure the input string contains only ASCII characters.
|
||||
@@ -167,7 +167,7 @@ First, add the feature in your `Cargo.toml`
|
||||
# Cargo.toml
|
||||
|
||||
[dependencies]
|
||||
tfhe = { version = "~1.6.1", features = ["strings"] }
|
||||
tfhe = { version = "~1.6.0", features = ["strings"] }
|
||||
```
|
||||
|
||||
The `FheAsciiString` type allows to simply do homomorphic case changing of encrypted strings (and much more!):
|
||||
|
||||
@@ -17,7 +17,7 @@ This function returns a Boolean (`true` or `false`) so that the total count of `
|
||||
```toml
|
||||
# Cargo.toml
|
||||
|
||||
tfhe = { version = "~1.6.1", features = ["integer"] }
|
||||
tfhe = { version = "~1.6.0", features = ["integer"] }
|
||||
```
|
||||
|
||||
First, define the verification function.
|
||||
|
||||
@@ -49,6 +49,7 @@ pub use signed::{CompressedFheInt, FheInt, FheIntId, SquashedNoiseFheInt};
|
||||
pub use unsigned::{CompressedFheUint, FheUint, FheUintId, SquashedNoiseFheUint};
|
||||
|
||||
pub mod oprf;
|
||||
pub mod shuffle;
|
||||
pub(super) mod signed;
|
||||
pub(super) mod unsigned;
|
||||
|
||||
|
||||
121
tfhe/src/high_level_api/integers/shuffle.rs
Normal file
121
tfhe/src/high_level_api/integers/shuffle.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use crate::high_level_api::global_state;
|
||||
use crate::high_level_api::integers::FheIntegerType;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
|
||||
use crate::integer::server_key::radix_parallel::bitonic_shuffle::BitonicShuffleKeySize;
|
||||
use tfhe_csprng::seeders::Seeder;
|
||||
|
||||
/// Shuffles `data` into a uniformly random permutation using a bitonic
|
||||
/// sorting network with OPRF-generated random keys.
|
||||
///
|
||||
/// `key_size` controls the bit-width of the random sort keys used internally,
|
||||
/// either by specifying a target collision probability or by passing a raw
|
||||
/// block count. Larger keys reduce collision probability (improving shuffle
|
||||
/// uniformity) at the cost of more computation per comparison.
|
||||
///
|
||||
/// The re-randomization metadata of the input elements is not preserved
|
||||
/// through the shuffle.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the resolved key block count is 0, or if the Cuda/Hpu
|
||||
/// backend is active (not yet supported).
|
||||
pub fn bitonic_shuffle<T, S>(
|
||||
data: Vec<T>,
|
||||
key_size: BitonicShuffleKeySize,
|
||||
seeder: &mut S,
|
||||
) -> Result<Vec<T>, crate::Error>
|
||||
where
|
||||
T: FheIntegerType,
|
||||
S: Seeder,
|
||||
{
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(cpu_key) => {
|
||||
let inner = data.into_iter().map(|v| v.into_cpu()).collect();
|
||||
let result =
|
||||
cpu_key
|
||||
.pbs_key()
|
||||
.bitonic_shuffle(&cpu_key.oprf_key(), inner, key_size, seeder)?;
|
||||
Ok(result
|
||||
.into_iter()
|
||||
.map(|ct| T::from_cpu(ct, cpu_key.tag.clone(), ReRandomizationMetadata::default()))
|
||||
.collect())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_) => Err(crate::Error::new(
|
||||
"bitonic_shuffle is not supported on Cuda".to_string(),
|
||||
)),
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_) => Err(crate::Error::new(
|
||||
"bitonic_shuffle is not supported on Hpu".to_string(),
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{bitonic_shuffle, BitonicShuffleKeySize};
|
||||
use crate::core_crypto::commons::generators::DeterministicSeeder;
|
||||
use crate::core_crypto::prelude::new_seeder;
|
||||
use crate::high_level_api::prelude::*;
|
||||
use crate::high_level_api::tests::setup_default_cpu;
|
||||
use crate::{FheInt8, FheUint8};
|
||||
use rand::Rng;
|
||||
use tfhe_csprng::generators::DefaultRandomGenerator;
|
||||
|
||||
#[test]
|
||||
fn test_bitonic_shuffle_fheuint() {
|
||||
let cks = setup_default_cpu();
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut clear_values: Vec<u8> = (0..15).map(|_| rng.gen()).collect();
|
||||
|
||||
let encrypted: Vec<FheUint8> = clear_values
|
||||
.iter()
|
||||
.map(|&v| FheUint8::try_encrypt(v, &cks).unwrap())
|
||||
.collect();
|
||||
|
||||
let seed = new_seeder().seed();
|
||||
println!("seed: {seed:?}");
|
||||
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
|
||||
let shuffled = bitonic_shuffle(
|
||||
encrypted,
|
||||
BitonicShuffleKeySize::num_blocks(16),
|
||||
&mut seeder,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut decrypted: Vec<u8> = shuffled.iter().map(|ct| ct.decrypt(&cks)).collect();
|
||||
|
||||
clear_values.sort_unstable();
|
||||
decrypted.sort_unstable();
|
||||
assert_eq!(decrypted, clear_values);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bitonic_shuffle_fheint() {
|
||||
let cks = setup_default_cpu();
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut clear_values: Vec<i8> = (0..15).map(|_| rng.gen()).collect();
|
||||
|
||||
let encrypted: Vec<FheInt8> = clear_values
|
||||
.iter()
|
||||
.map(|&v| FheInt8::try_encrypt(v, &cks).unwrap())
|
||||
.collect();
|
||||
|
||||
let seed = new_seeder().seed();
|
||||
println!("seed: {seed:?}");
|
||||
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
|
||||
let shuffled = bitonic_shuffle(
|
||||
encrypted,
|
||||
BitonicShuffleKeySize::num_blocks(16),
|
||||
&mut seeder,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut decrypted: Vec<i8> = shuffled.iter().map(|ct| ct.decrypt(&cks)).collect();
|
||||
|
||||
clear_values.sort_unstable();
|
||||
decrypted.sort_unstable();
|
||||
assert_eq!(decrypted, clear_values);
|
||||
}
|
||||
}
|
||||
@@ -47,8 +47,9 @@ macro_rules! export_concrete_array_types {
|
||||
};
|
||||
}
|
||||
|
||||
pub use crate::core_crypto::commons::math::random::{Seed, XofSeed};
|
||||
pub use crate::core_crypto::commons::math::random::{Seed, Seeder, XofSeed};
|
||||
pub use crate::high_level_api::integers::oprf::RangeForRandom;
|
||||
pub use crate::high_level_api::integers::shuffle::bitonic_shuffle;
|
||||
pub use crate::integer::server_key::MatchValues;
|
||||
pub use crate::shortint::OprfSeed;
|
||||
use crate::{error, Error, Versionize};
|
||||
@@ -70,8 +71,8 @@ pub use integers::{
|
||||
pub use keys::CudaServerKey;
|
||||
pub use keys::{
|
||||
generate_keys, ClientKey, CompactPublicKey, CompressedCompactPublicKey, CompressedPublicKey,
|
||||
CompressedReRandomizationKey, CompressedReRandomizationKeySwitchingKey, CompressedServerKey,
|
||||
KeySwitchingKey, PublicKey, ReRandomizationKey, ReRandomizationKeySwitchingKey, ServerKey,
|
||||
CompressedReRandomizationKeySwitchingKey, CompressedServerKey, KeySwitchingKey, PublicKey,
|
||||
ReRandomizationKeySwitchingKey, ServerKey,
|
||||
};
|
||||
use strum::FromRepr;
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ pub use crate::shortint::CheckError;
|
||||
use crate::shortint::{CarryModulus, MessageModulus};
|
||||
pub use radix::scalar_mul::ScalarMultiplier;
|
||||
pub use radix::scalar_sub::TwosComplementNegation;
|
||||
pub use radix_parallel::{MatchValues, MiniUnsignedInteger, Reciprocable};
|
||||
pub use radix_parallel::{
|
||||
BitonicShuffleKeySize, CollisionProbability, MatchValues, MiniUnsignedInteger, Reciprocable,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
|
||||
@@ -247,14 +247,15 @@ impl ServerKey {
|
||||
/// let res: u64 = cks.decrypt(&ct1);
|
||||
/// assert_eq!(7, res);
|
||||
/// ```
|
||||
pub fn extend_radix_with_trivial_zero_blocks_msb_assign(
|
||||
&self,
|
||||
ct: &mut RadixCiphertext,
|
||||
num_blocks: usize,
|
||||
) {
|
||||
pub fn extend_radix_with_trivial_zero_blocks_msb_assign<T>(&self, ct: &mut T, num_blocks: usize)
|
||||
where
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
// Swap out blocks via a no-alloc empty sentinel (blocks_mut is a slice, can't resize)
|
||||
let mut blocks = std::mem::replace(ct, T::from_blocks(vec![])).into_blocks();
|
||||
let block_trivial_zero = self.key.create_trivial(0);
|
||||
ct.blocks
|
||||
.resize(ct.blocks.len() + num_blocks, block_trivial_zero);
|
||||
blocks.resize(blocks.len() + num_blocks, block_trivial_zero);
|
||||
*ct = T::from_blocks(blocks);
|
||||
}
|
||||
|
||||
/// Append trivial zero MSB blocks to an existing [`RadixCiphertext`] and returns the result as
|
||||
|
||||
291
tfhe/src/integer/server_key/radix_parallel/bitonic_shuffle.rs
Normal file
291
tfhe/src/integer/server_key/radix_parallel/bitonic_shuffle.rs
Normal file
@@ -0,0 +1,291 @@
|
||||
//! Bitonic sorting network generator.
|
||||
//!
|
||||
//! A bitonic sorting network for n=2^k elements has k*(k+1)/2 stages,
|
||||
//! each with n/2 comparators. It sorts any input sequence.
|
||||
use crate::core_crypto::prelude::Container;
|
||||
use crate::integer::oprf::GenericOprfServerKey;
|
||||
use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, ServerKey};
|
||||
use crate::shortint::MessageModulus;
|
||||
use rayon::prelude::*;
|
||||
use tfhe_csprng::seeders::Seeder;
|
||||
use tfhe_fft::c64;
|
||||
|
||||
/// Generates a bitonic sorting network for n elements (n must be a power of 2).
|
||||
///
|
||||
/// Returns a list of stages, where each stage contains disjoint (i, j, ascending) triples.
|
||||
/// Each triple represents a compare-and-swap: if ascending, put the smaller element at i;
|
||||
/// if descending, put the larger element at i.
|
||||
pub(crate) fn bitonic_network(n: usize) -> Vec<Vec<(usize, usize, bool)>> {
|
||||
assert!(
|
||||
n.is_power_of_two() && n >= 2,
|
||||
"bitonic_network requires n to be a power of 2 and >= 2, got n={n}"
|
||||
);
|
||||
let log_n = n.trailing_zeros() as usize;
|
||||
let mut stages = Vec::new();
|
||||
|
||||
for phase in 0..log_n {
|
||||
for step in (0..=phase).rev() {
|
||||
let mut comparators = Vec::new();
|
||||
for i in 0..n {
|
||||
let j = i ^ (1 << step);
|
||||
if j > i {
|
||||
let ascending = (i >> (phase + 1)) & 1 == 0;
|
||||
comparators.push((i, j, ascending));
|
||||
}
|
||||
}
|
||||
stages.push(comparators);
|
||||
}
|
||||
}
|
||||
|
||||
stages
|
||||
}
|
||||
|
||||
/// Returns the next power of 2 >= n (or n itself if already a power of 2).
|
||||
fn padded_size(n: usize) -> usize {
|
||||
n.next_power_of_two()
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Debug)]
|
||||
pub struct CollisionProbability(f64);
|
||||
|
||||
impl CollisionProbability {
|
||||
pub fn try_new(proba: f64) -> Option<Self> {
|
||||
if 0.0 < proba && proba < 1.0 {
|
||||
Some(Self(proba))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(proba: f64) -> Self {
|
||||
Self::try_new(proba).expect("Invalid probability, it must be in ]0, 1.0[")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Debug)]
|
||||
pub enum BitonicShuffleKeySize {
|
||||
CollisionProbability(CollisionProbability),
|
||||
NumBlocks(u32),
|
||||
}
|
||||
|
||||
impl BitonicShuffleKeySize {
|
||||
pub fn try_collision_probability(proba: f64) -> Option<Self> {
|
||||
CollisionProbability::try_new(proba).map(Self::CollisionProbability)
|
||||
}
|
||||
|
||||
pub fn collision_probability(proba: f64) -> Self {
|
||||
Self::CollisionProbability(CollisionProbability::new(proba))
|
||||
}
|
||||
|
||||
pub fn num_blocks(num_blocks: u32) -> Self {
|
||||
Self::NumBlocks(num_blocks)
|
||||
}
|
||||
|
||||
fn num_blocks_of_keys(&self, num_elements: usize, msg_mod: MessageModulus) -> u32 {
|
||||
match self {
|
||||
Self::CollisionProbability(CollisionProbability(proba)) => {
|
||||
let n_squared = (num_elements * num_elements) as f64;
|
||||
let bits = (n_squared / (2.0 * proba)).log2().ceil() as u32;
|
||||
bits.div_ceil(msg_mod.0.ilog2())
|
||||
}
|
||||
Self::NumBlocks(n) => *n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerKey {
|
||||
/// Shuffles `data` into a uniformly random permutation using a bitonic sorting network
|
||||
/// with random sort keys.
|
||||
///
|
||||
/// `key_size` controls the bit-width of the random sort keys used internally, either
|
||||
/// by specifying a target collision probability or by passing a raw block count.
|
||||
/// Larger keys reduce collision probability — and thus improve shuffle uniformity —
|
||||
/// at the cost of more computation per comparison/swap.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the resolved key block count is 0.
|
||||
pub fn bitonic_shuffle<T, S, C>(
|
||||
&self,
|
||||
oprf_key: &GenericOprfServerKey<C>,
|
||||
data: Vec<T>,
|
||||
key_size: BitonicShuffleKeySize,
|
||||
seeder: &mut S,
|
||||
) -> Result<Vec<T>, crate::Error>
|
||||
where
|
||||
T: IntegerRadixCiphertext,
|
||||
S: Seeder,
|
||||
C: Container<Element = c64> + Sync,
|
||||
{
|
||||
let key_num_blocks = key_size.num_blocks_of_keys(data.len(), self.message_modulus()) as u64;
|
||||
|
||||
if key_num_blocks == 0 {
|
||||
return Err(crate::Error::new(
|
||||
"key_num_blocks must be at least 1".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if data.len() <= 1 {
|
||||
return Ok(data);
|
||||
}
|
||||
|
||||
let keys: Vec<_> = (0..data.len())
|
||||
.map(|_| {
|
||||
oprf_key.par_generate_oblivious_pseudo_random_unsigned_integer(
|
||||
seeder.seed(),
|
||||
key_num_blocks,
|
||||
self,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.bitonic_shuffle_with_keys(data, keys)
|
||||
}
|
||||
|
||||
/// Shuffles `data` using a bitonic sorting network keyed by `keys`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if `data` and `keys` have different lengths, or if
|
||||
/// elements within `data` (or within `keys`) have inconsistent block counts.
|
||||
pub fn bitonic_shuffle_with_keys<T>(
|
||||
&self,
|
||||
mut data: Vec<T>,
|
||||
mut keys: Vec<RadixCiphertext>,
|
||||
) -> Result<Vec<T>, crate::Error>
|
||||
where
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
if data.len() != keys.len() {
|
||||
return Err(crate::Error::new(format!(
|
||||
"data and keys must have the same length, got {} and {}",
|
||||
data.len(),
|
||||
keys.len()
|
||||
)));
|
||||
}
|
||||
|
||||
if data.len() <= 1 {
|
||||
return Ok(data);
|
||||
}
|
||||
|
||||
let data_num_blocks = data[0].blocks().len();
|
||||
if data[1..]
|
||||
.iter()
|
||||
.any(|d| d.blocks().len() != data_num_blocks)
|
||||
{
|
||||
return Err(crate::Error::new(
|
||||
"all data elements must have the same number of blocks".to_string(),
|
||||
));
|
||||
}
|
||||
let key_num_blocks = keys[0].blocks.len();
|
||||
if keys[1..].iter().any(|k| k.blocks.len() != key_num_blocks) {
|
||||
return Err(crate::Error::new(
|
||||
"all keys must have the same number of blocks".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
rayon::join(
|
||||
|| {
|
||||
data.par_iter_mut()
|
||||
.for_each(|value| self.clean_inplace_for_default_op(value));
|
||||
},
|
||||
|| {
|
||||
keys.par_iter_mut()
|
||||
.for_each(|value| self.clean_inplace_for_default_op(value));
|
||||
},
|
||||
);
|
||||
|
||||
let mut data = self.unchecked_bitonic_shuffle_with_keys(data, keys);
|
||||
|
||||
data.par_iter_mut().for_each(|radix| {
|
||||
radix
|
||||
.blocks_mut()
|
||||
.par_iter_mut()
|
||||
.for_each(|block| self.key.message_extract_assign(block))
|
||||
});
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
/// Performs a bitonic shuffle without cleaning inputs or outputs.
|
||||
///
|
||||
/// # Preconditions
|
||||
///
|
||||
/// * `data` and `keys` must have the same length and consistent block counts.
|
||||
/// * Data blocks must have no carries and noise budget for `unchecked_flip_parallelized`.
|
||||
/// * Key blocks must have no carries and noise budget for `unchecked_lt/gt`.
|
||||
///
|
||||
/// Output blocks have no carries but non-nominal noise level.
|
||||
pub fn unchecked_bitonic_shuffle_with_keys<T>(
|
||||
&self,
|
||||
mut data: Vec<T>,
|
||||
mut keys: Vec<RadixCiphertext>,
|
||||
) -> Vec<T>
|
||||
where
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
assert_eq!(
|
||||
data.len(),
|
||||
keys.len(),
|
||||
"data.len()={} != keys.len()={}",
|
||||
data.len(),
|
||||
keys.len()
|
||||
);
|
||||
let n = data.len();
|
||||
if n <= 1 {
|
||||
return data;
|
||||
}
|
||||
|
||||
let padded_n = padded_size(n);
|
||||
let network = bitonic_network(padded_n);
|
||||
|
||||
let mut key_num_blocks = keys[0].blocks.len();
|
||||
let data_num_blocks = data[0].blocks().len();
|
||||
|
||||
let pad = padded_n - n;
|
||||
if pad > 0 {
|
||||
// We need to pad with some trivial (key=MAX, data=0)
|
||||
// However it could be that a key is already=MAX, so to protect us from that case
|
||||
// we add an extra block to the keys
|
||||
key_num_blocks += 1;
|
||||
for key in &mut keys {
|
||||
self.extend_radix_with_trivial_zero_blocks_msb_assign(key, 1);
|
||||
}
|
||||
|
||||
for _ in 0..pad {
|
||||
keys.push(self.create_trivial_max_radix(key_num_blocks));
|
||||
data.push(self.create_trivial_zero_radix(data_num_blocks));
|
||||
}
|
||||
}
|
||||
|
||||
let mut stage_results = Vec::with_capacity(padded_n / 2);
|
||||
for stage in network {
|
||||
stage
|
||||
.into_par_iter()
|
||||
.map(|(i, j, ascending)| {
|
||||
let cmp = if ascending {
|
||||
self.unchecked_gt_parallelized(&keys[i], &keys[j])
|
||||
} else {
|
||||
self.unchecked_lt_parallelized(&keys[i], &keys[j])
|
||||
};
|
||||
let ((new_ki, new_kj), (new_di, new_dj)) = rayon::join(
|
||||
|| self.unchecked_flip_parallelized(&cmp, &keys[i], &keys[j]),
|
||||
|| self.unchecked_flip_parallelized(&cmp, &data[i], &data[j]),
|
||||
);
|
||||
|
||||
(i, j, new_ki, new_kj, new_di, new_dj)
|
||||
})
|
||||
.collect_into_vec(&mut stage_results);
|
||||
|
||||
for (i, j, new_ki, new_kj, new_di, new_dj) in stage_results.drain(..) {
|
||||
keys[i] = new_ki;
|
||||
keys[j] = new_kj;
|
||||
data[i] = new_di;
|
||||
data[j] = new_dj;
|
||||
}
|
||||
}
|
||||
|
||||
data.truncate(n);
|
||||
data
|
||||
}
|
||||
}
|
||||
@@ -121,96 +121,17 @@ where
|
||||
"Inputs must have the same number of blocks"
|
||||
);
|
||||
|
||||
// To make use if many_lut, we require 1 bit, 1 more bit is required to pack
|
||||
// the condition. Thus 2 bits of carry are required.
|
||||
//
|
||||
// Otherwise we call if_then_else twice, which is less efficient.
|
||||
if self.carry_modulus().0 < (1 << 2) {
|
||||
return rayon::join(
|
||||
|| self.if_then_else_parallelized(condition, b, a),
|
||||
|| self.if_then_else_parallelized(condition, a, b),
|
||||
);
|
||||
}
|
||||
|
||||
let (a, b) = rayon::join(
|
||||
|| self.clean_for_default_op(a),
|
||||
|| self.clean_for_default_op(b),
|
||||
);
|
||||
let (mut a, mut b) = self.unchecked_flip_parallelized(condition, &*a, &*b);
|
||||
a.blocks_mut()
|
||||
.par_iter_mut()
|
||||
.chain(b.blocks_mut().par_iter_mut())
|
||||
.for_each(|block| self.key.message_extract_assign(block));
|
||||
|
||||
let zero_out_if_true_fn = |packed| {
|
||||
let condition = (packed / self.message_modulus().0) & 1;
|
||||
let value = packed % self.message_modulus().0;
|
||||
(1 - condition) * value
|
||||
};
|
||||
|
||||
let zero_out_if_false_fn = |packed| {
|
||||
let condition = (packed / self.message_modulus().0) & 1;
|
||||
let value = packed % self.message_modulus().0;
|
||||
condition * value
|
||||
};
|
||||
|
||||
let lut = self
|
||||
.key
|
||||
.generate_many_lookup_table(&[&zero_out_if_true_fn, &zero_out_if_false_fn]);
|
||||
|
||||
let scaled_condition = self
|
||||
.key
|
||||
.unchecked_scalar_mul(&condition.0, self.message_modulus().0 as u8);
|
||||
|
||||
let map_condition_lut_on_blocks =
|
||||
|blocks: &[Ciphertext]| -> (Vec<Ciphertext>, Vec<Ciphertext>) {
|
||||
let mut left = Vec::with_capacity(blocks.len());
|
||||
let mut right = Vec::with_capacity(blocks.len());
|
||||
blocks
|
||||
.par_iter()
|
||||
.map(|block| {
|
||||
let block = self.key.unchecked_add(block, &scaled_condition);
|
||||
let mut resulting_blocks = self.key.apply_many_lookup_table(&block, &lut);
|
||||
|
||||
let second_result = resulting_blocks.pop().unwrap();
|
||||
let first_result = resulting_blocks.pop().unwrap();
|
||||
|
||||
(first_result, second_result)
|
||||
})
|
||||
.unzip_into_vecs(&mut left, &mut right);
|
||||
(left, right)
|
||||
};
|
||||
|
||||
let (
|
||||
(mut a_blocks_if_cond, mut a_blocks_if_not_cond),
|
||||
(b_blocks_if_cond, b_blocks_if_not_cond),
|
||||
) = rayon::join(
|
||||
|| map_condition_lut_on_blocks(a.blocks()),
|
||||
|| map_condition_lut_on_blocks(b.blocks()),
|
||||
);
|
||||
|
||||
let clean_lut = self
|
||||
.key
|
||||
.generate_lookup_table(|x| x % self.message_modulus().0);
|
||||
|
||||
let inplace_add_then_clean_blocks =
|
||||
|lhs_blocks: &mut [Ciphertext], rhs_blocks: &[Ciphertext]| {
|
||||
lhs_blocks
|
||||
.par_iter_mut()
|
||||
.zip(rhs_blocks.par_iter())
|
||||
.for_each(|(lhs, rhs)| {
|
||||
self.key.unchecked_add_assign(lhs, rhs);
|
||||
self.key.apply_lookup_table_assign(lhs, &clean_lut);
|
||||
});
|
||||
};
|
||||
rayon::join(
|
||||
|| {
|
||||
inplace_add_then_clean_blocks(&mut a_blocks_if_cond, &b_blocks_if_not_cond);
|
||||
},
|
||||
|| {
|
||||
inplace_add_then_clean_blocks(&mut a_blocks_if_not_cond, &b_blocks_if_cond);
|
||||
},
|
||||
);
|
||||
|
||||
(
|
||||
T::from_blocks(a_blocks_if_cond),
|
||||
T::from_blocks(a_blocks_if_not_cond),
|
||||
)
|
||||
(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -987,6 +908,122 @@ impl ServerKey {
|
||||
});
|
||||
}
|
||||
|
||||
/// Performs `if condition { (b, a) } else { (a, b) }` in fhe
|
||||
///
|
||||
/// * Blocks of both `a` and `b` must have no carries and noise_level <= 2
|
||||
/// * Outputs will have no carries but noise_level == 2
|
||||
pub(crate) fn unchecked_flip_parallelized<T>(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
a: &T,
|
||||
b: &T,
|
||||
) -> (T, T)
|
||||
where
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
assert_eq!(
|
||||
a.blocks().len(),
|
||||
b.blocks().len(),
|
||||
"Inputs must have the same number of blocks"
|
||||
);
|
||||
|
||||
// To make use if many_lut, we require 1 bit, 1 more bit is required to pack
|
||||
// the condition. Thus 2 bits of carry are required.
|
||||
//
|
||||
// Otherwise we call if_then_else twice, which is less efficient.
|
||||
if self.carry_modulus().0 < (1 << 2) {
|
||||
let unchecked_if_then_else_parallelized_no_cleanup =
|
||||
|condition: &BooleanBlock, true_ct, false_ct| {
|
||||
let condition_block = &condition.0;
|
||||
let do_clean_message = false;
|
||||
self.unchecked_programmable_if_then_else_parallelized(
|
||||
condition_block,
|
||||
true_ct,
|
||||
false_ct,
|
||||
|x| x == 1,
|
||||
do_clean_message,
|
||||
)
|
||||
};
|
||||
return rayon::join(
|
||||
|| unchecked_if_then_else_parallelized_no_cleanup(condition, b, a),
|
||||
|| unchecked_if_then_else_parallelized_no_cleanup(condition, a, b),
|
||||
);
|
||||
}
|
||||
|
||||
// We move the block message by one bit: it is the most flexible option (at least for 2_2)
|
||||
// as it allows the block to have noise_level <= 2 or the condition block to have
|
||||
// noise_level <= 2.
|
||||
//
|
||||
// A drawback is that everyblock needs to be cloned to be shifted, as opposed to
|
||||
// only the condition block, but it's an ok drawback
|
||||
assert!(a.blocks().iter().chain(b.blocks().iter()).all(|block| {
|
||||
self.key
|
||||
.max_noise_level
|
||||
.validate(block.noise_level() * 2 + condition.0.noise_level())
|
||||
.is_ok()
|
||||
}));
|
||||
|
||||
let zero_out_if_true_fn = |packed| {
|
||||
let condition = packed & 1;
|
||||
let value = packed / 2;
|
||||
(1 - condition) * value
|
||||
};
|
||||
|
||||
let zero_out_if_false_fn = |packed| {
|
||||
let condition = packed & 1;
|
||||
let value = packed / 2;
|
||||
condition * value
|
||||
};
|
||||
|
||||
let lut = self
|
||||
.key
|
||||
.generate_many_lookup_table(&[&zero_out_if_true_fn, &zero_out_if_false_fn]);
|
||||
|
||||
let map_condition_lut_on_blocks =
|
||||
|blocks: &[Ciphertext]| -> (Vec<Ciphertext>, Vec<Ciphertext>) {
|
||||
let mut left = Vec::with_capacity(blocks.len());
|
||||
let mut right = Vec::with_capacity(blocks.len());
|
||||
blocks
|
||||
.par_iter()
|
||||
.map(|block| {
|
||||
let mut block = self.key.unchecked_scalar_mul(block, 2);
|
||||
self.key.unchecked_add_assign(&mut block, &condition.0);
|
||||
let mut resulting_blocks = self.key.apply_many_lookup_table(&block, &lut);
|
||||
|
||||
let second_result = resulting_blocks.pop().unwrap();
|
||||
let first_result = resulting_blocks.pop().unwrap();
|
||||
|
||||
(first_result, second_result)
|
||||
})
|
||||
.unzip_into_vecs(&mut left, &mut right);
|
||||
(left, right)
|
||||
};
|
||||
|
||||
let (
|
||||
(mut a_blocks_if_cond, mut a_blocks_if_not_cond),
|
||||
(b_blocks_if_cond, b_blocks_if_not_cond),
|
||||
) = rayon::join(
|
||||
|| map_condition_lut_on_blocks(a.blocks()),
|
||||
|| map_condition_lut_on_blocks(b.blocks()),
|
||||
);
|
||||
|
||||
for (a, b) in a_blocks_if_cond.iter_mut().zip(b_blocks_if_not_cond.iter()) {
|
||||
self.key.unchecked_add_assign(a, b);
|
||||
// By construction, one of the two input encrypts only zeros
|
||||
a.degree.0 = self.message_modulus().0 - 1;
|
||||
}
|
||||
for (a, b) in a_blocks_if_not_cond.iter_mut().zip(b_blocks_if_cond.iter()) {
|
||||
self.key.unchecked_add_assign(a, b);
|
||||
// By construction, one of the two input encrypts only zeros
|
||||
a.degree.0 = self.message_modulus().0 - 1;
|
||||
}
|
||||
|
||||
(
|
||||
T::from_blocks(a_blocks_if_cond),
|
||||
T::from_blocks(a_blocks_if_not_cond),
|
||||
)
|
||||
}
|
||||
|
||||
fn scalar_flip_parallelized<T, Scalar>(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
|
||||
@@ -112,6 +112,7 @@ impl ServerKey {
|
||||
/// * inputs must have the same number of blocks
|
||||
/// * block carries of both inputs must be empty
|
||||
/// * carry modulus == message modulus
|
||||
/// * blocks of the inputs a and b must have a noise_level such that a[i] - b[i] is possible
|
||||
fn compare<T>(&self, a: &T, b: &T, compare: ComparisonKind) -> BooleanBlock
|
||||
where
|
||||
T: IntegerRadixCiphertext,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod abs;
|
||||
mod add;
|
||||
mod bit_extractor;
|
||||
pub(crate) mod bitonic_shuffle;
|
||||
mod bitwise_op;
|
||||
mod block_shift;
|
||||
pub(crate) mod cmux;
|
||||
@@ -47,6 +48,7 @@ use crate::integer::ciphertext::IntegerRadixCiphertext;
|
||||
use crate::integer::RadixCiphertext;
|
||||
use crate::shortint::ciphertext::{Ciphertext, NoiseLevel};
|
||||
pub(crate) use add::OutputFlag;
|
||||
pub use bitonic_shuffle::{BitonicShuffleKeySize, CollisionProbability};
|
||||
use rayon::prelude::*;
|
||||
pub use scalar_div_mod::{MiniUnsignedInteger, Reciprocable};
|
||||
pub use vector_find::MatchValues;
|
||||
@@ -323,4 +325,18 @@ impl ServerKey {
|
||||
Cow::Borrowed(ct)
|
||||
}
|
||||
}
|
||||
|
||||
/// Cleans the input inplace so that it is ready to be used in a default ops
|
||||
pub(crate) fn clean_inplace_for_default_op<T>(&self, ct: &mut T)
|
||||
where
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
if ct
|
||||
.blocks()
|
||||
.iter()
|
||||
.any(|block| !block.carry_is_empty() || block.noise_level() != NoiseLevel::NOMINAL)
|
||||
{
|
||||
self.full_propagate_parallelized(ct);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ mod modulus_switch_compression;
|
||||
pub(crate) mod test_add;
|
||||
pub(crate) mod test_aes;
|
||||
pub(crate) mod test_aes256;
|
||||
pub(crate) mod test_bitonic_shuffle;
|
||||
pub(crate) mod test_bitwise_op;
|
||||
mod test_block_rotate;
|
||||
mod test_block_shift;
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::bitonic_shuffle::bitonic_network;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
|
||||
use crate::integer::tests::create_parameterized_test;
|
||||
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
|
||||
#[cfg(tarpaulin)]
|
||||
use crate::shortint::parameters::coverage_parameters::*;
|
||||
use crate::shortint::parameters::test_params::*;
|
||||
use crate::shortint::parameters::*;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Clear reference: sorts data by keys using the same bitonic network
|
||||
/// as `ServerKey::bitonic_shuffle_with_keys`.
|
||||
fn clear_bitonic_shuffle_with_keys(data: &[u32], keys: &[u32]) -> Vec<u32> {
|
||||
assert_eq!(data.len(), keys.len());
|
||||
let n = data.len();
|
||||
if n <= 1 {
|
||||
return data.to_vec();
|
||||
}
|
||||
|
||||
let padded_n = n.next_power_of_two();
|
||||
let mut keys = keys.iter().copied().map(u64::from).collect::<Vec<_>>();
|
||||
let mut data = data.to_vec();
|
||||
|
||||
// Pad with MAX keys and zero data, so padding sorts to the end and is truncated
|
||||
for _ in 0..(padded_n - n) {
|
||||
keys.push(u64::MAX);
|
||||
data.push(0u32);
|
||||
}
|
||||
|
||||
let network = bitonic_network(padded_n);
|
||||
for stage in &network {
|
||||
let swaps: Vec<_> = stage
|
||||
.iter()
|
||||
.map(|&(i, j, ascending)| {
|
||||
let should_swap = if ascending {
|
||||
keys[i] > keys[j]
|
||||
} else {
|
||||
keys[i] < keys[j]
|
||||
};
|
||||
(i, j, should_swap)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (i, j, should_swap) in swaps {
|
||||
if should_swap {
|
||||
keys.swap(i, j);
|
||||
data.swap(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data.truncate(n);
|
||||
data
|
||||
}
|
||||
|
||||
create_parameterized_test!(integer_bitonic_shuffle_with_keys);
|
||||
|
||||
fn integer_bitonic_shuffle_with_keys<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::bitonic_shuffle_with_keys::<RadixCiphertext>);
|
||||
bitonic_shuffle_with_keys_test(param, executor);
|
||||
}
|
||||
|
||||
pub(crate) fn bitonic_shuffle_with_keys_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(Vec<RadixCiphertext>, Vec<RadixCiphertext>),
|
||||
Result<Vec<RadixCiphertext>, crate::Error>,
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let num_blocks = 32usize.div_ceil(cks.parameters().message_modulus().0.ilog2() as usize);
|
||||
let cks = RadixClientKey::from((cks, num_blocks));
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for _ in 0..4 {
|
||||
let len = rng.gen_range(1..=16usize).next_power_of_two();
|
||||
let clear_keys: Vec<u32> = (0..len).map(|_| rng.gen::<u32>()).collect();
|
||||
let mut clear_data: Vec<u32> = (0..len).map(|_| rng.gen::<u32>()).collect();
|
||||
println!("clear_keys: {clear_keys:?}, clear_data: {clear_data:?}");
|
||||
|
||||
let enc_keys = clear_keys.iter().map(|&v| cks.encrypt(v as u64)).collect();
|
||||
let enc_data = clear_data.iter().map(|&v| cks.encrypt(v as u64)).collect();
|
||||
|
||||
let result = executor.execute((enc_data, enc_keys)).unwrap();
|
||||
|
||||
let mut decrypted: Vec<u32> = result
|
||||
.iter()
|
||||
.map(|ct| cks.decrypt::<u64>(ct) as u32)
|
||||
.collect();
|
||||
|
||||
// Check the encrypted implementation matches the clear one
|
||||
let expected = clear_bitonic_shuffle_with_keys(&clear_data, &clear_keys);
|
||||
assert_eq!(decrypted, expected);
|
||||
|
||||
// Check that the permutation did not lose any of the data
|
||||
decrypted.sort_unstable();
|
||||
clear_data.sort_unstable();
|
||||
assert_eq!(decrypted, clear_data, "permutation lost data");
|
||||
}
|
||||
|
||||
{
|
||||
let len = 17;
|
||||
let mut clear_keys: Vec<u32> = (0..len).map(|_| rng.gen::<u32>()).collect();
|
||||
let mut clear_data: Vec<u32> = (0..len).map(|_| rng.gen::<u32>()).collect();
|
||||
clear_keys[3] = u32::MAX;
|
||||
assert!(!clear_keys.len().is_power_of_two());
|
||||
println!("clear_keys: {clear_keys:?}, clear_data: {clear_data:?}");
|
||||
|
||||
let enc_keys = clear_keys.iter().map(|&v| cks.encrypt(v as u64)).collect();
|
||||
let enc_data = clear_data.iter().map(|&v| cks.encrypt(v as u64)).collect();
|
||||
|
||||
let result = executor.execute((enc_data, enc_keys)).unwrap();
|
||||
|
||||
let mut decrypted: Vec<u32> = result
|
||||
.iter()
|
||||
.map(|ct| cks.decrypt::<u64>(ct) as u32)
|
||||
.collect();
|
||||
|
||||
// Check the encrypted implementation matches the clear one
|
||||
let expected = clear_bitonic_shuffle_with_keys(&clear_data, &clear_keys);
|
||||
assert_eq!(decrypted, expected);
|
||||
|
||||
// Check that the permutation did not lose any of the data
|
||||
decrypted.sort_unstable();
|
||||
clear_data.sort_unstable();
|
||||
assert_eq!(decrypted, clear_data, "permutation lost data");
|
||||
}
|
||||
}
|
||||
@@ -976,7 +976,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tfhe"
|
||||
version = "1.6.1"
|
||||
version = "1.6.0"
|
||||
dependencies = [
|
||||
"aligned-vec",
|
||||
"bincode",
|
||||
|
||||
Reference in New Issue
Block a user