Compare commits

..

1 Commits

Author SHA1 Message Date
Thomas Montaigu
88241fee21 feat: add bitonic_shuffle
Bitonic shuffles allows to shuffle a vec of homomorphic integers
by using a bitonic sort using random generated keys
2026-04-22 15:50:02 +02:00
29 changed files with 796 additions and 124 deletions

View File

@@ -97,6 +97,12 @@ path = "benches/high_level_api/erc7984.rs"
harness = false
required-features = ["integer", "internal-keycache"]
[[bench]]
name = "hlapi-bitonic-shuffle"
path = "benches/high_level_api/bitonic_shuffle.rs"
harness = false
required-features = ["integer", "internal-keycache"]
[[bench]]
name = "hlapi-dex"
path = "benches/high_level_api/dex.rs"

View File

@@ -0,0 +1,53 @@
use benchmark::params_aliases::BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
use criterion::{black_box, criterion_group, Criterion};
use rand::Rng;
use tfhe::core_crypto::seeders::UnixSeeder;
use tfhe::prelude::*;
use tfhe::{bitonic_shuffle, set_server_key, ClientKey, ConfigBuilder, FheUint64, ServerKey};
fn bitonic_shuffle_bench(c: &mut Criterion, bench_name: &str, cks: &ClientKey) {
let mut bench_group = c.benchmark_group(bench_name);
bench_group.sample_size(10);
let mut rng = rand::thread_rng();
for num_elements in [16, 32] {
let clear_values: Vec<u64> = (0..num_elements).map(|_| rng.gen()).collect();
let encrypted: Vec<FheUint64> = clear_values
.iter()
.map(|&v| FheUint64::encrypt(v, cks))
.collect();
let bench_id = format!("{bench_name}::n_{num_elements}");
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
let mut seeder = UnixSeeder::new(0);
let result = bitonic_shuffle(
encrypted.clone(),
tfhe::integer::server_key::BitonicShuffleKeySize::NumBlocks(16),
&mut seeder,
)
.expect("shuffle failed");
black_box(result);
})
});
}
bench_group.finish();
}
pub fn bitonic_shuffle_cpu(c: &mut Criterion) {
let param = BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let config = ConfigBuilder::with_custom_parameters(param).build();
let cks = ClientKey::generate(config);
let sks = ServerKey::new(&cks);
rayon::broadcast(|_| set_server_key(sks.clone()));
set_server_key(sks);
bitonic_shuffle_bench(c, "hlapi::bitonic_shuffle_cpu", &cks);
}
criterion_group!(bitonic_shuffle_group, bitonic_shuffle_cpu);
criterion::criterion_main!(bitonic_shuffle_group);

View File

@@ -1,6 +1,6 @@
[package]
name = "tfhe"
version = "1.6.1"
version = "1.6.0"
edition = "2021"
readme = "../README.md"
keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography"]
@@ -64,7 +64,7 @@ tfhe-fft = { version = "0.10.1", path = "../tfhe-fft", features = [
"serde",
"fft128",
] }
tfhe-ntt = { version = "0.7.1", path = "../tfhe-ntt" }
tfhe-ntt = { version = "0.7.0", path = "../tfhe-ntt" }
pulp = { workspace = true, features = ["default"] }
tfhe-cuda-backend = { version = "0.14.0", path = "../backends/tfhe-cuda-backend", optional = true }
aligned-vec = { workspace = true, features = ["default", "serde"] }

View File

@@ -75,11 +75,11 @@
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="420.0">121 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="420.0">165 ms</text>
<text dominant-baseline="middle" text-anchor="start" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="6" y="460.0">Leading / Trailing zeros/ones</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="460.0">67.2 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="460.0">70.6 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="510.0" y="460.0">89.8 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="460.0">92.6 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="460.0">113 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="460.0">88.4 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="460.0">148 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="510.0" y="460.0">169 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="460.0">222 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="460.0">275 ms</text>
<text dominant-baseline="middle" text-anchor="start" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="6" y="500.0">Log2</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="500.0">110 ms</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="500.0">163 ms</text>

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 16 KiB

View File

@@ -75,11 +75,11 @@
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="420.0">32.5 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="420.0">14.0 ops/s</text>
<text dominant-baseline="middle" text-anchor="start" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="6" y="460.0">Leading / Trailing zeros/ones</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="460.0">824 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="460.0">487 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="510.0" y="460.0">222 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="460.0">119 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="460.0">57.8 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="460.0">625 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="460.0">247 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="510.0" y="460.0">108 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="594.0" y="460.0">44.1 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="678.0" y="460.0">19.0 ops/s</text>
<text dominant-baseline="middle" text-anchor="start" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="6" y="500.0">Log2</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="342.0" y="500.0">542 ops/s</text>
<text dominant-baseline="middle" text-anchor="middle" font-family="Arial" font-size="14" font-weight="normal" fill="black" x="426.0" y="500.0">220 ops/s</text>

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 16 KiB

View File

@@ -74,7 +74,7 @@ To compile and execute GPU TFHE-rs programs, make sure your system has the follo
To use the **TFHE-rs** GPU backend in your project, add the following dependency in your `Cargo.toml`.
```toml
tfhe = { version = "~1.6.1", features = ["boolean", "shortint", "integer", "gpu"] }
tfhe = { version = "~1.6.0", features = ["boolean", "shortint", "integer", "gpu"] }
```
If none of the supported backends is configured in `Cargo.toml`, the CPU backend is used.

View File

@@ -17,7 +17,7 @@ This guide explains how to update your existing program to leverage HPU accelera
To use the **TFHE-rs** HPU backend in your project, add the following dependency in your `Cargo.toml`.
```toml
tfhe = { version = "~1.6.1", features = ["integer", "hpu-v80"] }
tfhe = { version = "~1.6.0", features = ["integer", "hpu-v80"] }
```
{% hint style="success" %}

View File

@@ -16,7 +16,7 @@ You can load serialized data with the `unversionize` function, even in newer ver
[dependencies]
# ...
tfhe = { version = "~1.6.1", features = ["integer"] }
tfhe = { version = "~1.6.0", features = ["integer"] }
tfhe-versionable = "0.6.0"
bincode = "1.3.3"
```

View File

@@ -161,7 +161,7 @@ In the following example, we use [bincode](https://crates.io/crates/bincode) for
[dependencies]
# ...
tfhe = { version = "~1.6.1", features = ["integer"] }
tfhe = { version = "~1.6.0", features = ["integer"] }
bincode = "1.3.3"
```

View File

@@ -19,7 +19,7 @@ The following example shows a complete workflow of working with encrypted arrays
# Cargo.toml
[dependencies]
tfhe = { version = "~1.6.1", features = ["integer"] }
tfhe = { version = "~1.6.0", features = ["integer"] }
```
```rust

View File

@@ -36,7 +36,7 @@ To serialize a `KVStore`, it must first be compressed.
# Cargo.toml
[dependencies]
tfhe = { version = "~1.6.1", features = ["integer"] }
tfhe = { version = "~1.6.0", features = ["integer"] }
```
```rust

View File

@@ -29,7 +29,7 @@ Here is an example:
# Cargo.toml
[dependencies]
tfhe = { version = "~1.6.1", features = ["integer", "strings"] }
tfhe = { version = "~1.6.0", features = ["integer", "strings"] }
```
```rust

View File

@@ -7,7 +7,7 @@ This document provides instructions to set up **TFHE-rs** in your project.
First, add **TFHE-rs** as a dependency in your `Cargo.toml`.
```toml
tfhe = { version = "~1.6.1", features = ["boolean", "shortint", "integer"] }
tfhe = { version = "~1.6.0", features = ["boolean", "shortint", "integer"] }
```
{% hint style="info" %}
@@ -35,7 +35,7 @@ By default, **TFHE-rs** makes the assumption that hardware AES features are enab
To add support for older CPU, import **TFHE-rs** with the `software-prng` feature in your `Cargo.toml`:
```toml
tfhe = { version = "~1.6.1", features = ["boolean", "shortint", "integer", "software-prng"] }
tfhe = { version = "~1.6.0", features = ["boolean", "shortint", "integer", "software-prng"] }
```
## Hardware acceleration

View File

@@ -59,7 +59,7 @@ edition = "2021"
Then add the following configuration to include **TFHE-rs**:
```toml
tfhe = { version = "~1.6.1", features = ["integer"] }
tfhe = { version = "~1.6.0", features = ["integer"] }
```
Your updated `Cargo.toml` file should look like this:
@@ -71,7 +71,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
tfhe = { version = "~1.6.1", features = ["integer"] }
tfhe = { version = "~1.6.0", features = ["integer"] }
```
If you are on a different platform please refer to the [installation documentation](installation.md) for configuration options of other supported platforms.

View File

@@ -9,7 +9,7 @@ Welcome to this tutorial about `TFHE-rs` `core_crypto` module.
To use `TFHE-rs`, it first has to be added as a dependency in the `Cargo.toml`:
```toml
tfhe = { version = "~1.6.1" }
tfhe = { version = "~1.6.0" }
```
### Commented code to double a 2-bit message in a leveled fashion and using a PBS with the `core_crypto` module.

View File

@@ -28,7 +28,7 @@ To use the `FheUint8` type, enable the `integer` feature:
# Cargo.toml
[dependencies]
tfhe = { version = "~1.6.1", features = ["integer"] }
tfhe = { version = "~1.6.0", features = ["integer"] }
```
The `MyFheString::encrypt` function performs data validation to ensure the input string contains only ASCII characters.
@@ -167,7 +167,7 @@ First, add the feature in your `Cargo.toml`
# Cargo.toml
[dependencies]
tfhe = { version = "~1.6.1", features = ["strings"] }
tfhe = { version = "~1.6.0", features = ["strings"] }
```
The `FheAsciiString` type allows to simply do homomorphic case changing of encrypted strings (and much more!):

View File

@@ -17,7 +17,7 @@ This function returns a Boolean (`true` or `false`) so that the total count of `
```toml
# Cargo.toml
tfhe = { version = "~1.6.1", features = ["integer"] }
tfhe = { version = "~1.6.0", features = ["integer"] }
```
First, define the verification function.

View File

@@ -49,6 +49,7 @@ pub use signed::{CompressedFheInt, FheInt, FheIntId, SquashedNoiseFheInt};
pub use unsigned::{CompressedFheUint, FheUint, FheUintId, SquashedNoiseFheUint};
pub mod oprf;
pub mod shuffle;
pub(super) mod signed;
pub(super) mod unsigned;

View File

@@ -0,0 +1,121 @@
use crate::high_level_api::global_state;
use crate::high_level_api::integers::FheIntegerType;
use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
use crate::integer::server_key::radix_parallel::bitonic_shuffle::BitonicShuffleKeySize;
use tfhe_csprng::seeders::Seeder;
/// Shuffles `data` into a uniformly random permutation using a bitonic
/// sorting network with OPRF-generated random keys.
///
/// `key_size` controls the bit-width of the random sort keys used internally,
/// either by specifying a target collision probability or by passing a raw
/// block count. Larger keys reduce collision probability (improving shuffle
/// uniformity) at the cost of more computation per comparison.
///
/// The re-randomization metadata of the input elements is not preserved
/// through the shuffle.
///
/// # Errors
///
/// Returns an error if the resolved key block count is 0, or if the Cuda/Hpu
/// backend is active (not yet supported).
pub fn bitonic_shuffle<T, S>(
data: Vec<T>,
key_size: BitonicShuffleKeySize,
seeder: &mut S,
) -> Result<Vec<T>, crate::Error>
where
T: FheIntegerType,
S: Seeder,
{
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(cpu_key) => {
let inner = data.into_iter().map(|v| v.into_cpu()).collect();
let result =
cpu_key
.pbs_key()
.bitonic_shuffle(&cpu_key.oprf_key(), inner, key_size, seeder)?;
Ok(result
.into_iter()
.map(|ct| T::from_cpu(ct, cpu_key.tag.clone(), ReRandomizationMetadata::default()))
.collect())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => Err(crate::Error::new(
"bitonic_shuffle is not supported on Cuda".to_string(),
)),
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_) => Err(crate::Error::new(
"bitonic_shuffle is not supported on Hpu".to_string(),
)),
})
}
#[cfg(test)]
mod test {
use super::{bitonic_shuffle, BitonicShuffleKeySize};
use crate::core_crypto::commons::generators::DeterministicSeeder;
use crate::core_crypto::prelude::new_seeder;
use crate::high_level_api::prelude::*;
use crate::high_level_api::tests::setup_default_cpu;
use crate::{FheInt8, FheUint8};
use rand::Rng;
use tfhe_csprng::generators::DefaultRandomGenerator;
#[test]
fn test_bitonic_shuffle_fheuint() {
let cks = setup_default_cpu();
let mut rng = rand::thread_rng();
let mut clear_values: Vec<u8> = (0..15).map(|_| rng.gen()).collect();
let encrypted: Vec<FheUint8> = clear_values
.iter()
.map(|&v| FheUint8::try_encrypt(v, &cks).unwrap())
.collect();
let seed = new_seeder().seed();
println!("seed: {seed:?}");
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
let shuffled = bitonic_shuffle(
encrypted,
BitonicShuffleKeySize::num_blocks(16),
&mut seeder,
)
.unwrap();
let mut decrypted: Vec<u8> = shuffled.iter().map(|ct| ct.decrypt(&cks)).collect();
clear_values.sort_unstable();
decrypted.sort_unstable();
assert_eq!(decrypted, clear_values);
}
#[test]
fn test_bitonic_shuffle_fheint() {
let cks = setup_default_cpu();
let mut rng = rand::thread_rng();
let mut clear_values: Vec<i8> = (0..15).map(|_| rng.gen()).collect();
let encrypted: Vec<FheInt8> = clear_values
.iter()
.map(|&v| FheInt8::try_encrypt(v, &cks).unwrap())
.collect();
let seed = new_seeder().seed();
println!("seed: {seed:?}");
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
let shuffled = bitonic_shuffle(
encrypted,
BitonicShuffleKeySize::num_blocks(16),
&mut seeder,
)
.unwrap();
let mut decrypted: Vec<i8> = shuffled.iter().map(|ct| ct.decrypt(&cks)).collect();
clear_values.sort_unstable();
decrypted.sort_unstable();
assert_eq!(decrypted, clear_values);
}
}

View File

@@ -47,8 +47,9 @@ macro_rules! export_concrete_array_types {
};
}
pub use crate::core_crypto::commons::math::random::{Seed, XofSeed};
pub use crate::core_crypto::commons::math::random::{Seed, Seeder, XofSeed};
pub use crate::high_level_api::integers::oprf::RangeForRandom;
pub use crate::high_level_api::integers::shuffle::bitonic_shuffle;
pub use crate::integer::server_key::MatchValues;
pub use crate::shortint::OprfSeed;
use crate::{error, Error, Versionize};
@@ -70,8 +71,8 @@ pub use integers::{
pub use keys::CudaServerKey;
pub use keys::{
generate_keys, ClientKey, CompactPublicKey, CompressedCompactPublicKey, CompressedPublicKey,
CompressedReRandomizationKey, CompressedReRandomizationKeySwitchingKey, CompressedServerKey,
KeySwitchingKey, PublicKey, ReRandomizationKey, ReRandomizationKeySwitchingKey, ServerKey,
CompressedReRandomizationKeySwitchingKey, CompressedServerKey, KeySwitchingKey, PublicKey,
ReRandomizationKeySwitchingKey, ServerKey,
};
use strum::FromRepr;

View File

@@ -20,7 +20,9 @@ pub use crate::shortint::CheckError;
use crate::shortint::{CarryModulus, MessageModulus};
pub use radix::scalar_mul::ScalarMultiplier;
pub use radix::scalar_sub::TwosComplementNegation;
pub use radix_parallel::{MatchValues, MiniUnsignedInteger, Reciprocable};
pub use radix_parallel::{
BitonicShuffleKeySize, CollisionProbability, MatchValues, MiniUnsignedInteger, Reciprocable,
};
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;

View File

@@ -247,14 +247,15 @@ impl ServerKey {
/// let res: u64 = cks.decrypt(&ct1);
/// assert_eq!(7, res);
/// ```
pub fn extend_radix_with_trivial_zero_blocks_msb_assign(
&self,
ct: &mut RadixCiphertext,
num_blocks: usize,
) {
pub fn extend_radix_with_trivial_zero_blocks_msb_assign<T>(&self, ct: &mut T, num_blocks: usize)
where
T: IntegerRadixCiphertext,
{
// Swap out blocks via a no-alloc empty sentinel (blocks_mut is a slice, can't resize)
let mut blocks = std::mem::replace(ct, T::from_blocks(vec![])).into_blocks();
let block_trivial_zero = self.key.create_trivial(0);
ct.blocks
.resize(ct.blocks.len() + num_blocks, block_trivial_zero);
blocks.resize(blocks.len() + num_blocks, block_trivial_zero);
*ct = T::from_blocks(blocks);
}
/// Append trivial zero MSB blocks to an existing [`RadixCiphertext`] and returns the result as

View File

@@ -0,0 +1,291 @@
//! Bitonic sorting network generator.
//!
//! A bitonic sorting network for n=2^k elements has k*(k+1)/2 stages,
//! each with n/2 comparators. It sorts any input sequence.
use crate::core_crypto::prelude::Container;
use crate::integer::oprf::GenericOprfServerKey;
use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, ServerKey};
use crate::shortint::MessageModulus;
use rayon::prelude::*;
use tfhe_csprng::seeders::Seeder;
use tfhe_fft::c64;
/// Generates a bitonic sorting network for n elements (n must be a power of 2).
///
/// Returns a list of stages, where each stage contains disjoint (i, j, ascending) triples.
/// Each triple represents a compare-and-swap: if ascending, put the smaller element at i;
/// if descending, put the larger element at i.
pub(crate) fn bitonic_network(n: usize) -> Vec<Vec<(usize, usize, bool)>> {
assert!(
n.is_power_of_two() && n >= 2,
"bitonic_network requires n to be a power of 2 and >= 2, got n={n}"
);
let log_n = n.trailing_zeros() as usize;
let mut stages = Vec::new();
for phase in 0..log_n {
for step in (0..=phase).rev() {
let mut comparators = Vec::new();
for i in 0..n {
let j = i ^ (1 << step);
if j > i {
let ascending = (i >> (phase + 1)) & 1 == 0;
comparators.push((i, j, ascending));
}
}
stages.push(comparators);
}
}
stages
}
/// Returns the next power of 2 >= n (or n itself if already a power of 2).
fn padded_size(n: usize) -> usize {
n.next_power_of_two()
}
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct CollisionProbability(f64);
impl CollisionProbability {
pub fn try_new(proba: f64) -> Option<Self> {
if 0.0 < proba && proba < 1.0 {
Some(Self(proba))
} else {
None
}
}
pub fn new(proba: f64) -> Self {
Self::try_new(proba).expect("Invalid probability, it must be in ]0, 1.0[")
}
}
#[derive(Copy, Clone, PartialEq, Debug)]
pub enum BitonicShuffleKeySize {
CollisionProbability(CollisionProbability),
NumBlocks(u32),
}
impl BitonicShuffleKeySize {
pub fn try_collision_probability(proba: f64) -> Option<Self> {
CollisionProbability::try_new(proba).map(Self::CollisionProbability)
}
pub fn collision_probability(proba: f64) -> Self {
Self::CollisionProbability(CollisionProbability::new(proba))
}
pub fn num_blocks(num_blocks: u32) -> Self {
Self::NumBlocks(num_blocks)
}
fn num_blocks_of_keys(&self, num_elements: usize, msg_mod: MessageModulus) -> u32 {
match self {
Self::CollisionProbability(CollisionProbability(proba)) => {
let n_squared = (num_elements * num_elements) as f64;
let bits = (n_squared / (2.0 * proba)).log2().ceil() as u32;
bits.div_ceil(msg_mod.0.ilog2())
}
Self::NumBlocks(n) => *n,
}
}
}
impl ServerKey {
/// Shuffles `data` into a uniformly random permutation using a bitonic sorting network
/// with random sort keys.
///
/// `key_size` controls the bit-width of the random sort keys used internally, either
/// by specifying a target collision probability or by passing a raw block count.
/// Larger keys reduce collision probability — and thus improve shuffle uniformity —
/// at the cost of more computation per comparison/swap.
///
/// # Errors
///
/// Returns an error if the resolved key block count is 0.
pub fn bitonic_shuffle<T, S, C>(
&self,
oprf_key: &GenericOprfServerKey<C>,
data: Vec<T>,
key_size: BitonicShuffleKeySize,
seeder: &mut S,
) -> Result<Vec<T>, crate::Error>
where
T: IntegerRadixCiphertext,
S: Seeder,
C: Container<Element = c64> + Sync,
{
let key_num_blocks = key_size.num_blocks_of_keys(data.len(), self.message_modulus()) as u64;
if key_num_blocks == 0 {
return Err(crate::Error::new(
"key_num_blocks must be at least 1".to_string(),
));
}
if data.len() <= 1 {
return Ok(data);
}
let keys: Vec<_> = (0..data.len())
.map(|_| {
oprf_key.par_generate_oblivious_pseudo_random_unsigned_integer(
seeder.seed(),
key_num_blocks,
self,
)
})
.collect();
self.bitonic_shuffle_with_keys(data, keys)
}
/// Shuffles `data` using a bitonic sorting network keyed by `keys`.
///
/// # Errors
///
/// Returns an error if `data` and `keys` have different lengths, or if
/// elements within `data` (or within `keys`) have inconsistent block counts.
pub fn bitonic_shuffle_with_keys<T>(
&self,
mut data: Vec<T>,
mut keys: Vec<RadixCiphertext>,
) -> Result<Vec<T>, crate::Error>
where
T: IntegerRadixCiphertext,
{
if data.len() != keys.len() {
return Err(crate::Error::new(format!(
"data and keys must have the same length, got {} and {}",
data.len(),
keys.len()
)));
}
if data.len() <= 1 {
return Ok(data);
}
let data_num_blocks = data[0].blocks().len();
if data[1..]
.iter()
.any(|d| d.blocks().len() != data_num_blocks)
{
return Err(crate::Error::new(
"all data elements must have the same number of blocks".to_string(),
));
}
let key_num_blocks = keys[0].blocks.len();
if keys[1..].iter().any(|k| k.blocks.len() != key_num_blocks) {
return Err(crate::Error::new(
"all keys must have the same number of blocks".to_string(),
));
}
rayon::join(
|| {
data.par_iter_mut()
.for_each(|value| self.clean_inplace_for_default_op(value));
},
|| {
keys.par_iter_mut()
.for_each(|value| self.clean_inplace_for_default_op(value));
},
);
let mut data = self.unchecked_bitonic_shuffle_with_keys(data, keys);
data.par_iter_mut().for_each(|radix| {
radix
.blocks_mut()
.par_iter_mut()
.for_each(|block| self.key.message_extract_assign(block))
});
Ok(data)
}
/// Performs a bitonic shuffle without cleaning inputs or outputs.
///
/// # Preconditions
///
/// * `data` and `keys` must have the same length and consistent block counts.
/// * Data blocks must have no carries and noise budget for `unchecked_flip_parallelized`.
/// * Key blocks must have no carries and noise budget for `unchecked_lt/gt`.
///
/// Output blocks have no carries but non-nominal noise level.
pub fn unchecked_bitonic_shuffle_with_keys<T>(
&self,
mut data: Vec<T>,
mut keys: Vec<RadixCiphertext>,
) -> Vec<T>
where
T: IntegerRadixCiphertext,
{
assert_eq!(
data.len(),
keys.len(),
"data.len()={} != keys.len()={}",
data.len(),
keys.len()
);
let n = data.len();
if n <= 1 {
return data;
}
let padded_n = padded_size(n);
let network = bitonic_network(padded_n);
let mut key_num_blocks = keys[0].blocks.len();
let data_num_blocks = data[0].blocks().len();
let pad = padded_n - n;
if pad > 0 {
// We need to pad with some trivial (key=MAX, data=0)
// However it could be that a key is already=MAX, so to protect us from that case
// we add an extra block to the keys
key_num_blocks += 1;
for key in &mut keys {
self.extend_radix_with_trivial_zero_blocks_msb_assign(key, 1);
}
for _ in 0..pad {
keys.push(self.create_trivial_max_radix(key_num_blocks));
data.push(self.create_trivial_zero_radix(data_num_blocks));
}
}
let mut stage_results = Vec::with_capacity(padded_n / 2);
for stage in network {
stage
.into_par_iter()
.map(|(i, j, ascending)| {
let cmp = if ascending {
self.unchecked_gt_parallelized(&keys[i], &keys[j])
} else {
self.unchecked_lt_parallelized(&keys[i], &keys[j])
};
let ((new_ki, new_kj), (new_di, new_dj)) = rayon::join(
|| self.unchecked_flip_parallelized(&cmp, &keys[i], &keys[j]),
|| self.unchecked_flip_parallelized(&cmp, &data[i], &data[j]),
);
(i, j, new_ki, new_kj, new_di, new_dj)
})
.collect_into_vec(&mut stage_results);
for (i, j, new_ki, new_kj, new_di, new_dj) in stage_results.drain(..) {
keys[i] = new_ki;
keys[j] = new_kj;
data[i] = new_di;
data[j] = new_dj;
}
}
data.truncate(n);
data
}
}

View File

@@ -121,96 +121,17 @@ where
"Inputs must have the same number of blocks"
);
// To make use if many_lut, we require 1 bit, 1 more bit is required to pack
// the condition. Thus 2 bits of carry are required.
//
// Otherwise we call if_then_else twice, which is less efficient.
if self.carry_modulus().0 < (1 << 2) {
return rayon::join(
|| self.if_then_else_parallelized(condition, b, a),
|| self.if_then_else_parallelized(condition, a, b),
);
}
let (a, b) = rayon::join(
|| self.clean_for_default_op(a),
|| self.clean_for_default_op(b),
);
let (mut a, mut b) = self.unchecked_flip_parallelized(condition, &*a, &*b);
a.blocks_mut()
.par_iter_mut()
.chain(b.blocks_mut().par_iter_mut())
.for_each(|block| self.key.message_extract_assign(block));
let zero_out_if_true_fn = |packed| {
let condition = (packed / self.message_modulus().0) & 1;
let value = packed % self.message_modulus().0;
(1 - condition) * value
};
let zero_out_if_false_fn = |packed| {
let condition = (packed / self.message_modulus().0) & 1;
let value = packed % self.message_modulus().0;
condition * value
};
let lut = self
.key
.generate_many_lookup_table(&[&zero_out_if_true_fn, &zero_out_if_false_fn]);
let scaled_condition = self
.key
.unchecked_scalar_mul(&condition.0, self.message_modulus().0 as u8);
let map_condition_lut_on_blocks =
|blocks: &[Ciphertext]| -> (Vec<Ciphertext>, Vec<Ciphertext>) {
let mut left = Vec::with_capacity(blocks.len());
let mut right = Vec::with_capacity(blocks.len());
blocks
.par_iter()
.map(|block| {
let block = self.key.unchecked_add(block, &scaled_condition);
let mut resulting_blocks = self.key.apply_many_lookup_table(&block, &lut);
let second_result = resulting_blocks.pop().unwrap();
let first_result = resulting_blocks.pop().unwrap();
(first_result, second_result)
})
.unzip_into_vecs(&mut left, &mut right);
(left, right)
};
let (
(mut a_blocks_if_cond, mut a_blocks_if_not_cond),
(b_blocks_if_cond, b_blocks_if_not_cond),
) = rayon::join(
|| map_condition_lut_on_blocks(a.blocks()),
|| map_condition_lut_on_blocks(b.blocks()),
);
let clean_lut = self
.key
.generate_lookup_table(|x| x % self.message_modulus().0);
let inplace_add_then_clean_blocks =
|lhs_blocks: &mut [Ciphertext], rhs_blocks: &[Ciphertext]| {
lhs_blocks
.par_iter_mut()
.zip(rhs_blocks.par_iter())
.for_each(|(lhs, rhs)| {
self.key.unchecked_add_assign(lhs, rhs);
self.key.apply_lookup_table_assign(lhs, &clean_lut);
});
};
rayon::join(
|| {
inplace_add_then_clean_blocks(&mut a_blocks_if_cond, &b_blocks_if_not_cond);
},
|| {
inplace_add_then_clean_blocks(&mut a_blocks_if_not_cond, &b_blocks_if_cond);
},
);
(
T::from_blocks(a_blocks_if_cond),
T::from_blocks(a_blocks_if_not_cond),
)
(a, b)
}
}
@@ -987,6 +908,122 @@ impl ServerKey {
});
}
/// Performs `if condition { (b, a) } else { (a, b) }` in fhe
///
/// * Blocks of both `a` and `b` must have no carries and noise_level <= 2
/// * Outputs will have no carries but noise_level == 2
pub(crate) fn unchecked_flip_parallelized<T>(
&self,
condition: &BooleanBlock,
a: &T,
b: &T,
) -> (T, T)
where
T: IntegerRadixCiphertext,
{
assert_eq!(
a.blocks().len(),
b.blocks().len(),
"Inputs must have the same number of blocks"
);
// To make use if many_lut, we require 1 bit, 1 more bit is required to pack
// the condition. Thus 2 bits of carry are required.
//
// Otherwise we call if_then_else twice, which is less efficient.
if self.carry_modulus().0 < (1 << 2) {
let unchecked_if_then_else_parallelized_no_cleanup =
|condition: &BooleanBlock, true_ct, false_ct| {
let condition_block = &condition.0;
let do_clean_message = false;
self.unchecked_programmable_if_then_else_parallelized(
condition_block,
true_ct,
false_ct,
|x| x == 1,
do_clean_message,
)
};
return rayon::join(
|| unchecked_if_then_else_parallelized_no_cleanup(condition, b, a),
|| unchecked_if_then_else_parallelized_no_cleanup(condition, a, b),
);
}
// We move the block message by one bit: it is the most flexible option (at least for 2_2)
// as it allows the block to have noise_level <= 2 or the condition block to have
// noise_level <= 2.
//
// A drawback is that everyblock needs to be cloned to be shifted, as opposed to
// only the condition block, but it's an ok drawback
assert!(a.blocks().iter().chain(b.blocks().iter()).all(|block| {
self.key
.max_noise_level
.validate(block.noise_level() * 2 + condition.0.noise_level())
.is_ok()
}));
let zero_out_if_true_fn = |packed| {
let condition = packed & 1;
let value = packed / 2;
(1 - condition) * value
};
let zero_out_if_false_fn = |packed| {
let condition = packed & 1;
let value = packed / 2;
condition * value
};
let lut = self
.key
.generate_many_lookup_table(&[&zero_out_if_true_fn, &zero_out_if_false_fn]);
let map_condition_lut_on_blocks =
|blocks: &[Ciphertext]| -> (Vec<Ciphertext>, Vec<Ciphertext>) {
let mut left = Vec::with_capacity(blocks.len());
let mut right = Vec::with_capacity(blocks.len());
blocks
.par_iter()
.map(|block| {
let mut block = self.key.unchecked_scalar_mul(block, 2);
self.key.unchecked_add_assign(&mut block, &condition.0);
let mut resulting_blocks = self.key.apply_many_lookup_table(&block, &lut);
let second_result = resulting_blocks.pop().unwrap();
let first_result = resulting_blocks.pop().unwrap();
(first_result, second_result)
})
.unzip_into_vecs(&mut left, &mut right);
(left, right)
};
let (
(mut a_blocks_if_cond, mut a_blocks_if_not_cond),
(b_blocks_if_cond, b_blocks_if_not_cond),
) = rayon::join(
|| map_condition_lut_on_blocks(a.blocks()),
|| map_condition_lut_on_blocks(b.blocks()),
);
for (a, b) in a_blocks_if_cond.iter_mut().zip(b_blocks_if_not_cond.iter()) {
self.key.unchecked_add_assign(a, b);
// By construction, one of the two input encrypts only zeros
a.degree.0 = self.message_modulus().0 - 1;
}
for (a, b) in a_blocks_if_not_cond.iter_mut().zip(b_blocks_if_cond.iter()) {
self.key.unchecked_add_assign(a, b);
// By construction, one of the two input encrypts only zeros
a.degree.0 = self.message_modulus().0 - 1;
}
(
T::from_blocks(a_blocks_if_cond),
T::from_blocks(a_blocks_if_not_cond),
)
}
fn scalar_flip_parallelized<T, Scalar>(
&self,
condition: &BooleanBlock,

View File

@@ -112,6 +112,7 @@ impl ServerKey {
/// * inputs must have the same number of blocks
/// * block carries of both inputs must be empty
/// * carry modulus == message modulus
/// * blocks of the inputs a and b must have a noise_level such that a[i] - b[i] is possible
fn compare<T>(&self, a: &T, b: &T, compare: ComparisonKind) -> BooleanBlock
where
T: IntegerRadixCiphertext,

View File

@@ -1,6 +1,7 @@
mod abs;
mod add;
mod bit_extractor;
pub(crate) mod bitonic_shuffle;
mod bitwise_op;
mod block_shift;
pub(crate) mod cmux;
@@ -47,6 +48,7 @@ use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::RadixCiphertext;
use crate::shortint::ciphertext::{Ciphertext, NoiseLevel};
pub(crate) use add::OutputFlag;
pub use bitonic_shuffle::{BitonicShuffleKeySize, CollisionProbability};
use rayon::prelude::*;
pub use scalar_div_mod::{MiniUnsignedInteger, Reciprocable};
pub use vector_find::MatchValues;
@@ -323,4 +325,18 @@ impl ServerKey {
Cow::Borrowed(ct)
}
}
/// Cleans the input inplace so that it is ready to be used in a default ops
pub(crate) fn clean_inplace_for_default_op<T>(&self, ct: &mut T)
where
T: IntegerRadixCiphertext,
{
if ct
.blocks()
.iter()
.any(|block| !block.carry_is_empty() || block.noise_level() != NoiseLevel::NOMINAL)
{
self.full_propagate_parallelized(ct);
}
}
}

View File

@@ -2,6 +2,7 @@ mod modulus_switch_compression;
pub(crate) mod test_add;
pub(crate) mod test_aes;
pub(crate) mod test_aes256;
pub(crate) mod test_bitonic_shuffle;
pub(crate) mod test_bitwise_op;
mod test_block_rotate;
mod test_block_shift;

View File

@@ -0,0 +1,141 @@
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::bitonic_shuffle::bitonic_network;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
use crate::integer::tests::create_parameterized_test;
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
#[cfg(tarpaulin)]
use crate::shortint::parameters::coverage_parameters::*;
use crate::shortint::parameters::test_params::*;
use crate::shortint::parameters::*;
use rand::Rng;
use std::sync::Arc;
/// Clear reference: sorts data by keys using the same bitonic network
/// as `ServerKey::bitonic_shuffle_with_keys`.
fn clear_bitonic_shuffle_with_keys(data: &[u32], keys: &[u32]) -> Vec<u32> {
assert_eq!(data.len(), keys.len());
let n = data.len();
if n <= 1 {
return data.to_vec();
}
let padded_n = n.next_power_of_two();
let mut keys = keys.iter().copied().map(u64::from).collect::<Vec<_>>();
let mut data = data.to_vec();
// Pad with MAX keys and zero data, so padding sorts to the end and is truncated
for _ in 0..(padded_n - n) {
keys.push(u64::MAX);
data.push(0u32);
}
let network = bitonic_network(padded_n);
for stage in &network {
let swaps: Vec<_> = stage
.iter()
.map(|&(i, j, ascending)| {
let should_swap = if ascending {
keys[i] > keys[j]
} else {
keys[i] < keys[j]
};
(i, j, should_swap)
})
.collect();
for (i, j, should_swap) in swaps {
if should_swap {
keys.swap(i, j);
data.swap(i, j);
}
}
}
data.truncate(n);
data
}
create_parameterized_test!(integer_bitonic_shuffle_with_keys);
fn integer_bitonic_shuffle_with_keys<P>(param: P)
where
P: Into<TestParameters>,
{
let executor =
CpuFunctionExecutor::new(&ServerKey::bitonic_shuffle_with_keys::<RadixCiphertext>);
bitonic_shuffle_with_keys_test(param, executor);
}
pub(crate) fn bitonic_shuffle_with_keys_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(Vec<RadixCiphertext>, Vec<RadixCiphertext>),
Result<Vec<RadixCiphertext>, crate::Error>,
>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let num_blocks = 32usize.div_ceil(cks.parameters().message_modulus().0.ilog2() as usize);
let cks = RadixClientKey::from((cks, num_blocks));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let mut rng = rand::thread_rng();
for _ in 0..4 {
let len = rng.gen_range(1..=16usize).next_power_of_two();
let clear_keys: Vec<u32> = (0..len).map(|_| rng.gen::<u32>()).collect();
let mut clear_data: Vec<u32> = (0..len).map(|_| rng.gen::<u32>()).collect();
println!("clear_keys: {clear_keys:?}, clear_data: {clear_data:?}");
let enc_keys = clear_keys.iter().map(|&v| cks.encrypt(v as u64)).collect();
let enc_data = clear_data.iter().map(|&v| cks.encrypt(v as u64)).collect();
let result = executor.execute((enc_data, enc_keys)).unwrap();
let mut decrypted: Vec<u32> = result
.iter()
.map(|ct| cks.decrypt::<u64>(ct) as u32)
.collect();
// Check the encrypted implementation matches the clear one
let expected = clear_bitonic_shuffle_with_keys(&clear_data, &clear_keys);
assert_eq!(decrypted, expected);
// Check that the permutation did not lose any of the data
decrypted.sort_unstable();
clear_data.sort_unstable();
assert_eq!(decrypted, clear_data, "permutation lost data");
}
{
let len = 17;
let mut clear_keys: Vec<u32> = (0..len).map(|_| rng.gen::<u32>()).collect();
let mut clear_data: Vec<u32> = (0..len).map(|_| rng.gen::<u32>()).collect();
clear_keys[3] = u32::MAX;
assert!(!clear_keys.len().is_power_of_two());
println!("clear_keys: {clear_keys:?}, clear_data: {clear_data:?}");
let enc_keys = clear_keys.iter().map(|&v| cks.encrypt(v as u64)).collect();
let enc_data = clear_data.iter().map(|&v| cks.encrypt(v as u64)).collect();
let result = executor.execute((enc_data, enc_keys)).unwrap();
let mut decrypted: Vec<u32> = result
.iter()
.map(|ct| cks.decrypt::<u64>(ct) as u32)
.collect();
// Check the encrypted implementation matches the clear one
let expected = clear_bitonic_shuffle_with_keys(&clear_data, &clear_keys);
assert_eq!(decrypted, expected);
// Check that the permutation did not lose any of the data
decrypted.sort_unstable();
clear_data.sort_unstable();
assert_eq!(decrypted, clear_data, "permutation lost data");
}
}

View File

@@ -976,7 +976,7 @@ dependencies = [
[[package]]
name = "tfhe"
version = "1.6.1"
version = "1.6.0"
dependencies = [
"aligned-vec",
"bincode",