mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
3 Commits
pa/fix/siz
...
cm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d220008757 | ||
|
|
b1b55b6426 | ||
|
|
77bea74ac9 |
@@ -1,6 +1,6 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["tfhe", "tasks", "apps/trivium", "concrete-csprng"]
|
||||
members = ["tfhe", "tasks", "apps/trivium", "concrete-csprng", "concrete-float"]
|
||||
|
||||
[profile.bench]
|
||||
lto = "fat"
|
||||
|
||||
90
Makefile
90
Makefile
@@ -149,6 +149,11 @@ fix_newline: check_linelint_installed
|
||||
check_newline: check_linelint_installed
|
||||
linelint .
|
||||
|
||||
.PHONY: clippy_float # Run clippy lints on core_crypto with and without experimental features
|
||||
clippy_float: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
-p concrete-float -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_core # Run clippy lints on core_crypto with and without experimental features
|
||||
clippy_core: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
@@ -478,6 +483,59 @@ test_concrete_csprng:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-csprng
|
||||
|
||||
.PHONY: test_float # Run minifloat bivariate test
|
||||
test_float: test_float_add test_float_sub test_float_mul test_float_div test_float_cos test_float_sin test_float_relu test_float_sigmoid test_minifloat
|
||||
|
||||
.PHONY: test_minifloat # Run minifloat bivariate test
|
||||
test_minifloat:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE),shortint -p tfhe float_wopbs_bivariate -- --nocapture
|
||||
|
||||
.PHONY: test_float_cos # Run floating points cosine test
|
||||
test_float_cos:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::float_cos" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_sin # Run floating points sine test
|
||||
test_float_sin:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::float_sin" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_mul # Run floating points multiplication test
|
||||
test_float_mul:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_mul" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_add # Run floating points addition test
|
||||
test_float_add:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_add" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_sub # Run floating points subtraction test
|
||||
test_float_sub:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_sub" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_div # Run floating points division test
|
||||
test_float_div:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_div" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_relu # Run floating points relu test
|
||||
test_float_relu:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_relu" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_sigmoid # Run floating points sigmoid test
|
||||
test_float_sigmoid:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_sigmoid" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_depth_test # Run floating points depth test
|
||||
test_float_depth_test:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::depth_test_parallelized" -- --exact --nocapture
|
||||
|
||||
.PHONY: doc # Build rust doc
|
||||
doc: install_rs_check_toolchain
|
||||
RUSTDOCFLAGS="--html-in-header katex-header.html" \
|
||||
@@ -631,6 +689,38 @@ ci_bench_web_js_api_parallel: build_web_js_api_parallel
|
||||
nvm use node && \
|
||||
$(MAKE) -C tfhe/web_wasm_parallel_tests bench-ci
|
||||
|
||||
.PHONY: bench_float # Run benchmarks for the floating points
|
||||
bench_float: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench
|
||||
|
||||
.PHONY: bench_float_8bit # Run benchmarks for the floating points
|
||||
bench_float_8bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench -- PARAM_8
|
||||
|
||||
|
||||
.PHONY: bench_float_16bit # Run benchmarks for the floating points
|
||||
bench_float_16bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench -- PARAM_16
|
||||
|
||||
|
||||
.PHONY: bench_float_32bit # Run benchmarks for the floating points
|
||||
bench_float_32bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench -- PARAM_32
|
||||
|
||||
.PHONY: bench_float_64bit # Run benchmarks for the floating points
|
||||
bench_float_64bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench -- PARAM_64
|
||||
|
||||
.PHONY: bench_minifloat # Run benchmarks for Wopbs floating points
|
||||
bench_minifloat: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-wopbs-bench
|
||||
|
||||
#
|
||||
# Utility tools
|
||||
#
|
||||
|
||||
297
README.md
297
README.md
@@ -1,175 +1,160 @@
|
||||
<p align="center">
|
||||
<!-- product name logo -->
|
||||
<img width=600 src="https://user-images.githubusercontent.com/5758427/231206749-8f146b97-3c5a-4201-8388-3ffa88580415.png">
|
||||
</p>
|
||||
<hr/>
|
||||
<p align="center">
|
||||
<a href="https://docs.zama.ai/tfhe-rs"> 📒 Read documentation</a> | <a href="https://zama.ai/community"> 💛 Community support</a>
|
||||
</p>
|
||||
<p align="center">
|
||||
<!-- Version badge using shields.io -->
|
||||
<a href="https://github.com/zama-ai/tfhe-rs/releases">
|
||||
<img src="https://img.shields.io/github/v/release/zama-ai/tfhe-rs?style=flat-square">
|
||||
</a>
|
||||
<!-- Zama Bounty Program -->
|
||||
<a href="https://github.com/zama-ai/bounty-program">
|
||||
<img src="https://img.shields.io/badge/Contribute-Zama%20Bounty%20Program-yellow?style=flat-square">
|
||||
</a>
|
||||
</p>
|
||||
<hr/>
|
||||
# Artifact:TFHE Gets Real: an Efficient and Flexible Homomorphic Floating-Point Arithmetic
|
||||
|
||||
|
||||
**TFHE-rs** is a pure Rust implementation of TFHE for boolean and integer
|
||||
arithmetics over encrypted data. It includes:
|
||||
- a **Rust** API
|
||||
- a **C** API
|
||||
- and a **client-side WASM** API
|
||||
## Description
|
||||
|
||||
**TFHE-rs** is meant for developers and researchers who want full control over
|
||||
what they can do with TFHE, while not having to worry about the low level
|
||||
implementation. The goal is to have a stable, simple, high-performance, and
|
||||
production-ready library for all the advanced features of TFHE.
|
||||
|
||||
## Getting Started
|
||||
The steps to run a first example are described below.
|
||||
In what follows, we provide instructions on how to run the benchmarks from the paper entitled **TFHE Gets Real: An Efficient and Flexible Homomorphic Floating-Point Arithmetic**.
|
||||
In particular, the benchmarks presented in **Table 5**, **Table 6**, **Table 7**, and the experiments shown in **Table 8** can be easily reproduced using this code. The implementation of the techniques described in the aforementioned paper has been integrated into the **TFHE-rs** library, version 0.5.0. The modified or added source files are organized into two different paths.
|
||||
|
||||
### Cargo.toml configuration
|
||||
To use the latest version of `TFHE-rs` in your project, you first need to add it as a dependency in your `Cargo.toml`:
|
||||
The Minifloats (Section 3.1) are located in *tfhe/src/float-wopbs*
|
||||
- Test files are located in *tfhe/src/float_wopbs/server_key/tests.rs*
|
||||
- Benchmarks are located in *tfhe/benches/float_wopbs/bench.rs*
|
||||
|
||||
+ For x86_64-based machines running Unix-like OSes:
|
||||
|
||||
```toml
|
||||
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix"] }
|
||||
The homomorphic floating points (Section 3.2) are located in *tfhe/concrete-float/*
|
||||
- Test files are located *tfhe/concrete-float/src/server_key/tests.rs*
|
||||
- Benchmarks are located in *tfhe/concrete-float/benches/bench.rs*
|
||||
|
||||
|
||||
## Dependencies
|
||||
|
||||
Tested on Linux and Mac OS with Rust version >= 1.80 (see [here](https://www.rust-lang.org/tools/install) a guide to install Rust).
|
||||
Complete list of dependencies and a guide on how to install TFHE-rs can be found in the online documentation [here](https://docs.zama.ai/tfhe-rs/0.5-3/getting-started/installation) or in the local file [here](./README_TFHE-rs.md).
|
||||
|
||||
## How to run benchmarks
|
||||
At the root of the project (i.e., in the TFHE-rs folder), enter the following commands to run the benchmarks:
|
||||
|
||||
- ```make bench_minifloat```: returns the timings associated to the Minifloats (**Table 6**).
|
||||
- ```make bench_float```: returns the timings associated to the HFP (**Table 5**, **Table 7**).
|
||||
These benchmarks first launch the parallelized and then the sequential experiments.
|
||||
This outputs the timings depending on the input precision.
|
||||
**This takes more than 6 hours to run**.
|
||||
|
||||
To run benchmarks for a specific precision over homomorphic floating points, here are the dedicated commands:
|
||||
- ```make bench_float_8bit```: Runs benchmarks for only 8-bit floating point *(around 15 min)*.
|
||||
- ```make bench_float_16bit```: Runs benchmarks for only 16-bit floating point *(around 30 min)*.
|
||||
- ```make bench_float_32bit```: Runs benchmarks for only 32-bit floating point *(around 1h40)*.
|
||||
- ```make bench_float_64bit```: Runs benchmarks for only 64-bit floating point *(around 6h30)*.
|
||||
|
||||
|
||||
We recall that the benchmarks were performed on AWS using an **m6i.metal** instance with an Intel Xeon 8375C (Ice Lake) processor running at 3.5 GHz, 128 vCPUs, and 512 GiB of memory.
|
||||
|
||||
### Understanding Benchmark Output (Criterion.rs)
|
||||
|
||||
This project uses [Criterion.rs](https://docs.rs/criterion/latest/criterion/) for benchmarking. Criterion is a powerful and statistically robust benchmarking framework for Rust, and it may produce outputs that are unfamiliar at first glance. This section explains how to interpret them.
|
||||
|
||||
#### Sample Output Structure
|
||||
|
||||
A typical benchmark result looks like this:
|
||||
|
||||
```
|
||||
test_float time: [53.2 µs 54.0 µs 54.8 µs]
|
||||
change: [+0.2% +1.0% +1.8%] (p = 0.002)
|
||||
Found 3 outliers among 100 measurements (3.00%)
|
||||
3 (3.00%) high mild
|
||||
```
|
||||
|
||||
+ For Apple Silicon or aarch64-based machines running Unix-like OSes:
|
||||
**Here's what this means:**
|
||||
|
||||
```toml
|
||||
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "aarch64-unix"] }
|
||||
- `time: [low est. median high est.]`: The estimated execution time of the function.
|
||||
- `change`: The performance change compared to a previous run (if available).
|
||||
- `outliers`: Some runs deviated from the typical time. Criterion detects and accounts for these using statistical methods.
|
||||
|
||||
---
|
||||
|
||||
#### Common Warnings and What They Mean
|
||||
|
||||
##### `Found X outliers among Y measurements`
|
||||
|
||||
Criterion runs each benchmark many times (default: 100) to get statistically significant results.
|
||||
An *outlier* is a run that was significantly faster or slower than the others.
|
||||
|
||||
- **Why does this happen?** Often, it's due to **other processes on the machine** (e.g., background services, OS interrupts, or CPU scheduling) affecting performance temporarily.
|
||||
- **Why it doesn't invalidate results:** Criterion uses statistical techniques to minimize the impact of these outliers when estimating performance.
|
||||
- **Best practice to reduce outliers:** Run the benchmarks on a **freshly rebooted machine**, with as few background processes as possible. Ideally, let the system idle for a minute after boot to stabilize before running benchmarks.
|
||||
|
||||
##### `Unable to complete 100 samples in 5.0s.`
|
||||
|
||||
The benchmark took longer than the expected 5 seconds.
|
||||
This is merely a warning indicating that the full set of 100 samples could not be collected within the default 5-second measurement window.
|
||||
|
||||
- **No action is required**: Criterion will still proceed to run all 100 samples, and the results remain statistically valid.
|
||||
- **Why the warning appears**: It's there to inform you that benchmarking is taking longer than expected and to help you tune settings if needed.
|
||||
- **Optional**: If you're constrained by time (e.g., running in CI), you can:
|
||||
- Reduce the sample size (e.g., to 10 or 20 samples).
|
||||
- Or increase the measurement time using:
|
||||
```bash
|
||||
cargo bench -- --measurement-time 30
|
||||
```
|
||||
|
||||
## How to run the tests
|
||||
### MiniFloats
|
||||
|
||||
To run the tests related to the **minifloats**, run the following command:
|
||||
- ```make test_minifloat```: Runs a bivariate operation between two minifloats.
|
||||
|
||||
|
||||
The **minifloat** test is available in the file *tfhe/src/float_wopbs/server_key/tests.rs*.
|
||||
|
||||
|
||||
|
||||
### Homomorphic Floating Points
|
||||
At the root of the project (i.e., in the TFHE-rs folder), enter the following commands to run the tests per operation on the **homomorphic floating points**:
|
||||
- ```make test_float_add```: Runs a 32-bit floating-point addition with two random inputs.
|
||||
- ```make test_float_sub```: Runs a 32-bit floating-point subtraction with two random inputs.
|
||||
- ```make test_float_mul```: Runs a 32-bit floating-point multiplication with two random inputs.
|
||||
- ```make test_float_div```: Runs a 32-bit floating-point division with two random inputs.
|
||||
- ```make test_float_cos```: Runs the experiment from **Table 8** with a random input value.
|
||||
- ```make test_float_sin```: Runs the experiment from **Table 8** with a random input value.
|
||||
- ```make test_float_relu```: Runs a 32-bit floating-point relu with a random input.
|
||||
- ```make test_float_sigmoid```: Runs a 32-bit floating-point sigmoid with a random input.
|
||||
- ```make test_float```: Runs all previous tests for operations on 32-bit floating-points.
|
||||
- ```make test_float_depth_test```: This command runs the following experiment:
|
||||
- **Step 1**: Create 3 blocks, each composed of a clear 32-bit floating point, a clear 64-bit floating point, and a 32-bit homomorphic floating point.
|
||||
- **Step 2**: Choose two blocks randomly among the 3 blocks and randomly select a parallelized operation (addition, subtraction, or multiplication).
|
||||
- **Step 3**: Compute the selected operation between the two selected blocks and store the result randomly in one of the two selected blocks.
|
||||
(The operation is performed respectively between the two 64-bit floating points, the two 32-bit floating points, and homomorphically between the two 32-bit homomorphic floating points.)
|
||||
- Repeat Steps 2 and 3 for 50 iterations.
|
||||
- To avoid reaching + or - infinity, or **NaN**, when the clear 64-bit floating point reaches a fixed bound, compute a multiplication to rescale the value close to 1.
|
||||
This operation is also performed homomorphically for the encrypted data. This test takes several minutes.
|
||||
|
||||
The tests are located in the file *tfhe/concrete-float/src/server_key/tests.rs*.
|
||||
|
||||
Due to the representation being close to, but not exactly the same as, a given representation, the obtained result is not identical to the one obtained in clear.
|
||||
To consider a test as "passed", we accept a difference of less than 0.1% compared to the 64-bit floating-point clear results.
|
||||
Note that using 8 or 16-bit homomorphic floating points might return errors due to a lack of precision and due to the comparisons with clear 64-bit floating points.
|
||||
|
||||
In each test, the different results are presented in the following format:
|
||||
```
|
||||
--------------------
|
||||
"Name":
|
||||
|
||||
Result :
|
||||
Clear 32-bits:
|
||||
Clear 64-bits:
|
||||
|
||||
--------------------
|
||||
```
|
||||
Note: users with ARM devices must compile `TFHE-rs` using a stable toolchain with version >= 1.72.
|
||||
where ```name``` stands for the name of the ciphertext or the name of the operation, result always corresponds to the decryption of a homomorphic floating point, and Clear ``` 32-bits``` and Clear ``` 64-bits``` correspond to the clear floating-point witness.
|
||||
|
||||
All tests in *tfhe/concrete-float/src/server_key/tests.rs* are conducted for 32-bit floating-point precision, as it provides the best ratio between execution time and precision.
|
||||
To change the parameter set used, the parameters in the following ``` const ``` must be uncommented (lines 79 to 87 in the file *tfhe/concrete-float/src/server_key/tests.rs*).
|
||||
|
||||
|
||||
+ For x86_64-based machines with the [`rdseed instruction`](https://en.wikipedia.org/wiki/RDRAND)
|
||||
running Windows:
|
||||
|
||||
```toml
|
||||
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64"] }
|
||||
```rust
|
||||
const PARAMS: [(&str, Parameters); 1] =
|
||||
[
|
||||
//named_param!(PARAM_FP_64_BITS),
|
||||
named_param!(PARAM_FP_32_BITS),
|
||||
//named_param!(PARAM_FP_16_BITS),
|
||||
//named_param!(PARAM_FP_8_BITS),
|
||||
];
|
||||
```
|
||||
|
||||
Note: aarch64-based machines are not yet supported for Windows as it's currently missing an entropy source to be able to seed the [CSPRNGs](https://en.wikipedia.org/wiki/Cryptographically_secure_pseudorandom_number_generator) used in TFHE-rs
|
||||
Note that the number in ``` [(\&str, Parameters); 1] ``` should correspond to the number of tested parameters, e.g., if another parameter sets is uncommented, this line becomes: ``` [(\&str, Parameters); 2] ```.
|
||||
The parameter ```PARAM_X``` corresponds to the parameters used in **Table 5**, and ```PARAM_TCHES_X``` corresponds to the parameters used in **Table 7**.
|
||||
|
||||
|
||||
## A simple example
|
||||
|
||||
Here is a full example:
|
||||
|
||||
``` rust
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint32, FheUint8};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Basic configuration to use homomorphic integers
|
||||
let config = ConfigBuilder::default().build();
|
||||
|
||||
// Key generation
|
||||
let (client_key, server_keys) = generate_keys(config);
|
||||
|
||||
let clear_a = 1344u32;
|
||||
let clear_b = 5u32;
|
||||
let clear_c = 7u8;
|
||||
|
||||
// Encrypting the input data using the (private) client_key
|
||||
// FheUint32: Encrypted equivalent to u32
|
||||
let mut encrypted_a = FheUint32::try_encrypt(clear_a, &client_key)?;
|
||||
let encrypted_b = FheUint32::try_encrypt(clear_b, &client_key)?;
|
||||
|
||||
// FheUint8: Encrypted equivalent to u8
|
||||
let encrypted_c = FheUint8::try_encrypt(clear_c, &client_key)?;
|
||||
|
||||
// On the server side:
|
||||
set_server_key(server_keys);
|
||||
|
||||
// Clear equivalent computations: 1344 * 5 = 6720
|
||||
let encrypted_res_mul = &encrypted_a * &encrypted_b;
|
||||
|
||||
// Clear equivalent computations: 1344 >> 5 = 42
|
||||
encrypted_a = &encrypted_res_mul >> &encrypted_b;
|
||||
|
||||
// Clear equivalent computations: let casted_a = a as u8;
|
||||
let casted_a: FheUint8 = encrypted_a.cast_into();
|
||||
|
||||
// Clear equivalent computations: min(42, 7) = 7
|
||||
let encrypted_res_min = &casted_a.min(&encrypted_c);
|
||||
|
||||
// Operation between clear and encrypted data:
|
||||
// Clear equivalent computations: 7 & 1 = 1
|
||||
let encrypted_res = encrypted_res_min & 1_u8;
|
||||
|
||||
// Decrypting on the client side:
|
||||
let clear_res: u8 = encrypted_res.decrypt(&client_key);
|
||||
assert_eq!(clear_res, 1_u8);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
To run this code, use the following command:
|
||||
<p align="center"> <code> cargo run --release </code> </p>
|
||||
|
||||
Note that when running code that uses `tfhe-rs`, it is highly recommended
|
||||
to run in release mode with cargo's `--release` flag to have the best performances possible,
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
There are two ways to contribute to TFHE-rs:
|
||||
|
||||
- you can open issues to report bugs or typos, or to suggest new ideas
|
||||
- you can ask to become an official contributor by emailing [hello@zama.ai](mailto:hello@zama.ai).
|
||||
(becoming an approved contributor involves signing our Contributor License Agreement (CLA))
|
||||
|
||||
Only approved contributors can send pull requests, so please make sure to get in touch before you do!
|
||||
|
||||
## Credits
|
||||
|
||||
This library uses several dependencies and we would like to thank the contributors of those
|
||||
libraries.
|
||||
|
||||
## Need support?
|
||||
<a target="_blank" href="https://community.zama.ai">
|
||||
<img src="https://user-images.githubusercontent.com/5758427/231115030-21195b55-2629-4c01-9809-be5059243999.png">
|
||||
</a>
|
||||
|
||||
## Citing TFHE-rs
|
||||
|
||||
To cite TFHE-rs in academic papers, please use the following entry:
|
||||
|
||||
```text
|
||||
@Misc{TFHE-rs,
|
||||
title={{TFHE-rs: A Pure Rust Implementation of the TFHE Scheme for Boolean and Integer Arithmetics Over Encrypted Data}},
|
||||
author={Zama},
|
||||
year={2022},
|
||||
note={\url{https://github.com/zama-ai/tfhe-rs}},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This software is distributed under the BSD-3-Clause-Clear license. If you have any questions,
|
||||
please contact us at `hello@zama.ai`.
|
||||
|
||||
## Disclaimers
|
||||
|
||||
### Security Estimation
|
||||
|
||||
Security estimations are done using the
|
||||
[Lattice Estimator](https://github.com/malb/lattice-estimator)
|
||||
with `red_cost_model = reduction.RC.BDGL16`.
|
||||
|
||||
When a new update is published in the Lattice Estimator, we update parameters accordingly.
|
||||
|
||||
### Side-Channel Attacks
|
||||
|
||||
Mitigation for side channel attacks have not yet been implemented in TFHE-rs,
|
||||
and will be released in upcoming versions.
|
||||
|
||||
69
concrete-float/Cargo.toml
Normal file
69
concrete-float/Cargo.toml
Normal file
@@ -0,0 +1,69 @@
|
||||
[package]
|
||||
name = "concrete-float"
|
||||
version = "0.1.0-beta.0"
|
||||
edition = "2018"
|
||||
authors = ["Zama team"]
|
||||
license = "BSD-3-Clause-Clear"
|
||||
description = "Homomorphic Integer circuit interface for the concrete FHE library."
|
||||
homepage = "https://www.zama.ai/concrete-framework"
|
||||
documentation = "https://docs.zama.ai/home/"
|
||||
repository = "https://github.com/zama-ai/concrete"
|
||||
readme = "README.md"
|
||||
keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography"]
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
aligned-vec = { version = "0.5", features = ["serde"] }
|
||||
dyn-stack = { version = "0.9" }
|
||||
rayon = "1.5"
|
||||
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
|
||||
tfhe = { path = "../tfhe", features = ["shortint", "integer"] }
|
||||
|
||||
[target.'cfg(target_arch = "x86_64")'.dependencies]
|
||||
tfhe = { path = "../tfhe", features = ["shortint", "integer", "x86_64-unix"] }
|
||||
|
||||
[target.'cfg(target_arch = "aarch64")'.dependencies]
|
||||
tfhe = { path = "../tfhe", features = ["shortint", "integer", "aarch64-unix"] }
|
||||
|
||||
[features]
|
||||
nightly-avx512 = ["tfhe/nightly-avx512"]
|
||||
seeder_x86_64_rdseed = []
|
||||
seeder_unix = []
|
||||
generator_x86_64_aesni = []
|
||||
generator_fallback = []
|
||||
generator_aarch64_aes = []
|
||||
|
||||
x86_64 = [
|
||||
"seeder_x86_64_rdseed",
|
||||
"generator_x86_64_aesni",
|
||||
"generator_fallback",
|
||||
]
|
||||
x86_64-unix = ["x86_64", "seeder_unix"]
|
||||
aarch64 = [ "generator_aarch64_aes", "generator_fallback"]
|
||||
aarch64-unix = ["aarch64", "seeder_unix"]
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5.1"
|
||||
lazy_static = "1.4.0"
|
||||
bincode = "1.3.3"
|
||||
paste = "1.0.7"
|
||||
rand = "0.8.4"
|
||||
doc-comment = "0.3.3"
|
||||
#concrete-shortint = { path = "../tfhe", features = ["internal-keycache"] }
|
||||
|
||||
#[features]
|
||||
# Keychache used to speed up tests and benches
|
||||
# by not requiring to regererate keys at each launch
|
||||
#internal-keycache = ["lazy_static", "shortint/src/internal-keycache"]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
rustdoc-args = ["--html-in-header", "katex-header.html"]
|
||||
|
||||
[[bench]]
|
||||
name = "float-bench"
|
||||
path = "benches/bench.rs"
|
||||
harness = false
|
||||
required-features = []
|
||||
32
concrete-float/LICENSE
Normal file
32
concrete-float/LICENSE
Normal file
@@ -0,0 +1,32 @@
|
||||
BSD 3-Clause Clear License
|
||||
|
||||
Copyright © 2022 ZAMA.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this
|
||||
list of conditions and the following disclaimer in the documentation and/or other
|
||||
materials provided with the distribution.
|
||||
|
||||
3. Neither the name of ZAMA nor the names of its contributors may be used to endorse
|
||||
or promote products derived from this software without specific prior written permission.
|
||||
NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE*.
|
||||
THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
ZAMA OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
|
||||
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
|
||||
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
||||
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*In addition to the rights carried by this license, ZAMA grants to the user a non-exclusive,
|
||||
free and non-commercial license on all patents filed in its name relating to the open-source
|
||||
code (the "Patents") for the sole purpose of evaluation, development, research, prototyping
|
||||
and experimentation.
|
||||
11
concrete-float/README.md
Normal file
11
concrete-float/README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# concrete Integer
|
||||
|
||||
`concrete-integer` is a Rust library built on top of `concrete-shortint`, it
|
||||
combines multiple `shortint` to handle encrypted integers of "arbitrary"
|
||||
size.
|
||||
|
||||
## License
|
||||
|
||||
This software is distributed under the BSD-3-Clause-Clear license. If you have any questions,
|
||||
please contact us at `hello@zama.ai`.
|
||||
|
||||
304
concrete-float/benches/bench.rs
Normal file
304
concrete-float/benches/bench.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use concrete_float::gen_keys;
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use rand::Rng;
|
||||
|
||||
// Previous Parameters
|
||||
#[allow(unused_imports)]
|
||||
use concrete_float::parameters::{FINAL_PARAM_16,
|
||||
FINAL_PARAM_2_2_32, FINAL_PARAM_32,
|
||||
FINAL_PARAM_64, FINAL_PARAM_8,
|
||||
FINAL_WOP_PARAM_15, FINAL_WOP_PARAM_16,
|
||||
FINAL_WOP_PARAM_2_2_32, FINAL_WOP_PARAM_32,
|
||||
FINAL_WOP_PARAM_64, FINAL_WOP_PARAM_8,
|
||||
FINAL_PARAM_64_TCHESS, FINAL_PARAM_32_TCHESS,
|
||||
FINAL_WOP_PARAM_64_TCHESS, FINAL_WOP_PARAM_32_TCHESS};
|
||||
|
||||
use concrete_float::parameters::{FINAL_PARAM_16_BIS, FINAL_PARAM_32_BIS,
|
||||
FINAL_PARAM_64_BIS, FINAL_PARAM_8_BIS,
|
||||
FINAL_WOP_PARAM_16_BIS, FINAL_WOP_PARAM_32_BIS,
|
||||
FINAL_WOP_PARAM_64_BIS, FINAL_WOP_PARAM_8_BIS};
|
||||
use tfhe::shortint;
|
||||
|
||||
macro_rules! named_param {
|
||||
($param:ident) => {
|
||||
(stringify!($param), $param)
|
||||
};
|
||||
}
|
||||
|
||||
criterion_main!(float_parallelized, float);
|
||||
|
||||
struct Parameters {
|
||||
pbsparameters: shortint::ClassicPBSParameters,
|
||||
wopbsparameters: shortint::WopbsParameters,
|
||||
len_man: usize,
|
||||
len_exp: usize,
|
||||
}
|
||||
|
||||
//Parameter for a Floating point 64-bits equivalent
|
||||
const PARAM_64: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_64_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_64_BIS,
|
||||
len_man: 27,
|
||||
len_exp: 5,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 32-bits equivalent
|
||||
const PARAM_32: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_32_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_32_BIS,
|
||||
len_man: 13,
|
||||
len_exp: 4,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 16-bits equivalent
|
||||
const PARAM_16: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_16_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_16_BIS,
|
||||
len_man: 6,
|
||||
len_exp: 3,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 8-bits equivalent
|
||||
const PARAM_8: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_8_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_8_BIS,
|
||||
len_man: 3,
|
||||
len_exp: 2,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 64-bits equivalent
|
||||
//With failure probability smaller than PARAM_64
|
||||
const PARAM_TCHESS_64: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_64_TCHESS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_64_TCHESS,
|
||||
len_man: 27,
|
||||
len_exp: 5,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 32-bits equivalent
|
||||
//With failure probability smaller than PARAM_32
|
||||
const PARAM_TCHESS_32: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_32_TCHESS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_32_TCHESS,
|
||||
len_man: 13,
|
||||
len_exp: 4,
|
||||
};
|
||||
|
||||
|
||||
const SERVER_KEY_BENCH_PARAMS: [(&str, Parameters);6] =
|
||||
[
|
||||
named_param!(PARAM_8),
|
||||
named_param!(PARAM_16),
|
||||
named_param!(PARAM_32),
|
||||
named_param!(PARAM_64),
|
||||
named_param!(PARAM_TCHESS_32),
|
||||
named_param!(PARAM_TCHESS_64),
|
||||
];
|
||||
|
||||
criterion_group!(
|
||||
float,
|
||||
add,
|
||||
mul,
|
||||
relu,
|
||||
sigmoid,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
float_parallelized,
|
||||
add_parallelized,
|
||||
mul_parallelized,
|
||||
div_parallelized,,
|
||||
);
|
||||
|
||||
|
||||
fn relu(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "Relu", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.relu(&ct);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn sigmoid(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "sigmoid", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.sigmoid(&ct);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn mul(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "mul", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.mul_total(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn mul_parallelized(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "mul parallelized", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.mul_total_parallelized(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn div_parallelized(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "div parallelized", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.division(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn add(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "add", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.add_total(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn add_parallelized(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "add parallelized", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.add_total_parallelized(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
20
concrete-float/docs/SUMMARY.md
Normal file
20
concrete-float/docs/SUMMARY.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Concrete-Integer User Guide
|
||||
|
||||
[Introduction](introduction.md)
|
||||
|
||||
# Getting Started
|
||||
|
||||
[Installation](getting_started/installation.md)
|
||||
|
||||
[Writing Your First Circuit](getting_started/first_circuit.md)
|
||||
|
||||
[Types Of Operations](getting_started/operation_types.md)
|
||||
|
||||
[List of Operations](getting_started/operation_list.md)
|
||||
|
||||
[Cryptographic Parameters](getting_started/parameters.md)
|
||||
|
||||
|
||||
# How to
|
||||
|
||||
[Serialization / Deserialization](tutorials/serialization.md)
|
||||
105
concrete-float/docs/getting_started/first_circuit.md
Normal file
105
concrete-float/docs/getting_started/first_circuit.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Writing Your First Circuit
|
||||
|
||||
|
||||
## Key Types
|
||||
|
||||
`concrete-integer` provides 2 basic key types:
|
||||
- `ClientKey`
|
||||
- `ServerKey`
|
||||
|
||||
The `ClientKey` is the key that encrypts and decrypts messages,
|
||||
thus this key is meant to be kept private and should never be shared.
|
||||
This key is created from parameter values that will dictate both the security and efficiency
|
||||
of computations. The parameters also set the maximum number of bits of message encrypted
|
||||
in a ciphertext.
|
||||
|
||||
The `ServerKey` is the key that is used to actually do the FHE computations. It contains (among other things)
|
||||
a bootstrapping key and a keyswitching key.
|
||||
This key is created from a `ClientKey` that needs to be shared to the server, therefore it is not
|
||||
meant to be kept private.
|
||||
A user with a `ServerKey` can compute on the encrypted data sent by the owner of the associated
|
||||
`ClientKey`.
|
||||
|
||||
To reflect that, computation/operation methods are tied to the `ServerKey` type.
|
||||
|
||||
|
||||
## 1. Key Generation
|
||||
|
||||
To generate the keys, a user needs two parameters:
|
||||
- A set of `shortint` cryptographic parameters.
|
||||
- The number of ciphertexts used to encrypt an integer (we call them "shortint blocks").
|
||||
|
||||
|
||||
For this example we are going to build a pair of keys that can encrypt an **8-bit** integer
|
||||
by using **4** shortint blocks that store **2** bits of message each.
|
||||
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 2. Encrypting values
|
||||
|
||||
|
||||
Once we have our keys we can encrypt values:
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 128;
|
||||
let msg2 = 13;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
}
|
||||
```
|
||||
|
||||
## 3. Computing and decrypting
|
||||
|
||||
With our `server_key`, and encrypted values, we can now do an addition
|
||||
and then decrypt the result.
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 128;
|
||||
let msg2 = 13;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
|
||||
// We use the server public key to execute an integer circuit:
|
||||
let ct_3 = server_key.unchecked_add(&ct_1, &ct_2);
|
||||
|
||||
// We use the client key to decrypt the output of the circuit:
|
||||
let output = client_key.decrypt(&ct_3);
|
||||
|
||||
assert_eq!(output, (msg1 + msg2) % modulus);
|
||||
}
|
||||
```
|
||||
49
concrete-float/docs/getting_started/installation.md
Normal file
49
concrete-float/docs/getting_started/installation.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Installation
|
||||
|
||||
## Cargo.toml
|
||||
|
||||
To use `concrete-integer`, you will need to add it to the list of dependencies
|
||||
of your project, by updating your `Cargo.toml` file.
|
||||
|
||||
```toml
|
||||
concrete-integer = "0.1.0"
|
||||
```
|
||||
|
||||
### Supported platforms
|
||||
|
||||
|
||||
As `concrete-integer` relies on `concrete-shortint`, which in turn relies on `concrete-core`,
|
||||
the support ted platforms supported are:
|
||||
- `x86_64 Linux`
|
||||
- `x86_64 macOS`.
|
||||
|
||||
Windows users can use `concrete-integer` through the `WSL`.
|
||||
|
||||
macOS users which have the newer M1 (`arm64`) devices can use `concrete-integer` by cross-compiling to
|
||||
`x86_64` and run their program with Rosetta.
|
||||
|
||||
First install the needed Rust toolchain:
|
||||
|
||||
```console
|
||||
# Install the macOS x86_64 toolchain (you only need to do this once)
|
||||
rustup toolchain install --force-non-host stable-x86_64-apple-darwin
|
||||
```
|
||||
|
||||
Then you can either:
|
||||
|
||||
- Manually specify the toolchain to use in each of the cargo commands:
|
||||
|
||||
For example:
|
||||
|
||||
```console
|
||||
cargo +stable-x86_64-apple-darwin build
|
||||
cargo +stable-x86_64-apple-darwin test
|
||||
```
|
||||
|
||||
- Or override the toolchain to use for the current project:
|
||||
|
||||
```console
|
||||
rustup override set stable-x86_64-apple-darwin
|
||||
# cargo will use the `stable-x86_64-apple-darwin` toolchain.
|
||||
cargo build
|
||||
```
|
||||
15
concrete-float/docs/getting_started/operation_list.md
Normal file
15
concrete-float/docs/getting_started/operation_list.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# List of available operations
|
||||
|
||||
`concrete-integer` comes with a set of already implemented functions:
|
||||
|
||||
|
||||
- addition between two ciphertexts
|
||||
- addition between a ciphertext and an unencrypted scalar
|
||||
- multiplication of a ciphertext by an unencrypted scalar
|
||||
- bitwise shift `<<`, `>>`
|
||||
- bitwise and, or and xor
|
||||
- multiplication between two ciphertexts
|
||||
- subtraction of a ciphertext by another ciphertext
|
||||
- subtraction of a ciphertext by an unencrypted scalar
|
||||
- negation of a ciphertext
|
||||
|
||||
86
concrete-float/docs/getting_started/operation_types.md
Normal file
86
concrete-float/docs/getting_started/operation_types.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# How Integers are represented
|
||||
|
||||
|
||||
In `concrete-integer`, the encrypted data is split amongst many ciphertexts
|
||||
encrypted using the `concrete-shortint` library.
|
||||
|
||||
This crate implements two ways to represent an integer:
|
||||
- the Radix representation
|
||||
- the CRT (Chinese Reminder Theorem) representation
|
||||
|
||||
## Radix based Integers
|
||||
The first possibility to represent a large integer is to use a radix-based decomposition on the
|
||||
plaintexts. Let $$B \in \mathbb{N}$$ be a basis such that the size of $$B$$ is smaller (or equal)
|
||||
to four bits.
|
||||
Then, an integer $$m \in \mathbb{N}$$ can be written as $$m = m_0 + m_1*B + m_2*B^2 + ... $$, where
|
||||
each $$m_i$$ is strictly smaller than $$B$$. Each $$m_i$$ is then independently encrypted. In
|
||||
the end, an Integer ciphertext is defined as a set of Shortint ciphertexts.
|
||||
|
||||
In practice, the definition of an Integer requires the basis and the number of blocks. This is
|
||||
done at the key creation step.
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
}
|
||||
```
|
||||
|
||||
In this example, the keys are dedicated to Integers decomposed as four blocks using the basis
|
||||
$$B=2^2$$. Otherwise said, they allow to work on Integers modulus $$(2^2)^4 = 2^8$$.
|
||||
|
||||
|
||||
In this representation, the correctness of operations requires to propagate the carries
|
||||
between the ciphertext. This operation is costly since it relies on the computation of many
|
||||
programmable bootstrapping over Shortints.
|
||||
|
||||
|
||||
## CRT based Integers
|
||||
The second approach to represent large integers is based on the Chinese Remainder Theorem.
|
||||
In this cases, the basis $$B$$ is composed of several integers $$b_i$$, such that there are
|
||||
pairwise coprime, and each b_i has a size smaller than four bits. Then, the Integer will be
|
||||
defined modulus $$\prod b_i$$. For an integer $$m$$, its CRT decomposition is simply defined as
|
||||
$$m % b_0, m % b_1, ...$$. Each part is then encrypted as a Shortint ciphertext. In
|
||||
the end, an Integer ciphertext is defined as a set of Shortint ciphertexts.
|
||||
|
||||
An example of such a basis
|
||||
could be $$B = [2, 3, 5]$$. This means that the Integer is defined modulus $$2*3*5 = 30$$.
|
||||
|
||||
This representation has many advantages: no carry propagation is required, so that only cleaning
|
||||
the carry buffer of each ciphertexts is enough. This implies that operations can easily be
|
||||
parallelized. Moreover, it allows to efficiently compute PBS in the case where the function is
|
||||
CRT compliant.
|
||||
|
||||
A variant of the CRT is proposed, where each block might be associated to a different key couple.
|
||||
In the end, a keychain is required to the computations, but performance might be improved.
|
||||
|
||||
|
||||
|
||||
# Types of operations
|
||||
|
||||
|
||||
Much like `concrete-shortint`, the operations available via a `ServerKey` may come in different variants:
|
||||
|
||||
- operations that take their inputs as encrypted values.
|
||||
- scalar operations take at least one non-encrypted value as input.
|
||||
|
||||
For example, the addition has both variants:
|
||||
|
||||
- `ServerKey::unchecked_add` which takes two encrypted values and adds them.
|
||||
- `ServerKey::unchecked_scalar_add` which takes an encrypted value and a clear value (the
|
||||
so-called scalar) and adds them.
|
||||
|
||||
Each operation may come in different 'flavors':
|
||||
|
||||
- `unchecked`: Always does the operation, without checking if the result may exceed the capacity of
|
||||
the plaintext space.
|
||||
- `checked`: Checks are done before computing the operation, returning an error if operation
|
||||
cannot be done safely.
|
||||
- `smart`: Always does the operation, if the operation cannot be computed safely, the smart operation
|
||||
will propagate the carry buffer to make the operation possible.
|
||||
|
||||
Not all operations have these 3 flavors, as some of them are implemented in a way that the operation
|
||||
is always possible without ever exceeding the plaintext space capacity.
|
||||
6
concrete-float/docs/getting_started/parameters.md
Normal file
6
concrete-float/docs/getting_started/parameters.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# Use of parameters
|
||||
|
||||
|
||||
`concrete-integer` does not come with its own set of parameters, instead it uses
|
||||
parameters from the `concrete-shortint` crate. Currently, only the parameters
|
||||
`PARAM_MESSAGE_{X}_CARRY_{X}` with `X` in [1,4] can be used in `concrete-integer`.
|
||||
47
concrete-float/docs/how_to/pbs.md
Normal file
47
concrete-float/docs/how_to/pbs.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# The tree programmable bootstrapping
|
||||
|
||||
In `concrete-integer`, the user can evaluate any function on an encrypted ciphertext. To do so the user must first
|
||||
create a `treepbs key`, choose a function to evaluate and give them as parameters to the `tree programmable bootstrapping`.
|
||||
|
||||
Two versions of the tree pbs are implemented: the `standard` version that computes a result according to every encrypted
|
||||
bit (message and carry), and the `base` version that only takes into account the message bits of each block.
|
||||
|
||||
{% hint style="warning" %}
|
||||
|
||||
The `tree pbs` is quite slow, therefore its use is currently restricted to two and three blocks integer ciphertexts.
|
||||
|
||||
{% endhint %}
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
use concrete_integer::treepbs::TreepbsKey;
|
||||
|
||||
fn main() {
|
||||
let num_block = 2;
|
||||
// Generate the client key and the server key:
|
||||
let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg: u64 = 27;
|
||||
let ct = cks.encrypt(msg);
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus.0.pow(2 as u32) as u64;
|
||||
|
||||
let treepbs_key = TreepbsKey::new(&cks);
|
||||
|
||||
let f = |x: u64| x * x;
|
||||
|
||||
// evaluate f
|
||||
let vec_res = treepbs_key.two_block_pbs(&sks, &ct, f);
|
||||
|
||||
// decryption
|
||||
let res = cks.decrypt(&vec_res);
|
||||
|
||||
let clear = f(msg) % modulus;
|
||||
assert_eq!(res, clear);
|
||||
}
|
||||
```
|
||||
|
||||
# The WOP programmable bootstrapping
|
||||
|
||||
8
concrete-float/docs/introduction.md
Normal file
8
concrete-float/docs/introduction.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Concrete-integer
|
||||
|
||||
## Introduction
|
||||
|
||||
`concrete-integer` is a Rust library (crate) based on `concrete-shortint`, this crate provides
|
||||
large precision integers by using multiple `shortint` ciphertexts.
|
||||
|
||||
The intended target audience for this library is people who are somewhat familiar with cryptography.
|
||||
120
concrete-float/docs/tutorials/circuit_evaluation.md
Normal file
120
concrete-float/docs/tutorials/circuit_evaluation.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# Circuit evaluation
|
||||
|
||||
Let's try to do a circuit evaluation using the different flavours of operations we already introduced.
|
||||
For a very small circuit, the `unchecked` flavour may be enough to do the computation correctly.
|
||||
Otherwise, the `checked` and `smart` are the best options.
|
||||
|
||||
As an example, let's do a scalar multiplication, a subtraction and an addition.
|
||||
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12;
|
||||
let msg2 = 11;
|
||||
let msg3 = 9;
|
||||
let scalar = 3;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
let ct_3 = client_key.encrypt(msg2);
|
||||
|
||||
server_key.unchecked_small_scalar_mul_assign(&mut ct_1, scalar);
|
||||
|
||||
server_key.unchecked_sub_assign(&mut ct_1, &ct_2);
|
||||
|
||||
server_key.unchecked_add_assign(&mut ct_1, &ct_3);
|
||||
|
||||
// We use the client key to decrypt the output of the circuit:
|
||||
let output = client_key.decrypt(&ct_1);
|
||||
// The carry buffer has been overflowed, the result is not correct
|
||||
assert_ne!(output, ((msg1 * scalar as u64 - msg2) + msg3) % modulus as u64);
|
||||
}
|
||||
```
|
||||
|
||||
During this computation the carry buffer has been overflowed and as all the operations were `unchecked` the output
|
||||
may be incorrect.
|
||||
|
||||
If we redo this same circuit but using the `checked` flavour, a panic will occur.
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 2;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12;
|
||||
let msg2 = 11;
|
||||
let msg3 = 9;
|
||||
let scalar = 3;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
let ct_3 = client_key.encrypt(msg3);
|
||||
|
||||
let result = server_key.checked_small_scalar_mul_assign(&mut ct_1, scalar);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let result = server_key.checked_sub_assign(&mut ct_1, &ct_2);
|
||||
assert!(result.is_err());
|
||||
|
||||
// We use the client key to decrypt the output of the circuit:
|
||||
// Only the scalar multiplication could be done
|
||||
let output = client_key.decrypt(&ct_1);
|
||||
assert_eq!(output, (msg1 * scalar) % modulus as u64);
|
||||
}
|
||||
```
|
||||
|
||||
Therefore the `checked` flavour permits to manually manage the overflow of the carry buffer
|
||||
by raising an error if the correctness is not guaranteed.
|
||||
|
||||
Lastly, using the `smart` flavour will output the correct result all the time. However, the computation may be slower
|
||||
as the carry buffer may be propagated during the computations.
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12;
|
||||
let msg2 = 11;
|
||||
let msg3 = 9;
|
||||
let scalar = 3;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
let mut ct_2 = client_key.encrypt(msg2);
|
||||
let mut ct_3 = client_key.encrypt(msg3);
|
||||
|
||||
server_key.smart_scalar_mul_assign(&mut ct_1, scalar);
|
||||
|
||||
server_key.smart_sub_assign(&mut ct_1, &mut ct_2);
|
||||
|
||||
server_key.smart_add_assign(&mut ct_1, &mut ct_3);
|
||||
|
||||
// We use the client key to decrypt the output of the circuit:
|
||||
let output = client_key.decrypt(&ct_1);
|
||||
assert_eq!(output, ((msg1 * scalar as u64 - msg2) + msg3) % modulus as u64);
|
||||
}
|
||||
```
|
||||
78
concrete-float/docs/tutorials/serialization.md
Normal file
78
concrete-float/docs/tutorials/serialization.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# Serialization / Deserialization
|
||||
|
||||
As explained in the introduction, some types (`Serverkey`, `Ciphertext`) are meant to be shared
|
||||
with the server that does the computations.
|
||||
|
||||
The easiest way to send these data to a server is to use the serialization and deserialization features.
|
||||
concrete-integer uses the serde framework, serde's Serialize and Deserialize are implemented.
|
||||
|
||||
To be able to serialize our data, we need to pick a [data format], for our use case,
|
||||
[bincode] is a good choice, mainly because it is binary format.
|
||||
|
||||
|
||||
```toml
|
||||
# Cargo.toml
|
||||
|
||||
[dependencies]
|
||||
# ...
|
||||
bincode = "1.3.3"
|
||||
```
|
||||
|
||||
|
||||
```rust
|
||||
// main.rs
|
||||
|
||||
use bincode;
|
||||
|
||||
use std::io::Cursor;
|
||||
use concrete_integer::{gen_keys, ServerKey, Ciphertext};
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 201;
|
||||
let msg2 = 12;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
|
||||
let mut serialized_data = Vec::new();
|
||||
bincode::serialize_into(&mut serialized_data, &server_key)?;
|
||||
bincode::serialize_into(&mut serialized_data, &ct_1)?;
|
||||
bincode::serialize_into(&mut serialized_data, &ct_2)?;
|
||||
|
||||
// Simulate sending serialized data to a server and getting
|
||||
// back the serialized result
|
||||
let serialized_result = server_function(&serialized_data)?;
|
||||
let result: Ciphertext = bincode::deserialize(&serialized_result)?;
|
||||
|
||||
let output = client_key.decrypt(&result);
|
||||
assert_eq!(output, (msg1 + msg2) % modulus);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
fn server_function(serialized_data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
let mut serialized_data = Cursor::new(serialized_data);
|
||||
let server_key: ServerKey = bincode::deserialize_from(&mut serialized_data)?;
|
||||
let ct_1: Ciphertext = bincode::deserialize_from(&mut serialized_data)?;
|
||||
let ct_2: Ciphertext = bincode::deserialize_from(&mut serialized_data)?;
|
||||
|
||||
let result = server_key.unchecked_add(&ct_1, &ct_2);
|
||||
|
||||
let serialized_result = bincode::serialize(&result)?;
|
||||
|
||||
Ok(serialized_result)
|
||||
}
|
||||
```
|
||||
|
||||
[serde]: https://crates.io/crates/serde
|
||||
[data format]: https://serde.rs/#data-formats
|
||||
[bincode]: https://crates.io/crates/bincode
|
||||
15
concrete-float/katex-header.html
Normal file
15
concrete-float/katex-header.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/katex.min.css" integrity="sha384-9eLZqc9ds8eNjO3TmqPeYcDj8n+Qfa4nuSiGYa6DjLNcv9BtN69ZIulL9+8CqC9Y" crossorigin="anonymous">
|
||||
<script src="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/katex.min.js" integrity="sha384-K3vbOmF2BtaVai+Qk37uypf7VrgBubhQreNQe9aGsz9lB63dIFiQVlJbr92dw2Lx" crossorigin="anonymous"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/contrib/auto-render.min.js" integrity="sha384-kmZOZB5ObwgQnS/DuDg6TScgOiWWBiVt0plIRkZCmE6rDZGrEOQeHM5PcHi+nyqe" crossorigin="anonymous"></script>
|
||||
<script>
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
renderMathInElement(document.body, {
|
||||
delimiters: [
|
||||
{left: "$$", right: "$$", display: true},
|
||||
{left: "\\(", right: "\\)", display: false},
|
||||
{left: "$", right: "$", display: false},
|
||||
{left: "\\[", right: "\\]", display: true}
|
||||
]
|
||||
});
|
||||
});
|
||||
</script>
|
||||
BIN
concrete-float/long_run
Normal file
BIN
concrete-float/long_run
Normal file
Binary file not shown.
30
concrete-float/src/ciphertext/mod.rs
Normal file
30
concrete-float/src/ciphertext/mod.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
//! This module implements the ciphertext structure containing an encryption of an integer message.
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe::shortint;
|
||||
|
||||
/// Id to recognize the key used to encrypt a block.
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct KeyId(pub usize);
|
||||
|
||||
#[derive(Serialize, Clone, Deserialize, PartialEq, Eq, Debug)]
|
||||
pub struct Ciphertext {
|
||||
pub ct_vec_mantissa: Vec<shortint::ciphertext::Ciphertext>,
|
||||
pub ct_vec_exponent: Vec<shortint::ciphertext::Ciphertext>,
|
||||
pub ct_sign: shortint::ciphertext::Ciphertext,
|
||||
pub(crate) e_min: i64,
|
||||
}
|
||||
impl Ciphertext {
|
||||
/// Returns the slice of blocks that the ciphertext is composed of.
|
||||
pub fn mantissa_blocks(&self) -> &[shortint::Ciphertext] {
|
||||
&self.ct_vec_mantissa
|
||||
}
|
||||
pub fn exponent_blocks(&self) -> &[shortint::Ciphertext] {
|
||||
&self.ct_vec_exponent
|
||||
}
|
||||
pub fn sign(&self) -> &shortint::Ciphertext {
|
||||
&self.ct_sign
|
||||
}
|
||||
pub fn e_min(&self) -> &i64 {
|
||||
&self.e_min
|
||||
}
|
||||
}
|
||||
265
concrete-float/src/client_key/mod.rs
Normal file
265
concrete-float/src/client_key/mod.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
//! This module implements the generation of the client secret keys, together with the
|
||||
//! encryption and decryption methods.
|
||||
|
||||
pub(crate) mod utils;
|
||||
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe::shortint;
|
||||
use tfhe::shortint::{ClassicPBSParameters, WopbsParameters};
|
||||
pub use utils::radix_decomposition;
|
||||
|
||||
/// The number of ciphertexts in the vector.
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct VecLength(pub usize);
|
||||
|
||||
/// A structure containing the client key, which must be kept secret.
|
||||
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
|
||||
pub struct ClientKey {
|
||||
pub(crate) key: shortint::client_key::ClientKey,
|
||||
pub(crate) vector_length_mantissa: VecLength,
|
||||
pub(crate) vector_length_exponent: VecLength,
|
||||
}
|
||||
|
||||
impl ClientKey {
|
||||
/// Allocates and generates a client key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::client_key::ClientKey;
|
||||
/// use concrete_float::parameters::{PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32};
|
||||
/// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// // Generate the client key associated to integers over 4 blocks
|
||||
/// // of messages with modulus over 2 bits
|
||||
/// let param = (PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32);
|
||||
/// let cks = ClientKey::new(param, 4, 1);
|
||||
/// ```
|
||||
pub fn new(
|
||||
parameter_set: (ClassicPBSParameters, WopbsParameters),
|
||||
size_mantissa: usize,
|
||||
size_exponent: usize,
|
||||
) -> Self {
|
||||
let key = shortint::ClientKey::new(parameter_set);
|
||||
Self {
|
||||
key,
|
||||
vector_length_mantissa: VecLength(size_mantissa),
|
||||
vector_length_exponent: VecLength(size_exponent),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the parameters used by the client key.
|
||||
pub fn parameters(&self) -> shortint::parameters::ShortintParameterSet {
|
||||
self.key.parameters
|
||||
}
|
||||
|
||||
/// Encrypts a float message using the client key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::client_key::ClientKey;
|
||||
/// use concrete_float::parameters::{PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32};
|
||||
///
|
||||
/// let param = (PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32);
|
||||
/// let mut cks = ClientKey::new(param, 3, 1);
|
||||
///
|
||||
/// let msg = 1844640.;
|
||||
/// // Encryption of one message:
|
||||
/// let ct = cks.encrypt(msg);
|
||||
/// let res = cks.decrypt(&ct);
|
||||
///
|
||||
/// //approximation less than 0.1%
|
||||
/// assert_eq!(res, msg)
|
||||
/// ```
|
||||
pub fn encrypt(&self, message: f64) -> Ciphertext {
|
||||
let ct_sign = self.encrypt_sign(message);
|
||||
|
||||
let log_msg_modulus = f64::log2(self.parameters().message_modulus().0 as f64) as usize;
|
||||
let e_min = -((1 << (self.vector_length_exponent.0 * log_msg_modulus - 1)) as i64)
|
||||
- (self.vector_length_mantissa.0 as i64 - 1);
|
||||
if message == 0. {
|
||||
let exponent = 0;
|
||||
let mantissa = 0.0;
|
||||
let ct_vec_mantissa = self.encrypt_mantissa(mantissa as u64);
|
||||
let ct_vec_exponent = self.encrypt_exponent(exponent as u64);
|
||||
Ciphertext {
|
||||
ct_vec_mantissa,
|
||||
ct_vec_exponent,
|
||||
ct_sign,
|
||||
e_min,
|
||||
}
|
||||
} else {
|
||||
let length_mantissa = self.vector_length_mantissa.0;
|
||||
let log_message_modulus =
|
||||
f64::log2(self.parameters().message_modulus().0 as f64) as usize;
|
||||
|
||||
let value_exponent = log_message_modulus as u64;
|
||||
let mut exponent = e_min.abs();
|
||||
let mut cpy_message = message.abs();
|
||||
while cpy_message < (1_u128 << (length_mantissa * log_message_modulus)) as f64 {
|
||||
cpy_message *= (1 << value_exponent) as f64;
|
||||
exponent -= 1;
|
||||
}
|
||||
while cpy_message >= (1_u128 << (length_mantissa * log_message_modulus)) as f64 {
|
||||
cpy_message /= (1 << value_exponent) as f64;
|
||||
exponent += 1;
|
||||
}
|
||||
//TODO
|
||||
if exponent >= (1 << (log_message_modulus * self.vector_length_exponent.0) as i64) {
|
||||
println!("encrypt overflow");
|
||||
}
|
||||
if exponent < 0 {
|
||||
for _ in 0..exponent.abs() {
|
||||
cpy_message /= (1 << value_exponent) as f64;
|
||||
}
|
||||
exponent = 0;
|
||||
//panic!()
|
||||
}
|
||||
let mantissa = cpy_message.round() as u64;
|
||||
let ct_vec_mantissa = self.encrypt_mantissa(mantissa);
|
||||
let ct_vec_exponent = self.encrypt_exponent(exponent as u64);
|
||||
Ciphertext {
|
||||
ct_vec_mantissa,
|
||||
ct_vec_exponent,
|
||||
ct_sign,
|
||||
e_min,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn encrypt_sign(&self, message: f64) -> shortint::ciphertext::Ciphertext {
|
||||
let sign: u64;
|
||||
if message >= 0. {
|
||||
sign = 0;
|
||||
} else {
|
||||
sign = 1
|
||||
}
|
||||
self.key.encrypt_without_padding(
|
||||
sign * (self.key.parameters.message_modulus().0 * self.key.parameters.carry_modulus().0
|
||||
/ 2) as u64,
|
||||
)
|
||||
}
|
||||
|
||||
fn encrypt_mantissa(&self, mantissa: u64) -> Vec<shortint::Ciphertext> {
|
||||
let mut ct_vec_mantissa: Vec<shortint::ciphertext::Ciphertext> = Vec::new();
|
||||
let mut power = 1_u128;
|
||||
let message_modulus = self.parameters().message_modulus().0 as u128;
|
||||
for _ in 0..self.vector_length_mantissa.0 {
|
||||
let mut decomp = mantissa as u128 & ((message_modulus - 1) * power);
|
||||
decomp /= power;
|
||||
|
||||
// encryption
|
||||
let ct = self.key.encrypt(decomp as u64);
|
||||
ct_vec_mantissa.push(ct);
|
||||
//modulus to the power i
|
||||
power *= message_modulus;
|
||||
}
|
||||
ct_vec_mantissa
|
||||
}
|
||||
|
||||
fn encrypt_exponent(&self, exponent: u64) -> Vec<shortint::Ciphertext> {
|
||||
let mut ct_vec_exponent: Vec<shortint::ciphertext::Ciphertext> = Vec::new();
|
||||
let mut power = 1_u64;
|
||||
let message_modulus = self.parameters().message_modulus().0 as u64;
|
||||
for _ in 0..self.vector_length_exponent.0 {
|
||||
let mut decomp = exponent as u64 & ((message_modulus - 1) * power);
|
||||
decomp /= power;
|
||||
|
||||
// encryption
|
||||
let ct = self.key.encrypt(decomp);
|
||||
ct_vec_exponent.push(ct);
|
||||
//modulus to the power i
|
||||
power *= message_modulus;
|
||||
}
|
||||
ct_vec_exponent
|
||||
}
|
||||
|
||||
/// Decrypts a ciphertext encrypting an float message
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::client_key::ClientKey;
|
||||
/// use concrete_float::parameters::{PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32};
|
||||
///
|
||||
/// let param = (PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32);
|
||||
/// let mut cks = ClientKey::new(param, 3, 1);
|
||||
///
|
||||
/// let msg = 1844640.;
|
||||
/// // Encryption of one message:
|
||||
/// let ct = cks.encrypt(msg);
|
||||
/// let res = cks.decrypt(&ct);
|
||||
///
|
||||
/// //approximation less than 0.1%
|
||||
/// assert_eq!(res, msg)
|
||||
/// ```
|
||||
pub fn decrypt(&self, ctxt: &Ciphertext) -> f64 {
|
||||
let log_message_modulus = f64::log2(self.parameters().message_modulus().0 as f64) as usize;
|
||||
let value_exponent = log_message_modulus as i64;
|
||||
|
||||
let mut mantissa = self.decrypt_mantissa(&ctxt.ct_vec_mantissa) as f64;
|
||||
let mut exponent = self.decrypt_exponent(&ctxt.ct_vec_exponent) as i64;
|
||||
let sign = self.decrypt_sign(&ctxt.ct_sign);
|
||||
|
||||
exponent += ctxt.e_min;
|
||||
if exponent > 0 {
|
||||
for _ in 0..exponent.abs() {
|
||||
mantissa *= (1_u128 << value_exponent) as f64
|
||||
}
|
||||
} else {
|
||||
for _ in 0..exponent.abs() {
|
||||
mantissa /= (1_u128 << value_exponent) as f64
|
||||
}
|
||||
}
|
||||
|
||||
let res;
|
||||
if sign == 1 {
|
||||
res = -mantissa
|
||||
} else {
|
||||
res = mantissa
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub fn decrypt_mantissa(&self, ctxt: &Vec<shortint::Ciphertext>) -> u128 {
|
||||
let mut result = 0_u128;
|
||||
let mut shift = 1_u128;
|
||||
for c_i in ctxt.iter() {
|
||||
//decrypt the component i of the integer and multiply it by the radix product
|
||||
let tmp = (self.key.decrypt_message_and_carry(c_i) as u128).wrapping_mul(shift);
|
||||
|
||||
// update the result
|
||||
result = result.wrapping_add(tmp as u128);
|
||||
|
||||
// update the shift for the next iteration
|
||||
shift = shift.wrapping_mul(self.parameters().message_modulus().0 as u128);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn decrypt_exponent(&self, ctxt: &Vec<shortint::Ciphertext>) -> u64 {
|
||||
let mut result = 0_u64;
|
||||
let mut shift = 1_u64;
|
||||
for c_i in ctxt.iter() {
|
||||
//decrypt the component i of the integer and multiply it by the radix product
|
||||
let tmp = self.key.decrypt_message_and_carry(c_i).wrapping_mul(shift);
|
||||
|
||||
// update the result
|
||||
result = result.wrapping_add(tmp);
|
||||
|
||||
// update the shift for the next iteration
|
||||
shift = shift.wrapping_mul(self.parameters().message_modulus().0 as u64);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn decrypt_sign(&self, ctxt: &shortint::Ciphertext) -> u64 {
|
||||
let result = self.key.decrypt_message_and_carry_without_padding(ctxt);
|
||||
result
|
||||
/ (self.key.parameters.message_modulus().0 * self.key.parameters.carry_modulus().0 / 2)
|
||||
as u64
|
||||
}
|
||||
}
|
||||
51
concrete-float/src/client_key/utils.rs
Normal file
51
concrete-float/src/client_key/utils.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct RadixDecomposition {
|
||||
pub msg_space: usize,
|
||||
pub block_number: usize,
|
||||
}
|
||||
|
||||
/// Computes possible radix decompositions
|
||||
///
|
||||
/// Takes the number of bit of the message space as input and output a vector containing all the
|
||||
/// correct
|
||||
/// possible block decomposition assuming the same message space for all blocks.
|
||||
/// Lower and upper bounds define the minimal and maximal space to be considered
|
||||
/// Example: 6,2,4 -> [ [2,3], [3,2]] : [msg_space = 2 bits, block_number = 3]
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::client_key::radix_decomposition;
|
||||
/// let input_space = 16; //
|
||||
/// let min = 2;
|
||||
/// let max = 4;
|
||||
/// let decomp = radix_decomposition(input_space, min, max);
|
||||
///
|
||||
/// // Check that 3 possible radix decompositions are provided
|
||||
/// assert_eq!(decomp.len(), 3);
|
||||
/// ```
|
||||
pub fn radix_decomposition(
|
||||
input_space: usize,
|
||||
min_space: usize,
|
||||
max_space: usize,
|
||||
) -> Vec<RadixDecomposition> {
|
||||
let mut out: Vec<RadixDecomposition> = vec![];
|
||||
let mut max = max_space;
|
||||
if max_space > input_space {
|
||||
max = input_space;
|
||||
}
|
||||
for msg_space in min_space..max + 1 {
|
||||
let mut block_number = input_space / msg_space;
|
||||
//Manual ceil of the division
|
||||
if input_space % msg_space != 0 {
|
||||
block_number += 1;
|
||||
}
|
||||
out.push(RadixDecomposition {
|
||||
msg_space,
|
||||
block_number,
|
||||
})
|
||||
}
|
||||
out
|
||||
}
|
||||
41
concrete-float/src/keycache.rs
Normal file
41
concrete-float/src/keycache.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::path::Path;
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use crate::{ClientKey, ServerKey};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct FloatKeyCache;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref KEY_CACHE: FloatKeyCache = FloatKeyCache::default();
|
||||
}
|
||||
|
||||
pub fn get_sks(str: &str) -> ServerKey {
|
||||
let fiptr = format!("key/sks_key/{}", str);
|
||||
let filepath = Path::new(&fiptr);
|
||||
let file = BufReader::new(File::open(filepath).unwrap());
|
||||
let saved_key: ServerKey = bincode::deserialize_from(file).unwrap();
|
||||
saved_key
|
||||
}
|
||||
|
||||
pub fn get_cks(str: &str) -> ClientKey {
|
||||
let fiptr = format!("key/cks_key/{}", str);
|
||||
let filepath = Path::new(&fiptr);
|
||||
let file = BufReader::new(File::open(filepath).unwrap());
|
||||
let saved_key: ClientKey = bincode::deserialize_from(file).unwrap();
|
||||
saved_key
|
||||
}
|
||||
|
||||
pub fn save_sks(key: ServerKey, str: &str) {
|
||||
let filepath = format!("key/sks_key/{}", str);
|
||||
let file = BufWriter::new(File::create(filepath).unwrap());
|
||||
bincode::serialize_into(file, &key).unwrap();
|
||||
}
|
||||
|
||||
pub fn save_cks(key: ClientKey ,str: &str) {
|
||||
let filepath = format!("key/cks_key/{}", str);
|
||||
let file = BufWriter::new(File::create(filepath).unwrap());
|
||||
bincode::serialize_into(file, &key).unwrap();
|
||||
}
|
||||
92
concrete-float/src/lib.rs
Executable file
92
concrete-float/src/lib.rs
Executable file
@@ -0,0 +1,92 @@
|
||||
/*
|
||||
#![allow(clippy::excessive_precision)]
|
||||
//! Welcome the the `concrete-integer` documentation!
|
||||
//!
|
||||
//! # Description
|
||||
//!
|
||||
//! This library makes it possible to execute modular operations over encrypted integer.
|
||||
//!
|
||||
//! It allows to execute an integer circuit on an untrusted server because both circuit inputs
|
||||
//! outputs are kept private.
|
||||
//!
|
||||
//! Data are encrypted on the client side, before being sent to the server.
|
||||
//! On the server side every computation is performed on ciphertexts
|
||||
//!
|
||||
//! # Quick Example
|
||||
//!
|
||||
//! The following piece of code shows how to generate keys and run a integer circuit
|
||||
//! homomorphically.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use concrete_float::gen_keys;
|
||||
//! use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
//!
|
||||
//! //4 blocks for the radix decomposition
|
||||
//! let number_of_blocks = 4;
|
||||
//! // Modulus = (2^2)*4 = 2^8 (from the parameters chosen and the number of blocks
|
||||
//! let modulus = 1 << 8;
|
||||
//!
|
||||
//! // Generation of the client/server keys, using the default parameters:
|
||||
//! let (mut client_key, mut server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, number_of_blocks);
|
||||
//!
|
||||
//! let msg1 = 153;
|
||||
//! let msg2 = 125;
|
||||
//!
|
||||
//! // Encryption of two messages using the client key:
|
||||
//! let ct_1 = client_key.encrypt(msg1);
|
||||
//! let ct_2 = client_key.encrypt(msg2);
|
||||
//!
|
||||
//! // Homomorphic evaluation of an integer circuit (here, an addition) using the server key:
|
||||
//! let ct_3 = server_key.unchecked_add(&ct_1, &ct_2);
|
||||
//!
|
||||
//! // Decryption of the ciphertext using the client key:
|
||||
//! let output = client_key.decrypt(&ct_3);
|
||||
//! assert_eq!(output, (msg1 + msg2) % modulus);
|
||||
//! ```
|
||||
//!
|
||||
//! # Warning
|
||||
//! This uses cryptographic parameters from the `concrete-shortint` crates.
|
||||
//! Currently, the radix approach is only compatible with parameter sets such
|
||||
//! that the message and carry buffers have the same size.
|
||||
extern crate core;
|
||||
*/
|
||||
extern crate core;
|
||||
|
||||
pub mod ciphertext;
|
||||
pub mod client_key;
|
||||
pub mod parameters;
|
||||
pub mod server_key;
|
||||
use crate::client_key::ClientKey;
|
||||
use crate::server_key::ServerKey;
|
||||
//pub mod keycache;
|
||||
//pub mod wopbs;
|
||||
#[cfg(doctest)]
|
||||
//mod test_user_docs;
|
||||
use tfhe::shortint;
|
||||
use tfhe::shortint;
|
||||
|
||||
/// Generate a couple of client and server keys with given parameters
|
||||
///
|
||||
/// * the client key is used to encrypt and decrypt and has to be kept secret;
|
||||
/// * the server key is used to perform homomorphic operations on the server side and it is meant to
|
||||
/// be published (the client sends it to the server).
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::gen_keys;
|
||||
/// use concrete_shortint::parameters::DEFAULT_PARAMETERS;
|
||||
///
|
||||
/// let size_mantissa = 4;
|
||||
/// let size_exponent = 1;
|
||||
/// ```
|
||||
pub fn gen_keys(
|
||||
parameters_set: shortint::ClassicPBSParameters,
|
||||
parameters_set_wopbs: shortint::WopbsParameters,
|
||||
size_mantissa: usize,
|
||||
size_exponent: usize,
|
||||
) -> (ClientKey, ServerKey) {
|
||||
let params = (parameters_set, parameters_set_wopbs);
|
||||
let cks = ClientKey::new(params, size_mantissa, size_exponent);
|
||||
let sks = ServerKey::new(&cks);
|
||||
|
||||
(cks, sks)
|
||||
}
|
||||
1057
concrete-float/src/parameters/mod.rs
Normal file
1057
concrete-float/src/parameters/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
158
concrete-float/src/server_key/add.rs
Normal file
158
concrete-float/src/server_key/add.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use rayon::prelude::*;
|
||||
use tfhe::shortint;
|
||||
|
||||
//use crate::keycache::{get_sks, get_cks};
|
||||
|
||||
impl ServerKey {
|
||||
/// Computes homomorphically an addition between two ciphertexts encrypting integer values.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
/// ciphertext.
|
||||
///
|
||||
/// The result is returned as a new ciphertext.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// ```
|
||||
pub fn unchecked_add_mantissa(
|
||||
&self,
|
||||
ct_left: &Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) -> Ciphertext {
|
||||
let mut result = ct_left.clone();
|
||||
self.unchecked_add_mantissa_assign(&mut result, ct_right);
|
||||
result
|
||||
}
|
||||
|
||||
/// Computes homomorphically an addition between two ciphertexts encrypting integer values.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
/// ciphertext.
|
||||
///
|
||||
/// The result is assigned to the `ct_left` ciphertext.
|
||||
/// ```rust
|
||||
/// ```
|
||||
pub fn unchecked_add_mantissa_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) {
|
||||
for (ct_left_i, ct_right_i) in ct_left
|
||||
.ct_vec_mantissa
|
||||
.iter_mut()
|
||||
.zip(ct_right.ct_vec_mantissa.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
|
||||
}
|
||||
}
|
||||
|
||||
/// we suppose that the mantissa are align
|
||||
pub fn add_mantissa(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) {
|
||||
for (ct_left_i, ct_right_i) in ct_left
|
||||
.ct_vec_mantissa
|
||||
.iter_mut()
|
||||
.zip(ct_right.ct_vec_mantissa.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Verifies if ct1 and ct2 can be added together.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
///```rust
|
||||
/// ```
|
||||
pub fn is_add_possible(
|
||||
&self,
|
||||
ct_left: &[shortint::ciphertext::Ciphertext],
|
||||
ct_right: &[shortint::ciphertext::Ciphertext],
|
||||
) -> bool {
|
||||
for (ct_left_i, ct_right_i) in ct_left.iter().zip(ct_right.iter()) {
|
||||
if self.key.is_add_possible(ct_left_i, ct_right_i).is_err() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn add_total(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let res_sign = self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign);
|
||||
let (mut ct1_aligned, mut ct2_aligned) = self.align_mantissa(&ct1, &ct2);
|
||||
let ct_sub = self.sub_mantissa(&ct1_aligned, &ct2_aligned);
|
||||
self.add_mantissa(&mut ct1_aligned, &mut ct2_aligned);
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let ggsw = self.ggsw_ks_cbs(&res_sign, 0); // let ggsw = self.wopbs_key.extract_one_bit_cbs(&self.key, &res_sign, 63);
|
||||
let mut res = self.cmuxes_full(&ct1_aligned, &ct_sub, &ggsw);
|
||||
self.clean_degree(&mut res);
|
||||
res
|
||||
}
|
||||
|
||||
/// Computes homomorphically an addition between two ciphertexts encrypting integer values.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
/// ciphertext.
|
||||
///
|
||||
/// The result is returned as a new ciphertext.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// ```
|
||||
pub fn unchecked_add_mantissa_parallelized(
|
||||
&self,
|
||||
ct_left: &Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) -> Ciphertext {
|
||||
let mut result = ct_left.clone();
|
||||
self.unchecked_add_mantissa_assign_parallelized(&mut result, ct_right);
|
||||
result
|
||||
}
|
||||
|
||||
/// Computes homomorphically an addition between two ciphertexts encrypting integer values.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
/// ciphertext.
|
||||
///
|
||||
/// The result is assigned to the `ct_left` ciphertext.
|
||||
/// ```rust
|
||||
/// ```
|
||||
pub fn unchecked_add_mantissa_assign_parallelized(
|
||||
&self,
|
||||
ct_left: &mut Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) {
|
||||
ct_left
|
||||
.ct_vec_mantissa
|
||||
.par_iter_mut()
|
||||
.zip(ct_right.ct_vec_mantissa.par_iter())
|
||||
.for_each(|(ct_left_i, ct_right_i)| {
|
||||
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
|
||||
});
|
||||
}
|
||||
|
||||
/// we suppose that the mantissa are align
|
||||
pub fn add_mantissa_parallelized(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) {
|
||||
// The operation is too small to be worth parallelizing
|
||||
ct_left
|
||||
.ct_vec_mantissa
|
||||
.iter_mut()
|
||||
.zip(ct_right.ct_vec_mantissa.iter())
|
||||
.for_each(|(ct_left_i, ct_right_i)| {
|
||||
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn add_total_parallelized(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let res_sign = self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign);
|
||||
let (mut ct1_aligned, mut ct2_aligned) = self.align_mantissa_parallelized(&ct1, &ct2);
|
||||
let ct_sub = self.sub_mantissa_parallelized(&ct1_aligned, &ct2_aligned);
|
||||
self.add_mantissa_parallelized(&mut ct1_aligned, &mut ct2_aligned);
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let ggsw = self.ggsw_ks_cbs_parallelized(&res_sign, 0); // let ggsw = self.wopbs_key.extract_one_bit_cbs(&self.key, &res_sign, 63);
|
||||
let mut res = self.cmuxes_full_parallelized(&ct1_aligned, &ct_sub, &ggsw);
|
||||
self.clean_degree_parallelized(&mut res);
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
162
concrete-float/src/server_key/align_mantissa.rs
Normal file
162
concrete-float/src/server_key/align_mantissa.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use aligned_vec::ABox;
|
||||
use rayon::prelude::*;
|
||||
use tfhe::core_crypto::fft_impl::fft64::c64;
|
||||
use tfhe::core_crypto::fft_impl::fft64::crypto::ggsw::FourierGgswCiphertext;
|
||||
use tfhe::shortint;
|
||||
|
||||
impl ServerKey {
|
||||
// align the two mantissas of to floating points
|
||||
pub fn align_mantissa(
|
||||
&self,
|
||||
ct_left: &Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) -> (Ciphertext, Ciphertext) {
|
||||
let (ct_res, sign) = self.sub(&ct_left.ct_vec_exponent, &ct_right.ct_vec_exponent);
|
||||
let (vec_ggsw, sign_ggsw) =
|
||||
self.create_vec_ggsw_after_sub(&ct_res, &sign, ct_left.ct_vec_mantissa.len());
|
||||
let mut need_to_be_aligned = self.cmuxes(
|
||||
&ct_left.ct_vec_mantissa,
|
||||
&ct_right.ct_vec_mantissa,
|
||||
&sign_ggsw,
|
||||
);
|
||||
let aligned_exp = self.cmuxes(
|
||||
&ct_right.ct_vec_exponent,
|
||||
&ct_left.ct_vec_exponent,
|
||||
&sign_ggsw,
|
||||
);
|
||||
let aligned = self.cmux_tree_mantissa(&mut need_to_be_aligned, &vec_ggsw);
|
||||
let ct_left_aligned = self.cmuxes(&aligned, &ct_left.ct_vec_mantissa, &sign_ggsw);
|
||||
let ct_right_aligned = self.cmuxes(&ct_right.ct_vec_mantissa, &aligned, &sign_ggsw);
|
||||
let new_left = Ciphertext {
|
||||
ct_vec_mantissa: ct_left_aligned,
|
||||
ct_vec_exponent: aligned_exp.clone(),
|
||||
ct_sign: ct_left.ct_sign.clone(),
|
||||
e_min: ct_left.e_min,
|
||||
};
|
||||
let new_right = Ciphertext {
|
||||
ct_vec_mantissa: ct_right_aligned,
|
||||
ct_vec_exponent: aligned_exp,
|
||||
ct_sign: ct_right.ct_sign.clone(),
|
||||
e_min: ct_right.e_min,
|
||||
};
|
||||
(new_left, new_right)
|
||||
}
|
||||
|
||||
pub fn align_mantissa_parallelized(
|
||||
&self,
|
||||
ct_left: &Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) -> (Ciphertext, Ciphertext) {
|
||||
let (mut ct_res, sign) =
|
||||
self.abs_diff_parallelized(&ct_left.ct_vec_exponent, &ct_right.ct_vec_exponent);
|
||||
|
||||
let (vec_ggsw, sign_ggsw) = self.create_vec_ggsw_after_sub_parallelized(
|
||||
&mut ct_res,
|
||||
&sign,
|
||||
ct_left.ct_vec_mantissa.len(),
|
||||
);
|
||||
|
||||
let (mut need_to_be_aligned, aligned_exp) = rayon::join(
|
||||
|| {
|
||||
self.cmuxes_parallelized(
|
||||
&ct_left.ct_vec_mantissa,
|
||||
&ct_right.ct_vec_mantissa,
|
||||
&sign_ggsw,
|
||||
)
|
||||
},
|
||||
|| {
|
||||
self.cmuxes_parallelized(
|
||||
&ct_right.ct_vec_exponent,
|
||||
&ct_left.ct_vec_exponent,
|
||||
&sign_ggsw,
|
||||
)
|
||||
},
|
||||
);
|
||||
let aligned = self.cmux_tree_mantissa_parallelized(&mut need_to_be_aligned, &vec_ggsw);
|
||||
let (ct_left_aligned, ct_right_aligned) = rayon::join(
|
||||
|| self.cmuxes_parallelized(&aligned, &ct_left.ct_vec_mantissa, &sign_ggsw),
|
||||
|| self.cmuxes_parallelized(&ct_right.ct_vec_mantissa, &aligned, &sign_ggsw),
|
||||
);
|
||||
let new_left = Ciphertext {
|
||||
ct_vec_mantissa: ct_left_aligned,
|
||||
ct_vec_exponent: aligned_exp.clone(),
|
||||
ct_sign: ct_left.ct_sign.clone(),
|
||||
e_min: ct_left.e_min,
|
||||
};
|
||||
let new_right = Ciphertext {
|
||||
ct_vec_mantissa: ct_right_aligned,
|
||||
ct_vec_exponent: aligned_exp,
|
||||
ct_sign: ct_right.ct_sign.clone(),
|
||||
e_min: ct_right.e_min,
|
||||
};
|
||||
(new_left, new_right)
|
||||
}
|
||||
|
||||
pub fn create_vec_ggsw_after_sub(
|
||||
&self,
|
||||
ct_res: &Vec<shortint::ciphertext::Ciphertext>,
|
||||
sign: &shortint::ciphertext::Ciphertext,
|
||||
len_mantissa: usize,
|
||||
) -> (
|
||||
Vec<FourierGgswCiphertext<ABox<[c64]>>>,
|
||||
FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as usize;
|
||||
|
||||
let mut ct_res = ct_res.clone();
|
||||
self.full_propagate_exponent(&mut ct_res);
|
||||
let mut vec_ggsw = Vec::new();
|
||||
for i in 0..ct_res.len() {
|
||||
if len_mantissa < ((f64::log2(msg_modulus as f64) as usize) * i) {
|
||||
let mut ggsw = vec![self.ggsw_pbs_ks_cbs(&ct_res[i], msg_space)];
|
||||
ggsw.append(&mut vec_ggsw);
|
||||
vec_ggsw = ggsw
|
||||
} else {
|
||||
let mut ggsw = self.extract_bit_cbs(&ct_res[i]);
|
||||
ggsw.append(&mut vec_ggsw);
|
||||
vec_ggsw = ggsw;
|
||||
}
|
||||
}
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let sign_ggsw = self.ggsw_ks_cbs(&sign, 0);
|
||||
(vec_ggsw, sign_ggsw)
|
||||
}
|
||||
|
||||
pub fn create_vec_ggsw_after_sub_parallelized(
|
||||
&self,
|
||||
ct_res: &mut [shortint::ciphertext::Ciphertext],
|
||||
sign: &shortint::ciphertext::Ciphertext,
|
||||
len_mantissa: usize,
|
||||
) -> (
|
||||
Vec<FourierGgswCiphertext<ABox<[c64]>>>,
|
||||
FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as usize;
|
||||
|
||||
self.full_propagate_exponent_parallelized(ct_res);
|
||||
|
||||
let vec_ggsw: Vec<_> = ct_res
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.map(|(i, block)| {
|
||||
if (msg_modulus.ilog2() as usize * i) > len_mantissa {
|
||||
vec![self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(&block, msg_space)]
|
||||
} else {
|
||||
self.extract_bit_cbs_parallelized(&block)
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let sign_ggsw = self.ggsw_ks_cbs_parallelized(&sign, 0);
|
||||
(vec_ggsw, sign_ggsw)
|
||||
}
|
||||
}
|
||||
53
concrete-float/src/server_key/division.rs
Normal file
53
concrete-float/src/server_key/division.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::server_key::ServerKey;
|
||||
use tfhe::integer::ciphertext::RadixCiphertext;
|
||||
use tfhe::integer::IntegerCiphertext;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn division(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let log_msg_modulus = f64::log2(msg_modulus as f64) as u64;
|
||||
let len_vec_exp = ct1.ct_vec_exponent.len();
|
||||
let len_vec_man = ct1.ct_vec_mantissa.len();
|
||||
|
||||
let mut res = self.create_trivial_zero(
|
||||
ct1.ct_vec_mantissa.len(),
|
||||
ct1.ct_vec_exponent.len(),
|
||||
ct1.e_min,
|
||||
);
|
||||
let zero = self.create_trivial_zero(
|
||||
ct1.ct_vec_mantissa.len(),
|
||||
ct1.ct_vec_exponent.len(),
|
||||
ct1.e_min,
|
||||
);
|
||||
res.ct_sign = self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign);
|
||||
res.ct_vec_exponent = ct1.ct_vec_exponent.clone();
|
||||
|
||||
let cst = ct1.e_min + len_vec_man as i64 - 1;
|
||||
for i in 0..len_vec_exp {
|
||||
let cst = (cst.abs() as u64) >> (log_msg_modulus * i as u64);
|
||||
self.key.unchecked_scalar_add_assign(
|
||||
&mut res.ct_vec_exponent[i],
|
||||
(cst % msg_modulus) as u8,
|
||||
);
|
||||
}
|
||||
let (res_exp, sign) = self.sub(&res.ct_vec_exponent, &ct2.ct_vec_exponent);
|
||||
res.ct_vec_exponent = res_exp;
|
||||
let mut cct1 = RadixCiphertext::from(ct1.ct_vec_mantissa.clone());
|
||||
let mut cct2 = RadixCiphertext::from(ct2.ct_vec_mantissa.clone());
|
||||
|
||||
let int_key = tfhe::integer::ServerKey::from_shortint_ex(self.key.clone());
|
||||
|
||||
int_key.extend_radix_with_trivial_zero_blocks_lsb_assign(&mut cct1, len_vec_man - 1);
|
||||
int_key.extend_radix_with_trivial_zero_blocks_msb_assign(&mut cct2, len_vec_man - 1);
|
||||
|
||||
let res_mantissa = int_key.unchecked_div_parallelized(&cct1, &cct2);
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let sign_ggsw = self.ggsw_ks_cbs(&sign, 0);
|
||||
|
||||
res.ct_vec_mantissa = res_mantissa.blocks()[..len_vec_man].to_vec();
|
||||
res = self.cmuxes_full(&zero, &res, &sign_ggsw);
|
||||
res
|
||||
}
|
||||
}
|
||||
390
concrete-float/src/server_key/mod.rs
Normal file
390
concrete-float/src/server_key/mod.rs
Normal file
@@ -0,0 +1,390 @@
|
||||
//! Module with the definition of the ServerKey.
|
||||
//!
|
||||
//! This module implements the generation of the server public key, together with all the
|
||||
//! available homomorphic integer operations.
|
||||
mod add;
|
||||
mod align_mantissa;
|
||||
mod division;
|
||||
mod mul;
|
||||
mod relu;
|
||||
mod sigmoid;
|
||||
mod sub;
|
||||
mod tools;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use tfhe::shortint;
|
||||
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::client_key::ClientKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shortint::ciphertext::{Degree, MaxDegree};
|
||||
|
||||
/// Error returned when the carry buffer is full.
|
||||
pub use shortint::CheckError;
|
||||
|
||||
/// A structure containing the server public key.
|
||||
///
|
||||
/// The server key is generated by the client and is meant to be published: the client
|
||||
/// sends it to the server so it can compute homomorphic integer circuits.
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct ServerKey {
|
||||
pub key: shortint::server_key::ServerKey,
|
||||
pub integer_key: tfhe::integer::server_key::ServerKey,
|
||||
pub wopbs_key: shortint::wopbs::WopbsKey,
|
||||
}
|
||||
|
||||
impl ServerKey {
|
||||
/// Generates a server key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::parameters::{PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32};
|
||||
/// use concrete_float::{ClientKey, ServerKey};
|
||||
/// //mantissa and exponent defined over 4 blocks ///
|
||||
/// let size_mantissa = 4;
|
||||
/// let size_exponent = 2;
|
||||
///
|
||||
/// // Generate the client key:
|
||||
/// let param = (PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32);
|
||||
/// let cks = ClientKey::new(param, size_mantissa, size_exponent);
|
||||
///
|
||||
/// // Generate the server key:
|
||||
/// let sks = ServerKey::new(&cks);
|
||||
/// ```
|
||||
pub fn new(cks: &ClientKey) -> ServerKey {
|
||||
// It should remain just enough space to add a carry
|
||||
let max =
|
||||
(cks.key.parameters.message_modulus().0 - 1) * cks.key.parameters.carry_modulus().0 - 1;
|
||||
let key =
|
||||
shortint::server_key::ServerKey::new_with_max_degree(&cks.key, MaxDegree::new(max));
|
||||
let integer_key = tfhe::integer::server_key::ServerKey::from_shortint_ex(key.clone());
|
||||
let wopbs_key =
|
||||
shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(&cks.key, &key.clone());
|
||||
ServerKey {
|
||||
key,
|
||||
integer_key,
|
||||
wopbs_key,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a ciphertext filled with zeros
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::gen_keys;
|
||||
/// use concrete_shortint::parameters::DEFAULT_PARAMETERS;
|
||||
///
|
||||
/// let size_mantissa = 4;
|
||||
/// let size_exponent = 4;
|
||||
/// let e_min = -2;
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys(&DEFAULT_PARAMETERS, size, size);
|
||||
///
|
||||
/// let ctxt = sks.create_trivial_zero(size_mantissa, size_exponent, e_min, vec![]);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec = cks.decrypt(&ctxt);
|
||||
/// assert_eq!(0, dec);
|
||||
/// ```
|
||||
pub fn create_trivial_zero(
|
||||
&self,
|
||||
size_mantissa: usize,
|
||||
size_exponent: usize,
|
||||
e_min: i64,
|
||||
) -> Ciphertext {
|
||||
let mut vec_res_mantissa = Vec::<shortint::Ciphertext>::with_capacity(size_mantissa);
|
||||
let mut zero = self.key.create_trivial(0_u64);
|
||||
zero.degree = Degree::new(0);
|
||||
for _ in 0..size_mantissa {
|
||||
vec_res_mantissa.push(zero.clone());
|
||||
}
|
||||
|
||||
let mut vec_res_exponent = Vec::<shortint::Ciphertext>::with_capacity(size_exponent);
|
||||
for _ in 0..size_exponent {
|
||||
vec_res_exponent.push(zero.clone());
|
||||
}
|
||||
|
||||
let sign = zero;
|
||||
|
||||
Ciphertext {
|
||||
ct_vec_mantissa: vec_res_mantissa,
|
||||
ct_vec_exponent: vec_res_exponent,
|
||||
ct_sign: sign,
|
||||
e_min,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_trivial_zero_from_ct(&self, ctxt: &Ciphertext) -> Ciphertext {
|
||||
self.create_trivial_zero(
|
||||
ctxt.ct_vec_mantissa.len(),
|
||||
ctxt.ct_vec_exponent.len(),
|
||||
ctxt.e_min,
|
||||
)
|
||||
}
|
||||
|
||||
/// Propagate the carry of the 'index' block to the next one.
|
||||
/// if index is equals to the MS LWE, this operation do nothing.
|
||||
/// We want to keep all the information on this LWE ( with this operation we can't create a
|
||||
/// new LWE
|
||||
pub fn propagate_mantissa(&self, ctxt: &mut [shortint::Ciphertext], index: usize) {
|
||||
if index < ctxt.len() - 1 {
|
||||
let carry = self.key.carry_extract(&ctxt[index]);
|
||||
ctxt[index] = self.key.message_extract(&ctxt[index]);
|
||||
self.key.unchecked_add_assign(&mut ctxt[index + 1], &carry);
|
||||
}
|
||||
//TODO maybe just BS to decrease the noise ?
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
pub fn full_propagate_mantissa(&self, ctxt: &mut [shortint::Ciphertext]) {
|
||||
let len = ctxt.len();
|
||||
for i in 0..len {
|
||||
self.propagate_mantissa(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn propagate_exponent(&self, ctxt: &mut Vec<shortint::Ciphertext>, index: usize) {
|
||||
if index < ctxt.len() - 1 {
|
||||
let carry = self.key.carry_extract(&ctxt[index]);
|
||||
ctxt[index] = self.key.message_extract(&ctxt[index]);
|
||||
self.key.unchecked_add_assign(&mut ctxt[index + 1], &carry);
|
||||
} else {
|
||||
ctxt[index] = self.key.message_extract(&ctxt[index]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
/// except the msb lwe
|
||||
pub fn partial_propagate(&self, ctxt: &mut Vec<shortint::Ciphertext>) {
|
||||
for i in 0..(ctxt.len() - 1) {
|
||||
self.propagate_exponent(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
pub fn full_propagate_exponent(&self, ctxt: &mut Vec<shortint::Ciphertext>) {
|
||||
for i in 0..(ctxt.len()) {
|
||||
self.propagate_exponent(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// boolean bootstrapping
|
||||
pub fn reduce_noise_sign(&self, ctxt: &mut Ciphertext) {
|
||||
let msg_modulus = ctxt.ct_sign.message_modulus.0 as u64;
|
||||
let car_modulus = ctxt.ct_sign.carry_modulus.0 as u64;
|
||||
let msg_space = msg_modulus * car_modulus;
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut ctxt.ct_sign, (msg_space / 2) as u8);
|
||||
let accumulator = self
|
||||
.key
|
||||
.generate_lookup_table(|x| (x & (msg_space / 2)).wrapping_neg());
|
||||
//self.key.keyswitch_programmable_bootstrap_assign(&mut ctxt.ct_sign, &accumulator);
|
||||
self.key
|
||||
.apply_lookup_table_assign(&mut ctxt.ct_sign, &accumulator);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut ctxt.ct_sign, (msg_space / 2) as u8);
|
||||
// We can always add as the sign is managed on the padding bit, the only important thing is
|
||||
// the noise
|
||||
ctxt.ct_sign.degree = Degree::new(0);
|
||||
}
|
||||
|
||||
fn propagate_mantissa_increase_exponent_if_necessary(
|
||||
&self,
|
||||
ctxt: &mut Ciphertext,
|
||||
index: usize,
|
||||
) {
|
||||
if index < ctxt.ct_vec_mantissa.len() - 1 {
|
||||
let carry = self.key.carry_extract(&ctxt.ct_vec_mantissa[index]);
|
||||
ctxt.ct_vec_mantissa[index] = self.key.message_extract(&ctxt.ct_vec_mantissa[index]);
|
||||
self.key
|
||||
.unchecked_add_assign(&mut ctxt.ct_vec_mantissa[index + 1], &carry);
|
||||
} else {
|
||||
self.increase_exponent_if_necessary(ctxt);
|
||||
}
|
||||
}
|
||||
|
||||
fn increase_exponent_if_necessary(&self, ctxt: &mut Ciphertext) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = f64::log2((msg_modulus * car_modulus) as f64) as usize;
|
||||
let len = ctxt.ct_vec_mantissa.len();
|
||||
let carry = self
|
||||
.key
|
||||
.carry_extract(&ctxt.ct_vec_mantissa.last().unwrap());
|
||||
ctxt.ct_vec_mantissa[len - 1] = self
|
||||
.key
|
||||
.message_extract(&ctxt.ct_vec_mantissa.last().clone().unwrap());
|
||||
let mut tmp = ctxt.clone();
|
||||
tmp.ct_vec_mantissa.push(carry.clone());
|
||||
let _ = tmp.ct_vec_mantissa.remove(0);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut tmp.ct_vec_exponent[0], 1);
|
||||
let ggsw_carry = self.ggsw_pbs_ks_cbs(&carry, msg_space);
|
||||
let res = self.cmuxes_full(ctxt, &tmp, &ggsw_carry);
|
||||
ctxt.ct_vec_mantissa = res.ct_vec_mantissa;
|
||||
ctxt.ct_vec_exponent = res.ct_vec_exponent;
|
||||
}
|
||||
|
||||
pub fn full_propagate_mantissa_increase_exponent_if_necessary(&self, ctxt: &mut Ciphertext) {
|
||||
let len = ctxt.ct_vec_mantissa.len();
|
||||
for i in 0..len {
|
||||
self.propagate_mantissa_increase_exponent_if_necessary(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clean_degree(&self, ctxt: &mut Ciphertext) {
|
||||
self.reduce_noise_sign(ctxt);
|
||||
self.full_propagate_exponent(&mut ctxt.ct_vec_exponent);
|
||||
self.full_propagate_mantissa_increase_exponent_if_necessary(ctxt)
|
||||
}
|
||||
|
||||
/// Propagate the carry of the 'index' block to the next one.
|
||||
/// if index is equals to the MS LWE, this operation do nothing.
|
||||
/// We want to keep all the information on this LWE ( with this operation we can't create a
|
||||
/// new LWE
|
||||
pub fn propagate_mantissa_parallelized(&self, ctxt: &mut [shortint::Ciphertext], index: usize) {
|
||||
// todo!("propagate_mantissa_parallelized");
|
||||
if index < ctxt.len() - 1 {
|
||||
let (carry, msg) = rayon::join(
|
||||
|| self.key.carry_extract(&ctxt[index]),
|
||||
|| self.key.message_extract(&ctxt[index]),
|
||||
);
|
||||
ctxt[index] = msg;
|
||||
self.key.unchecked_add_assign(&mut ctxt[index + 1], &carry);
|
||||
}
|
||||
//TODO maybe just BS to decrease the noise ?
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
pub fn full_propagate_mantissa_parallelized(&self, ctxt: &mut [shortint::Ciphertext]) {
|
||||
// todo!("full_propagate_mantissa_parallelized");
|
||||
let len = ctxt.len();
|
||||
for i in 0..len {
|
||||
self.propagate_mantissa_parallelized(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO use the low latency propagation
|
||||
pub fn propagate_exponent_parallelized(&self, ctxt: &mut [shortint::Ciphertext], index: usize) {
|
||||
if index < ctxt.len() - 1 {
|
||||
let (carry, msg) = rayon::join(
|
||||
|| self.key.carry_extract(&ctxt[index]),
|
||||
|| self.key.message_extract(&ctxt[index]),
|
||||
);
|
||||
ctxt[index] = msg;
|
||||
self.key.unchecked_add_assign(&mut ctxt[index + 1], &carry);
|
||||
} else {
|
||||
self.key.message_extract_assign(&mut ctxt[index]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
/// except the msb lwe
|
||||
pub fn partial_propagate_parallelized(&self, ctxt: &mut Vec<shortint::Ciphertext>) {
|
||||
for i in 0..(ctxt.len() - 1) {
|
||||
self.propagate_exponent_parallelized(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
pub fn full_propagate_exponent_parallelized(&self, ctxt: &mut [shortint::Ciphertext]) {
|
||||
for i in 0..(ctxt.len()) {
|
||||
self.propagate_exponent_parallelized(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
fn propagate_mantissa_increase_exponent_if_necessary_parallelized(
|
||||
&self,
|
||||
ctxt: &mut Ciphertext,
|
||||
index: usize,
|
||||
) {
|
||||
if index < ctxt.ct_vec_mantissa.len() - 1 {
|
||||
let (carry, msg) = rayon::join(
|
||||
|| self.key.carry_extract(&ctxt.ct_vec_mantissa[index]),
|
||||
|| self.key.message_extract(&ctxt.ct_vec_mantissa[index]),
|
||||
);
|
||||
ctxt.ct_vec_mantissa[index] = msg;
|
||||
self.key
|
||||
.unchecked_add_assign(&mut ctxt.ct_vec_mantissa[index + 1], &carry);
|
||||
} else {
|
||||
self.increase_exponent_if_necessary_parallelized(ctxt);
|
||||
}
|
||||
}
|
||||
|
||||
fn increase_exponent_if_necessary_parallelized(&self, ctxt: &mut Ciphertext) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = f64::log2((msg_modulus * car_modulus) as f64) as usize;
|
||||
let len = ctxt.ct_vec_mantissa.len();
|
||||
let (carry, msg) = rayon::join(
|
||||
|| {
|
||||
self.key
|
||||
.carry_extract(&ctxt.ct_vec_mantissa.last().unwrap())
|
||||
},
|
||||
|| {
|
||||
self.key
|
||||
.message_extract(&ctxt.ct_vec_mantissa.last().clone().unwrap())
|
||||
},
|
||||
);
|
||||
|
||||
ctxt.ct_vec_mantissa[len - 1] = msg;
|
||||
let mut tmp = ctxt.clone();
|
||||
tmp.ct_vec_mantissa.push(carry.clone());
|
||||
let _ = tmp.ct_vec_mantissa.remove(0);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut tmp.ct_vec_exponent[0], 1);
|
||||
let ggsw_carry = self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(&carry, msg_space);
|
||||
let res = self.cmuxes_full_parallelized(ctxt, &tmp, &ggsw_carry);
|
||||
ctxt.ct_vec_mantissa = res.ct_vec_mantissa;
|
||||
ctxt.ct_vec_exponent = res.ct_vec_exponent;
|
||||
}
|
||||
|
||||
fn increase_exponent_if_necessary_parallelized_carry(
|
||||
&self,
|
||||
ctxt: &mut Ciphertext,
|
||||
mantissa_carry: &shortint::Ciphertext,
|
||||
) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = (msg_modulus * car_modulus).ilog2() as usize;
|
||||
|
||||
let mut tmp = ctxt.clone();
|
||||
tmp.ct_vec_mantissa.push(mantissa_carry.clone());
|
||||
let _ = tmp.ct_vec_mantissa.remove(0);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut tmp.ct_vec_exponent[0], 1);
|
||||
let ggsw_carry =
|
||||
self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(&mantissa_carry, msg_space);
|
||||
let res = self.cmuxes_full_parallelized(ctxt, &tmp, &ggsw_carry);
|
||||
ctxt.ct_vec_mantissa = res.ct_vec_mantissa;
|
||||
ctxt.ct_vec_exponent = res.ct_vec_exponent;
|
||||
}
|
||||
|
||||
pub fn full_propagate_mantissa_increase_exponent_if_necessary_parallelized(
|
||||
&self,
|
||||
ctxt: &mut Ciphertext,
|
||||
) {
|
||||
let len = ctxt.ct_vec_mantissa.len();
|
||||
for i in 0..len {
|
||||
self.propagate_mantissa_increase_exponent_if_necessary_parallelized(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clean_degree_parallelized(&self, ctxt: &mut Ciphertext) {
|
||||
// todo!("clean_degree_parallelized");
|
||||
self.reduce_noise_sign(ctxt);
|
||||
// let now = std::time::Instant::now();
|
||||
self.full_propagate_exponent_parallelized(&mut ctxt.ct_vec_exponent);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("elapsed exponent propagate: {elapsed:?}");
|
||||
|
||||
// let now = std::time::Instant::now();
|
||||
self.full_propagate_mantissa_increase_exponent_if_necessary_parallelized(ctxt);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("elapsed mantissa propagate: {elapsed:?}");
|
||||
}
|
||||
}
|
||||
373
concrete-float/src/server_key/mul.rs
Normal file
373
concrete-float/src/server_key/mul.rs
Normal file
@@ -0,0 +1,373 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use std::cmp::{max, min};
|
||||
use tfhe::shortint;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn mul(&self, ct1: &mut Ciphertext, ct2: &mut Ciphertext) -> Ciphertext {
|
||||
// carry need to be empty
|
||||
for ct in ct1.ct_vec_mantissa.iter_mut() {
|
||||
if ct.degree.get() > self.wopbs_key.param.message_modulus.0 {
|
||||
self.full_propagate_mantissa(&mut ct1.ct_vec_mantissa);
|
||||
break;
|
||||
}
|
||||
}
|
||||
for ct in ct2.ct_vec_mantissa.iter() {
|
||||
if ct.degree.get() > self.wopbs_key.param.message_modulus.0 {
|
||||
self.full_propagate_mantissa(&mut ct2.ct_vec_mantissa);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut res = self.mul_mantissa(ct1, ct2);
|
||||
res = self.add_exponent_for_mul(&mut res.clone(), ct2);
|
||||
res.ct_sign = self.add_sign_for_mul(ct1, ct2);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn mul_parallelized(
|
||||
&self,
|
||||
ct1: &mut Ciphertext,
|
||||
ct2: &mut Ciphertext,
|
||||
) -> (Ciphertext, shortint::Ciphertext) {
|
||||
// let now = std::time::Instant::now();
|
||||
let (mut res, mantissa_carry) = self.mul_mantissa_parallelized(ct1, ct2);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("mul_mantissa: {elapsed:?}");
|
||||
|
||||
res = self.add_exponent_for_mul_parallelized(&mut res.clone(), ct2, &mantissa_carry);
|
||||
res.ct_sign = self.add_sign_for_mul_parallelized(ct1, ct2);
|
||||
(res, mantissa_carry)
|
||||
}
|
||||
|
||||
fn mul_mantissa(&self, ct1: &mut Ciphertext, ct2: &mut Ciphertext) -> Ciphertext {
|
||||
let mantissa_len = ct1.ct_vec_mantissa.len();
|
||||
let value = (mantissa_len - 1) / 2;
|
||||
let mut result = self.create_trivial_zero(
|
||||
2 * mantissa_len - value - 1,
|
||||
ct1.ct_vec_exponent.len(),
|
||||
ct1.e_min,
|
||||
);
|
||||
|
||||
for (i, ct2_i) in ct2.ct_vec_mantissa.iter().enumerate() {
|
||||
let bound = max((value - i) as i64, 0) as usize;
|
||||
let tmp = self.block_mul(
|
||||
&ct1.ct_vec_mantissa[bound..].to_vec(),
|
||||
ct2_i,
|
||||
i,
|
||||
ct1.ct_vec_mantissa.len(),
|
||||
);
|
||||
if !self.is_add_possible(
|
||||
&tmp,
|
||||
&result.ct_vec_mantissa
|
||||
[min(0, (value - i) as i64).abs() as usize..(i + mantissa_len - value)],
|
||||
) {
|
||||
// we propagate only the necessary blocks,
|
||||
// to not loose any information, we propagate one blocks before and one blocks after
|
||||
self.full_propagate_mantissa(
|
||||
&mut result.ct_vec_mantissa[min(0, (value + 1 - i) as i64).abs() as usize
|
||||
..min(i + mantissa_len + 2 - value, 2 * mantissa_len - 1 - value)],
|
||||
);
|
||||
//self.full_propagate_mantissa(&mut result.ct_vec_mantissa);
|
||||
}
|
||||
for (ct_left_j, ct_right_j) in result.ct_vec_mantissa[min(0, (value - i) as i64).abs()
|
||||
as usize
|
||||
..min(i + mantissa_len + 1 - value, 2 * mantissa_len - 1 - value)]
|
||||
.iter_mut()
|
||||
.zip(tmp.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_j, ct_right_j);
|
||||
}
|
||||
}
|
||||
|
||||
// the (log_msg_modulus * mantissa.len()) most significant bit of a multiplication are
|
||||
// include either in the [mantissa_len, 2*mantissa_len] or in [mantissa_len - 1,
|
||||
// 2*mantissa_len - 1] we choose the first one if the block 2*mantissa_len is not
|
||||
// empty otherwise we choose the first one
|
||||
let mut result_trunc = self.create_trivial_zero_from_ct(ct1);
|
||||
result_trunc.ct_vec_mantissa =
|
||||
result.ct_vec_mantissa[(mantissa_len - 1 - value)..].to_vec();
|
||||
result_trunc.ct_vec_exponent = ct1.ct_vec_exponent.clone();
|
||||
|
||||
result_trunc
|
||||
}
|
||||
|
||||
// Return the float ciphertext and the mantissa carry
|
||||
fn mul_mantissa_parallelized(
|
||||
&self,
|
||||
ct1: &Ciphertext,
|
||||
ct2: &Ciphertext,
|
||||
) -> (Ciphertext, shortint::Ciphertext) {
|
||||
use tfhe::integer::{IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext};
|
||||
|
||||
let mantissa_len = ct1.ct_vec_mantissa.len();
|
||||
let mantissa_len_for_mul_with_carry = mantissa_len * 2;
|
||||
let mut ct1_mantissa = ct1.ct_vec_mantissa.to_vec();
|
||||
ct1_mantissa.resize(mantissa_len_for_mul_with_carry, self.key.create_trivial(0));
|
||||
let mut ct2_mantissa = ct2.ct_vec_mantissa.to_vec();
|
||||
ct2_mantissa.resize(mantissa_len_for_mul_with_carry, self.key.create_trivial(0));
|
||||
let ct1_mantissa_as_integer = RadixCiphertext::from_blocks(ct1_mantissa);
|
||||
let ct2_mantissa_as_integer = RadixCiphertext::from_blocks(ct2_mantissa);
|
||||
|
||||
// println!("ct1_len = {}", ct1_mantissa_as_integer.blocks().len());
|
||||
// println!("ct2_len = {}", ct2_mantissa_as_integer.blocks().len());
|
||||
|
||||
// let now = std::time::Instant::now();
|
||||
let mul_result = self
|
||||
.integer_key
|
||||
.mul_parallelized(&ct1_mantissa_as_integer, &ct2_mantissa_as_integer);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("integer mul: {elapsed:?}");
|
||||
|
||||
let mut mul_result_blocks = mul_result.into_blocks();
|
||||
let carry_block = mul_result_blocks.pop().unwrap();
|
||||
let mantissa = mul_result_blocks[mantissa_len - 1..].to_vec();
|
||||
assert_eq!(mantissa.len(), ct1.ct_vec_mantissa.len());
|
||||
let mut result_trunc = self.create_trivial_zero_from_ct(ct1);
|
||||
result_trunc.ct_vec_mantissa = mantissa;
|
||||
result_trunc.ct_vec_exponent = ct1.ct_vec_exponent.clone();
|
||||
|
||||
(result_trunc, carry_block)
|
||||
}
|
||||
|
||||
// multiply one block of a mantissa by each block of another mantissa and create a mantissa of
|
||||
// this mul
|
||||
fn block_mul(
|
||||
&self,
|
||||
ct1: &Vec<shortint::ciphertext::Ciphertext>,
|
||||
ct2: &shortint::ciphertext::Ciphertext,
|
||||
index: usize,
|
||||
len_man: usize,
|
||||
) -> Vec<shortint::ciphertext::Ciphertext> {
|
||||
let zero = self.key.create_trivial(0);
|
||||
let mut result = vec![zero.clone()];
|
||||
let mut result_lsb = ct1.clone();
|
||||
let mut result_msb = ct1.clone();
|
||||
if index != len_man - 1 {
|
||||
for (ct_lsb_i, ct_msb_i) in result_lsb.iter_mut().zip(result_msb.iter_mut()) {
|
||||
self.key.unchecked_mul_msb_assign(ct_msb_i, ct2);
|
||||
self.key.unchecked_mul_lsb_assign(ct_lsb_i, ct2);
|
||||
}
|
||||
result_lsb.push(zero.clone());
|
||||
result.append(&mut result_msb.clone());
|
||||
} else {
|
||||
for (ct_lsb_i, ct_msb_i) in result_lsb[..len_man - 1]
|
||||
.iter_mut()
|
||||
.zip(result_msb[..len_man - 1].iter_mut())
|
||||
{
|
||||
self.key.unchecked_mul_msb_assign(ct_msb_i, ct2);
|
||||
self.key.unchecked_mul_lsb_assign(ct_lsb_i, ct2);
|
||||
}
|
||||
|
||||
let msg_mod = self.key.message_modulus.0 as u64;
|
||||
let tmp = self.key.unchecked_scalar_mul(ct2, msg_mod as u8);
|
||||
self.key
|
||||
.unchecked_add_assign(result_lsb.last_mut().unwrap(), &tmp);
|
||||
|
||||
// Generate the accumulator for the multiplication
|
||||
let acc = self
|
||||
.key
|
||||
.generate_lookup_table(|x| (x / msg_mod) * (x % msg_mod));
|
||||
self.key
|
||||
.apply_lookup_table_assign(result_lsb.last_mut().unwrap(), &acc);
|
||||
|
||||
result.append(&mut result_msb.clone());
|
||||
result.pop();
|
||||
}
|
||||
|
||||
for (ct1_i, ct2_i) in result.iter_mut().zip(result_lsb.iter()) {
|
||||
self.key.unchecked_add_assign(ct1_i, ct2_i)
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
//sum the two sign for the mul
|
||||
fn add_sign_for_mul(
|
||||
&self,
|
||||
ct1: &mut Ciphertext,
|
||||
ct2: &mut Ciphertext,
|
||||
) -> shortint::ciphertext::Ciphertext {
|
||||
if self
|
||||
.key
|
||||
.is_add_possible(&ct1.ct_sign, &ct2.ct_sign)
|
||||
.is_err()
|
||||
{
|
||||
self.reduce_noise_sign(ct1);
|
||||
self.reduce_noise_sign(ct2);
|
||||
}
|
||||
self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign)
|
||||
}
|
||||
|
||||
fn add_sign_for_mul_parallelized(
|
||||
&self,
|
||||
ct1: &mut Ciphertext,
|
||||
ct2: &mut Ciphertext,
|
||||
) -> shortint::ciphertext::Ciphertext {
|
||||
if self
|
||||
.key
|
||||
.is_add_possible(&ct1.ct_sign, &ct2.ct_sign)
|
||||
.is_err()
|
||||
{
|
||||
rayon::join(
|
||||
|| self.reduce_noise_sign(ct1),
|
||||
|| self.reduce_noise_sign(ct2),
|
||||
);
|
||||
}
|
||||
self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign)
|
||||
}
|
||||
|
||||
// add the two exponent and subtract the value e_min and the shift on the MSB blocks
|
||||
fn add_exponent_for_mul(&self, ct1: &mut Ciphertext, ct2: &mut Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let carry_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let log_msg_modulus = f64::log2(msg_modulus as f64) as u64;
|
||||
let log_msg_space = f64::log2((carry_modulus * msg_modulus) as f64) as usize;
|
||||
let len_vec_exp = ct1.ct_vec_exponent.len();
|
||||
|
||||
if !self.is_add_possible(&ct1.ct_vec_exponent, &ct2.ct_vec_exponent) {
|
||||
self.partial_propagate(&mut ct1.ct_vec_exponent);
|
||||
self.partial_propagate(&mut ct2.ct_vec_exponent);
|
||||
}
|
||||
let mut res = ct1.clone();
|
||||
for (ct_left_j, ct_right_j) in res
|
||||
.ct_vec_exponent
|
||||
.iter_mut()
|
||||
.zip(ct2.ct_vec_exponent.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_j, ct_right_j);
|
||||
}
|
||||
let cst = ct1.e_min + ct1.ct_vec_mantissa.len() as i64 - 1;
|
||||
let cst = (cst.abs() as u64) >> (log_msg_modulus * (len_vec_exp - 1) as u64);
|
||||
|
||||
//check if the exponent is big enough (return 1 if e is to small, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x < cst) as u64));
|
||||
let mut ct_sign = self
|
||||
.key
|
||||
.apply_lookup_table(&mut res.ct_vec_exponent.last().unwrap(), &accumulator);
|
||||
|
||||
//check if the mantissa is not equals to zero (return 1 if ms_lwe== 0, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x == 0) as u64));
|
||||
let ms_lwe = self
|
||||
.key
|
||||
.apply_lookup_table(&mut ct1.ct_vec_mantissa.last().unwrap(), &accumulator);
|
||||
self.key.unchecked_add_assign(&mut ct_sign, &ms_lwe);
|
||||
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x > 0) as u64));
|
||||
let ct_sign = self.key.apply_lookup_table(&mut ct_sign, &accumulator);
|
||||
|
||||
let sign_ggsw = self.ggsw_ks_cbs(&ct_sign, log_msg_space);
|
||||
let zero = self.create_trivial_zero_from_ct(ct1);
|
||||
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x - cst) % msg_modulus);
|
||||
self.key
|
||||
.apply_lookup_table_assign(&mut res.ct_vec_exponent[len_vec_exp - 1], &accumulator);
|
||||
res = self.cmuxes_full(&res, &zero, &sign_ggsw);
|
||||
res
|
||||
}
|
||||
|
||||
// add the two exponent and subtract the value e_min and the shift on the MSB blocks
|
||||
fn add_exponent_for_mul_parallelized(
|
||||
&self,
|
||||
ct1: &mut Ciphertext,
|
||||
ct2: &mut Ciphertext,
|
||||
mantissa_carry: &shortint::Ciphertext,
|
||||
) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let carry_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let log_msg_modulus = msg_modulus.ilog2() as u64;
|
||||
let log_msg_space = (carry_modulus * msg_modulus).ilog2() as usize;
|
||||
let len_vec_exp = ct1.ct_vec_exponent.len();
|
||||
|
||||
if !self.is_add_possible(&ct1.ct_vec_exponent, &ct2.ct_vec_exponent) {
|
||||
rayon::join(
|
||||
|| self.partial_propagate(&mut ct1.ct_vec_exponent),
|
||||
|| self.partial_propagate(&mut ct2.ct_vec_exponent),
|
||||
);
|
||||
}
|
||||
let mut res = ct1.clone();
|
||||
for (ct_left_j, ct_right_j) in res
|
||||
.ct_vec_exponent
|
||||
.iter_mut()
|
||||
.zip(ct2.ct_vec_exponent.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_j, ct_right_j);
|
||||
}
|
||||
let cst = ct1.e_min + ct1.ct_vec_mantissa.len() as i64 - 1;
|
||||
let cst = (cst.abs() as u64) >> (log_msg_modulus * (len_vec_exp - 1) as u64);
|
||||
|
||||
let (mut ct_sign, ms_lwe) = rayon::join(
|
||||
|| {
|
||||
//check if the exponent is big enough (return 1 if e is to small, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x < cst) as u64));
|
||||
self.key
|
||||
.apply_lookup_table(&mut res.ct_vec_exponent.last().unwrap(), &accumulator)
|
||||
},
|
||||
|| {
|
||||
//check if the mantissa is not equals to zero (return 1 if ms_lwe== 0, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x == 0) as u64));
|
||||
let mut last_mantissa_block = ct1.ct_vec_mantissa.last().unwrap().clone();
|
||||
// We recreate a mantissa block containing the msg + carry as we only want to know
|
||||
// if it was 0
|
||||
self.key
|
||||
.unchecked_add_assign(&mut last_mantissa_block, &mantissa_carry);
|
||||
self.key
|
||||
.apply_lookup_table(&last_mantissa_block, &accumulator)
|
||||
},
|
||||
);
|
||||
|
||||
self.key.unchecked_add_assign(&mut ct_sign, &ms_lwe);
|
||||
|
||||
rayon::join(
|
||||
|| {
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x > 0) as u64));
|
||||
self.key
|
||||
.apply_lookup_table_assign(&mut ct_sign, &accumulator);
|
||||
},
|
||||
|| {
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x - cst) % msg_modulus);
|
||||
self.key.apply_lookup_table_assign(
|
||||
&mut res.ct_vec_exponent[len_vec_exp - 1],
|
||||
&accumulator,
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
let sign_ggsw = self.ggsw_ks_cbs_parallelized(&ct_sign, log_msg_space);
|
||||
|
||||
let zero = self.create_trivial_zero_from_ct(ct1);
|
||||
res = self.cmuxes_full_parallelized(&res, &zero, &sign_ggsw);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn mul_total(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let mut res = self.mul(&mut ct1.clone(), &mut ct2.clone());
|
||||
self.clean_degree(&mut res);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn mul_total_parallelized(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
// let now = std::time::Instant::now();
|
||||
let (mut res, mantissa_carry) = self.mul_parallelized(&mut ct1.clone(), &mut ct2.clone());
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("mul_parallelized: {elapsed:?}");
|
||||
|
||||
// self.clean_degree_parallelized(&mut res);
|
||||
|
||||
self.reduce_noise_sign(&mut res);
|
||||
// let now = std::time::Instant::now();
|
||||
self.full_propagate_exponent_parallelized(&mut res.ct_vec_exponent);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("elapsed exponent propagate: {elapsed:?}");
|
||||
|
||||
// let now = std::time::Instant::now();
|
||||
// No need to propagate the mantissa it is clean after the integer mul parallelized
|
||||
// self.full_propagate_mantissa_increase_exponent_if_necessary_parallelized(&mut res);
|
||||
|
||||
// TODO change the management of the carry
|
||||
self.increase_exponent_if_necessary_parallelized_carry(&mut res, &mantissa_carry);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("elapsed mantissa propagate: {elapsed:?}");
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
10
concrete-float/src/server_key/relu.rs
Normal file
10
concrete-float/src/server_key/relu.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn relu(&self, ct: &Ciphertext) -> Ciphertext {
|
||||
let zero = self.create_trivial_zero_from_ct(ct);
|
||||
let ggsw = self.ggsw_ks_cbs(&ct.ct_sign, 0);
|
||||
self.cmuxes_full(&ct, &zero, &ggsw)
|
||||
}
|
||||
}
|
||||
43
concrete-float/src/server_key/sigmoid.rs
Normal file
43
concrete-float/src/server_key/sigmoid.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::server_key::ServerKey;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn sigmoid(&self, ct: &Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let carry_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let log_msg_modulus = f64::log2(msg_modulus as f64) as u64;
|
||||
let log_carry_modulus = f64::log2(carry_modulus as f64) as u64;
|
||||
let cst = ct.e_min + ct.ct_vec_mantissa.len() as i64 - 1;
|
||||
let cst = (cst.abs() as u64) >> (log_msg_modulus * (ct.ct_vec_exponent.len() - 1) as u64);
|
||||
|
||||
let mut one = self.create_trivial_zero_from_ct(ct);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut one.ct_vec_mantissa.last_mut().unwrap(), 1 as u8);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut one.ct_vec_exponent.last_mut().unwrap(), cst as u8);
|
||||
|
||||
let mut minus_one = one.clone();
|
||||
self.change_sign_assign(&mut minus_one);
|
||||
let ggsw = self.ggsw_ks_cbs(&ct.ct_sign, 0);
|
||||
let tmp = self.cmuxes_full(&one, &minus_one, &ggsw);
|
||||
|
||||
let value = msg_modulus / 2;
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x > value) as u64);
|
||||
let ct_last = self
|
||||
.key
|
||||
.apply_lookup_table(&mut ct.ct_vec_mantissa.last().unwrap(), &accumulator);
|
||||
|
||||
//check if the exponent is big enough (return 1 if e is to small, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x < cst) as u64));
|
||||
let mut ct_sign = self
|
||||
.key
|
||||
.apply_lookup_table(&mut ct.ct_vec_exponent.last().unwrap(), &accumulator);
|
||||
|
||||
self.key.unchecked_add_assign(&mut ct_sign, &ct_last);
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x > 0) as u64));
|
||||
let ct_sign = self.key.apply_lookup_table(&mut ct_sign, &accumulator);
|
||||
|
||||
let ggsw = self.ggsw_ks_cbs(&ct_sign, (log_carry_modulus + log_msg_modulus) as usize);
|
||||
self.cmuxes_full(&tmp, &ct, &ggsw)
|
||||
}
|
||||
}
|
||||
448
concrete-float/src/server_key/sub.rs
Normal file
448
concrete-float/src/server_key/sub.rs
Normal file
@@ -0,0 +1,448 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use rayon::prelude::*;
|
||||
use shortint::ciphertext::Degree;
|
||||
use std::cmp::max;
|
||||
use tfhe::core_crypto::prelude::{Cleartext, Plaintext};
|
||||
use tfhe::shortint;
|
||||
|
||||
impl ServerKey {
|
||||
// This operation return |a - b| and sing(a-b)
|
||||
// after sub all the blocks have the smallest degree except the most significant block
|
||||
pub fn sub(
|
||||
&self,
|
||||
ctxt_left: &Vec<shortint::Ciphertext>,
|
||||
ctxt_right: &Vec<shortint::Ciphertext>,
|
||||
) -> (Vec<shortint::Ciphertext>, shortint::Ciphertext) {
|
||||
let mut ct_tmp: Vec<shortint::Ciphertext> = Vec::new();
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as u64;
|
||||
let size_ct = ctxt_left.len();
|
||||
for ct in ctxt_left.iter() {
|
||||
ct_tmp.push(
|
||||
self.key
|
||||
.unchecked_scalar_add(ct, ((msg_space / 2) - car_modulus / 2) as u8),
|
||||
);
|
||||
}
|
||||
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut ct_tmp[0], (car_modulus / 2) as u8);
|
||||
let cpy_right = ctxt_right.clone();
|
||||
for (c_left, c_right) in ct_tmp.iter_mut().zip(cpy_right.iter()) {
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_sub_assign(&mut c_left.ct, &c_right.ct);
|
||||
let noise_level = c_left.noise_level() + c_right.noise_level();
|
||||
c_left.set_noise_level(noise_level);
|
||||
}
|
||||
self.partial_propagate(&mut ct_tmp);
|
||||
//extract the sign (the first value add on the most significant block)
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x & (msg_space / 2)));
|
||||
let mut sign = self
|
||||
.key
|
||||
.apply_lookup_table(ct_tmp.last_mut().unwrap(), &accumulator);
|
||||
// the value sign encrypt only 1 or 0 so the degree is 1
|
||||
|
||||
// We can always add as the sign is managed on the padding bit, the only important thing is
|
||||
// the noise
|
||||
sign.degree = Degree::new(0);
|
||||
|
||||
// add the sign on each block
|
||||
for i in 0..(size_ct - 1) {
|
||||
self.key.unchecked_add_assign(&mut ct_tmp[i], &sign);
|
||||
}
|
||||
|
||||
// if the sign on each block ==0, we take the opposite, otherwise we return the value.
|
||||
// to find the opposite we perform the same idea than the subtraction (but only with pbs as
|
||||
// we know one value ) opposite = (1 << (len * precision)) - x
|
||||
for (i, ct) in ct_tmp.iter_mut().enumerate() {
|
||||
if i == 0 {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_modulus - x))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_modulus - x)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
ct.degree = Degree::new(msg_modulus as usize)
|
||||
} else if i == size_ct - 1 {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_space / 2 - x - 1))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_space / 2 - x - 1)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
ct.degree = Degree::new(max(
|
||||
(msg_space as usize / 2) - ct.degree.get(),
|
||||
ct.degree.get(),
|
||||
));
|
||||
} else {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_modulus - x - 1))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_modulus - x - 1)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
ct.degree = Degree::new(msg_modulus as usize)
|
||||
}
|
||||
}
|
||||
// move the sign bit on the msb
|
||||
// uncheck add, we juste create the sign
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_cleartext_mul_assign(
|
||||
&mut sign.ct,
|
||||
Cleartext(2),
|
||||
);
|
||||
//self.key.unchecked_scalar_mul_assign(&mut sign, 2);
|
||||
(ct_tmp, sign)
|
||||
}
|
||||
|
||||
// subtract the two mantissas
|
||||
// after the subtraction put the msb of the result on the mst significant block
|
||||
// if exponent == 0 and the first block == 0, the result is 0
|
||||
pub fn sub_mantissa(&self, ctxt_left: &Ciphertext, ctxt_right: &Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as usize;
|
||||
let (res, sign) = self.sub(&ctxt_left.ct_vec_mantissa, &ctxt_right.ct_vec_mantissa);
|
||||
|
||||
let mut new = self.create_trivial_zero_from_ct(ctxt_left);
|
||||
new.ct_vec_mantissa = res;
|
||||
new.ct_vec_exponent = ctxt_left.ct_vec_exponent.clone();
|
||||
// if sign == 0 => need to change the sign of the operation
|
||||
// if sign == 1 we want to keep the same sign
|
||||
// new_s = old_s + sign + 1
|
||||
new.ct_sign = self.key.unchecked_add(&sign, ctxt_left.sign());
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut new.ct_sign, msg_space as u8);
|
||||
|
||||
new = self.realign_sub(&new);
|
||||
new
|
||||
}
|
||||
|
||||
// move the msb on the most significant block.
|
||||
// if e = 0 and the first block is empty, return zero
|
||||
// (no subnormal value)
|
||||
pub fn realign_sub(&self, ct0: &Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = f64::log2((msg_modulus * car_modulus) as f64) as usize;
|
||||
let size_mantissa = ct0.ct_vec_mantissa.len();
|
||||
|
||||
let zero = self.create_trivial_zero_from_ct(ct0);
|
||||
let mut res = zero.clone();
|
||||
res.ct_vec_mantissa = ct0.ct_vec_mantissa.clone();
|
||||
res.ct_sign = ct0.ct_sign.clone();
|
||||
let mut msb_mantissa_ggsw =
|
||||
self.ggsw_pbs_ks_cbs(&res.ct_vec_mantissa[size_mantissa - 1], msg_space);
|
||||
for i in 0..size_mantissa {
|
||||
let mut tmp = zero.clone();
|
||||
tmp.ct_vec_mantissa = zero.ct_vec_mantissa.clone();
|
||||
for j in 0..(size_mantissa - 1) {
|
||||
tmp.ct_vec_mantissa[j + 1] = res.ct_vec_mantissa[j].clone();
|
||||
}
|
||||
for (k, ct_exp_i) in tmp.ct_vec_exponent.iter_mut().enumerate() {
|
||||
self.key.unchecked_scalar_add_assign(
|
||||
ct_exp_i,
|
||||
(((i + 1) >> (f64::log2(msg_modulus as f64) as usize * (k))) % msg_modulus)
|
||||
as u8,
|
||||
);
|
||||
}
|
||||
|
||||
// return tmp if ggsw == 0; res otherwise
|
||||
res.ct_vec_mantissa = self.cmuxes(
|
||||
&tmp.ct_vec_mantissa,
|
||||
&res.ct_vec_mantissa,
|
||||
&msb_mantissa_ggsw,
|
||||
);
|
||||
res.ct_vec_exponent = self.cmuxes(
|
||||
&tmp.ct_vec_exponent,
|
||||
&res.ct_vec_exponent,
|
||||
&msb_mantissa_ggsw,
|
||||
);
|
||||
|
||||
if i < size_mantissa - 1 {
|
||||
msb_mantissa_ggsw =
|
||||
self.ggsw_pbs_ks_cbs(&res.ct_vec_mantissa[size_mantissa - 1], msg_space);
|
||||
}
|
||||
}
|
||||
|
||||
let (mut diff_exp, sub_exp_sign) = self.sub(&ct0.ct_vec_exponent, &res.ct_vec_exponent);
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let sign_ggsw = self.ggsw_ks_cbs(&sub_exp_sign, 0); //let sign_ggsw = self.wopbs_key.extract_one_bit_cbs(&self.key, &sub_exp_sign, 63);
|
||||
diff_exp = self.cmuxes(&zero.ct_vec_exponent, &diff_exp, &msb_mantissa_ggsw);
|
||||
res.ct_vec_exponent = self.cmuxes(&zero.ct_vec_exponent, &diff_exp, &sign_ggsw);
|
||||
res.ct_vec_mantissa = self.cmuxes(&zero.ct_vec_mantissa, &res.ct_vec_mantissa, &sign_ggsw);
|
||||
res.ct_sign = res.ct_sign;
|
||||
res
|
||||
}
|
||||
|
||||
// change the sign
|
||||
pub fn change_sign_assign(&self, ct0: &mut Ciphertext) {
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_plaintext_add_assign(
|
||||
&mut ct0.ct_sign.ct,
|
||||
Plaintext(1 << 63),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn change_sign(&self, ct0: &Ciphertext) -> Ciphertext {
|
||||
let mut ct = ct0.clone();
|
||||
self.change_sign_assign(&mut ct);
|
||||
ct
|
||||
}
|
||||
|
||||
pub fn sub_total(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let ct2 = self.change_sign(ct2);
|
||||
self.add_total(&ct1, &ct2)
|
||||
}
|
||||
|
||||
// This operation return |a - b| and sing(a-b)
|
||||
// after sub all the blocks have the smallest degree except the most significant block
|
||||
// TODO: would the overflowing_sub from integer (with some slight adaptations perhaps) do the
|
||||
// trick ?
|
||||
pub fn abs_diff_parallelized(
|
||||
&self,
|
||||
ctxt_left: &Vec<shortint::Ciphertext>,
|
||||
ctxt_right: &Vec<shortint::Ciphertext>,
|
||||
) -> (Vec<shortint::Ciphertext>, shortint::Ciphertext) {
|
||||
let mut ct_tmp: Vec<shortint::Ciphertext> = Vec::with_capacity(ctxt_left.len());
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as u64;
|
||||
let size_ct = ctxt_left.len();
|
||||
for ct in ctxt_left.iter() {
|
||||
ct_tmp.push(
|
||||
self.key
|
||||
.unchecked_scalar_add(ct, ((msg_space / 2) - car_modulus / 2) as u8),
|
||||
);
|
||||
}
|
||||
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut ct_tmp[0], (car_modulus / 2) as u8);
|
||||
let cpy_right = ctxt_right.clone();
|
||||
// The operation is too small to be worth parallelizing
|
||||
ct_tmp
|
||||
.iter_mut()
|
||||
.zip(cpy_right.iter())
|
||||
.for_each(|(c_left, c_right)| {
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_sub_assign(
|
||||
&mut c_left.ct,
|
||||
&c_right.ct,
|
||||
);
|
||||
let noise_level = c_left.noise_level() + c_right.noise_level();
|
||||
c_left.set_noise_level(noise_level);
|
||||
});
|
||||
|
||||
self.partial_propagate_parallelized(&mut ct_tmp);
|
||||
//extract the sign (the first value add on the most significant block)
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x & (msg_space / 2)));
|
||||
let mut sign = self
|
||||
.key
|
||||
.apply_lookup_table(ct_tmp.last_mut().unwrap(), &accumulator);
|
||||
// the value sign encrypt only 1 or 0 so the degree is 1
|
||||
|
||||
// We can always add as the sign is managed on the padding bit, the only important thing is
|
||||
// the noise
|
||||
sign.degree = Degree::new(0);
|
||||
|
||||
// add the sign on each block, except the last one
|
||||
// Operation is too small to be worth parallelizing
|
||||
ct_tmp[0..(size_ct - 1)]
|
||||
.iter_mut()
|
||||
.for_each(|tmp_block| self.key.unchecked_add_assign(tmp_block, &sign));
|
||||
|
||||
// if the sign on each block ==0, we take the opposite, otherwise we return the value.
|
||||
// to find the opposite we perform the same idea than the subtraction (but only with pbs as
|
||||
// we know one value ) opposite = (1 << (len * precision)) - x
|
||||
ct_tmp.par_iter_mut().enumerate().for_each(|(i, ct)| {
|
||||
if i == 0 {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_modulus - x))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_modulus - x)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
} else if i == size_ct - 1 {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_space / 2 - x - 1))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_space / 2 - x - 1)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
} else {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_modulus - x - 1))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_modulus - x - 1)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
}
|
||||
});
|
||||
|
||||
// move the sign bit on the msb
|
||||
// uncheck add, we juste create the sign
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_cleartext_mul_assign(
|
||||
&mut sign.ct,
|
||||
Cleartext(2),
|
||||
);
|
||||
//self.key.unchecked_scalar_mul_assign(&mut sign, 2);
|
||||
(ct_tmp, sign)
|
||||
}
|
||||
|
||||
// subtract the two mantissas
|
||||
// after the subtraction put the msb of the result on the mst significant block
|
||||
// if exponent == 0 and the first block == 0, the result is 0
|
||||
pub fn sub_mantissa_parallelized(
|
||||
&self,
|
||||
ctxt_left: &Ciphertext,
|
||||
ctxt_right: &Ciphertext,
|
||||
) -> Ciphertext {
|
||||
// todo!("sub_mantissa_parallelized");
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as usize;
|
||||
// let now = std::time::Instant::now();
|
||||
let (res, sign) =
|
||||
self.abs_diff_parallelized(&ctxt_left.ct_vec_mantissa, &ctxt_right.ct_vec_mantissa);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("sub_mantissa_parallelized::sub_parallelized: {elapsed:?}");
|
||||
|
||||
let mut new = self.create_trivial_zero_from_ct(ctxt_left);
|
||||
new.ct_vec_mantissa = res;
|
||||
new.ct_vec_exponent = ctxt_left.ct_vec_exponent.clone();
|
||||
// if sign == 0 => need to change the sign of the operation
|
||||
// if sign == 1 we want to keep the same sign
|
||||
// new_s = old_s + sign + 1
|
||||
new.ct_sign = self.key.unchecked_add(&sign, ctxt_left.sign());
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut new.ct_sign, msg_space as u8);
|
||||
|
||||
// let now = std::time::Instant::now();
|
||||
new = self.realign_sub_parallelized(&new);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("sub_mantissa_parallelized::realign_sub_parallelized: {elapsed:?}");
|
||||
|
||||
new
|
||||
}
|
||||
|
||||
// move the msb on the most significant block.
|
||||
// if e = 0 and the first block is empty, return zero
|
||||
// (no subnormal value)
|
||||
pub fn realign_sub_parallelized(&self, ct0: &Ciphertext) -> Ciphertext {
|
||||
// todo!("realign_sub_parallelized");
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = (msg_modulus * car_modulus).ilog2() as usize;
|
||||
let size_mantissa = ct0.ct_vec_mantissa.len();
|
||||
|
||||
let zero = self.create_trivial_zero_from_ct(ct0);
|
||||
|
||||
let cmux_tree_size = if size_mantissa.is_power_of_two() {
|
||||
size_mantissa
|
||||
} else {
|
||||
size_mantissa.next_power_of_two()
|
||||
};
|
||||
|
||||
let mut ciphertexts_to_cmux: Vec<Ciphertext> = Vec::with_capacity(cmux_tree_size);
|
||||
let mut cmux_outputs: Vec<Ciphertext> = Vec::with_capacity(cmux_tree_size / 2);
|
||||
|
||||
(0..cmux_tree_size)
|
||||
.into_par_iter()
|
||||
.map(|ciphertext_idx| {
|
||||
if ciphertext_idx < size_mantissa {
|
||||
let mut ciphertext = zero.clone();
|
||||
|
||||
for (k, ct_exp_i) in ciphertext.ct_vec_exponent.iter_mut().enumerate() {
|
||||
self.key.unchecked_scalar_add_assign(
|
||||
ct_exp_i,
|
||||
((ciphertext_idx >> (msg_modulus.ilog2() as usize * (k))) % msg_modulus)
|
||||
as u8,
|
||||
);
|
||||
}
|
||||
|
||||
let exponent_block_count = size_mantissa - ciphertext_idx;
|
||||
ciphertext.ct_vec_mantissa[ciphertext_idx..]
|
||||
.clone_from_slice(&ct0.ct_vec_mantissa[..exponent_block_count]);
|
||||
|
||||
ciphertext
|
||||
} else {
|
||||
zero.clone()
|
||||
}
|
||||
})
|
||||
.collect_into_vec(&mut ciphertexts_to_cmux);
|
||||
|
||||
while ciphertexts_to_cmux.len() > 1 {
|
||||
ciphertexts_to_cmux
|
||||
.par_chunks_exact(2)
|
||||
.map(|chunk| {
|
||||
let less_modified_exponent = &chunk[0];
|
||||
let more_modified_exponent = &chunk[1];
|
||||
|
||||
let msb_mantissa_ggsw = self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(
|
||||
&less_modified_exponent.ct_vec_mantissa[size_mantissa - 1],
|
||||
msg_space,
|
||||
);
|
||||
|
||||
// return tmp if ggsw == 0; res otherwise
|
||||
let (mantissa, exponent) = rayon::join(
|
||||
|| {
|
||||
self.cmuxes_parallelized(
|
||||
&more_modified_exponent.ct_vec_mantissa,
|
||||
&less_modified_exponent.ct_vec_mantissa,
|
||||
&msb_mantissa_ggsw,
|
||||
)
|
||||
},
|
||||
|| {
|
||||
self.cmuxes_parallelized(
|
||||
&more_modified_exponent.ct_vec_exponent,
|
||||
&less_modified_exponent.ct_vec_exponent,
|
||||
&msb_mantissa_ggsw,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
let mut res = zero.clone();
|
||||
res.ct_vec_exponent = exponent;
|
||||
res.ct_vec_mantissa = mantissa;
|
||||
|
||||
res
|
||||
})
|
||||
.collect_into_vec(&mut cmux_outputs);
|
||||
|
||||
std::mem::swap(&mut ciphertexts_to_cmux, &mut cmux_outputs);
|
||||
}
|
||||
|
||||
let mut res = ciphertexts_to_cmux.into_iter().next().unwrap();
|
||||
|
||||
let (mut diff_exp, sub_exp_sign) =
|
||||
self.abs_diff_parallelized(&ct0.ct_vec_exponent, &res.ct_vec_exponent);
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let (sign_ggsw, msb_mantissa_ggsw) = rayon::join(
|
||||
|| self.ggsw_ks_cbs_parallelized(&sub_exp_sign, 0),
|
||||
|| {
|
||||
self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(
|
||||
&res.ct_vec_mantissa[size_mantissa - 1],
|
||||
msg_space,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
let (exponent, mantissa) = rayon::join(
|
||||
|| {
|
||||
diff_exp =
|
||||
self.cmuxes_parallelized(&zero.ct_vec_exponent, &diff_exp, &msb_mantissa_ggsw);
|
||||
self.cmuxes_parallelized(&zero.ct_vec_exponent, &diff_exp, &sign_ggsw)
|
||||
},
|
||||
|| self.cmuxes_parallelized(&zero.ct_vec_mantissa, &res.ct_vec_mantissa, &sign_ggsw),
|
||||
);
|
||||
|
||||
res.ct_vec_exponent = exponent;
|
||||
res.ct_vec_mantissa = mantissa;
|
||||
res.ct_sign = ct0.ct_sign.clone();
|
||||
res
|
||||
}
|
||||
|
||||
pub fn sub_total_parallelized(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let ct2 = self.change_sign(ct2);
|
||||
self.add_total_parallelized(&ct1, &ct2)
|
||||
}
|
||||
}
|
||||
835
concrete-float/src/server_key/tests.rs
Normal file
835
concrete-float/src/server_key/tests.rs
Normal file
@@ -0,0 +1,835 @@
|
||||
#![allow(dead_code)]
|
||||
use std::cmp::{max, min};
|
||||
use rand::Rng;
|
||||
use tfhe::shortint;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use crate::parameters::{PARAM_SAM_32, WOP_PARAM_SAM_32, PARAM_MESSAGE_2_CARRY_2_32,
|
||||
PARAM_MESSAGE_2_CARRY_2_64, WOP_PARAM_MESSAGE_2_CARRY_2_32,
|
||||
WOP_PARAM_MESSAGE_2_CARRY_2_64, FINAL_WOP_PARAM_2_2_32, FINAL_PARAM_2_2_32,
|
||||
FINAL_WOP_PARAM_8, FINAL_PARAM_8, FINAL_PARAM_15,
|
||||
FINAL_WOP_PARAM_15, FINAL_PARAM_16, FINAL_WOP_PARAM_16, FINAL_PARAM_32,
|
||||
FINAL_WOP_PARAM_32, FINAL_PARAM_64, FINAL_WOP_PARAM_64,
|
||||
FINAL_PARAM_64_BIS, FINAL_WOP_PARAM_64_BIS,
|
||||
FINAL_PARAM_32_BIS, FINAL_WOP_PARAM_32_BIS, FINAL_PARAM_16_BIS,
|
||||
FINAL_WOP_PARAM_16_BIS, FINAL_PARAM_15_BIS, FINAL_WOP_PARAM_15_BIS,
|
||||
FINAL_PARAM_8_BIS, FINAL_WOP_PARAM_8_BIS, FINAL_PARAM_32_TCHESS, FINAL_WOP_PARAM_32_TCHESS
|
||||
};
|
||||
use crate::server_key::*;
|
||||
use crate::{gen_keys, ClientKey};
|
||||
|
||||
const NB_OPE: i32 = 50;
|
||||
const LEN_MAN: usize = 13; //13;
|
||||
const LEN_EXP: usize = 4; //4;
|
||||
|
||||
|
||||
const LEN_MAN8: usize = 2;
|
||||
const LEN_EXP8: usize = 2;
|
||||
|
||||
const LEN_MAN16: usize = 6;
|
||||
const LEN_EXP16: usize = 3;
|
||||
|
||||
const LEN_MAN32: usize = 13;
|
||||
const LEN_EXP32: usize = 4;
|
||||
|
||||
const LEN_MAN64: usize = 27;
|
||||
const LEN_EXP64: usize = 5;
|
||||
|
||||
macro_rules! named_param {
|
||||
($param:ident) => {
|
||||
(stringify!($param), $param)
|
||||
};
|
||||
}
|
||||
|
||||
struct Parameters {
|
||||
pbsparameters: shortint::ClassicPBSParameters,
|
||||
wopbsparameters: shortint::WopbsParameters,
|
||||
len_man: usize,
|
||||
len_exp: usize,
|
||||
}
|
||||
|
||||
const PARAM_FP_64_BITS: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_64_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_64_BIS,
|
||||
len_man: LEN_MAN64,
|
||||
len_exp: LEN_EXP64,
|
||||
};
|
||||
|
||||
const PARAM_FP_32_BITS: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_32_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_32_BIS,
|
||||
len_man: LEN_MAN32,
|
||||
len_exp: LEN_EXP32,
|
||||
};
|
||||
|
||||
const PARAM_FP_16_BITS: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_16_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_16_BIS,
|
||||
len_man: LEN_MAN16,
|
||||
len_exp: LEN_EXP16,
|
||||
};
|
||||
|
||||
const PARAM_FP_8_BITS: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_8_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_8_BIS,
|
||||
len_man: LEN_MAN8,
|
||||
len_exp: LEN_EXP8,
|
||||
};
|
||||
|
||||
const PARAMS: [(&str, Parameters); 1] =
|
||||
[
|
||||
//named_param!(PARAM_FP_64_BITS),
|
||||
named_param!(PARAM_FP_32_BITS),
|
||||
//named_param!(PARAM_FP_16_BITS),
|
||||
//named_param!(PARAM_FP_8_BITS),
|
||||
];
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_float_encrypt() {
|
||||
for (_, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
print_info(&cks);
|
||||
println!("parameters :: {:?}", cks.key.parameters);
|
||||
let msg = 1.;
|
||||
|
||||
// Encryption of one message:
|
||||
let mut ct = cks.encrypt(msg);
|
||||
print_res(&cks, &ct, "decrypt", msg as f32, msg);
|
||||
sks.clean_degree(&mut ct);
|
||||
print_res(&cks, &ct, "decrypt", msg as f32, msg);
|
||||
let res = cks.decrypt(&ct);
|
||||
|
||||
assert_eq!(res, msg);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_mul() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
let msg2 = rng.gen::<f32>() as f64;
|
||||
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
print_res(&cks, &ct1, "ct 1", msg1 as f32, msg1);
|
||||
print_res(&cks, &ct2, "ct 2", msg2 as f32, msg2);
|
||||
|
||||
let res = sks.mul_total_parallelized(&mut ct1.clone(), &mut ct2.clone());
|
||||
print_res(&cks, &res, "Multiplication", (msg2 * msg1) as f32, msg2 * msg1);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res.abs() < ((msg1 * msg2) * 1.001).abs());
|
||||
assert!(res.abs() > ((msg1 * msg2) * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_div() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg2 = rng.gen::<f32>() as f64;
|
||||
let msg1 = -rng.gen::<f32>() as f64;
|
||||
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
print_res(&cks, &ct1, "ct1", (msg1) as f32, msg1);
|
||||
print_res(&cks, &ct2, "ct2", (msg2) as f32, msg2);
|
||||
|
||||
let mut res = sks.division(&ct1, &ct2);
|
||||
print_res(&cks, &res, "Division", (msg1 / msg2) as f32, msg1 / msg2);
|
||||
sks.clean_degree(&mut res);
|
||||
let res = cks.decrypt(&res);
|
||||
|
||||
assert!(res.abs() < ((msg1 / msg2) * 1.001).abs());
|
||||
assert!(res.abs() > ((msg1 / msg2) * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_cos() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
|
||||
let one = cks.encrypt(1.); //should be in trivial encrypt
|
||||
let one_div_by_2 = cks.encrypt(1. / 2.); //should be in trivial encrypt
|
||||
let one_div_by_24 = cks.encrypt(1. / 24.); //should be in trivial encrypt
|
||||
|
||||
print_res(&cks, &one, "one", 1 as f32, 1.);
|
||||
print_res(&cks, &one_div_by_2, "oneDivBy2", (1. / 2.) as f32, 1. / 2.);
|
||||
print_res(&cks, &one_div_by_24, "oneDivBy24", (1. / 24.) as f32, 1. / 24.);
|
||||
print_res(&cks, &ct1, "ct1", msg1 as f32, msg1);
|
||||
|
||||
|
||||
let ct1_square = sks.mul_total_parallelized(&ct1, &ct1);
|
||||
print_res(&cks, &ct1_square, "ct1_square", (msg1 * msg1) as f32, msg1 * msg1);
|
||||
|
||||
let ct1_square_square = sks.mul_total_parallelized(&ct1_square, &ct1_square);
|
||||
print_res(&cks, &ct1_square_square, "ct1_square_square", (msg1 * msg1 * msg1 * msg1) as f32, msg1 * msg1 * msg1 * msg1);
|
||||
|
||||
let ct1_square_time_one_div_by_2 = sks.mul_total_parallelized(&ct1_square, &one_div_by_2);
|
||||
print_res(&cks, &ct1_square_time_one_div_by_2, "ct1_square_time_1DivBy2", (msg1 * msg1 / 2.) as f32, msg1 * msg1 / 2.);
|
||||
|
||||
let ct1_square_square_time_one_div_by_24 = sks.mul_total_parallelized(&ct1_square_square, &one_div_by_24);
|
||||
print_res(&cks, &ct1_square_square_time_one_div_by_24, "ct1_square_square_time_1DivBy24", (msg1 * msg1 * msg1 * msg1 / 24.) as f32, msg1 * msg1 * msg1 * msg1 / 24.);
|
||||
|
||||
let res = sks.add_total_parallelized(&one, &ct1_square_square_time_one_div_by_24);
|
||||
print_res(&cks, &res, "first res", (1. + msg1 * msg1 * msg1 * msg1 / 24.) as f32, 1. + msg1 * msg1 * msg1 * msg1 / 24.);
|
||||
|
||||
|
||||
let res = sks.sub_total_parallelized(&res, &ct1_square_time_one_div_by_2);
|
||||
println!("Cosine, exact result : {:?}", msg1.cos());
|
||||
let approximation = 1. + msg1 * msg1 * msg1 * msg1 / 24. - msg1 * msg1 / 2.;
|
||||
print_res(&cks, &res, "Cosine approximation", approximation as f32, approximation);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res < (approximation * 1.001).abs());
|
||||
assert!(res > (approximation * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_sin() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
|
||||
print_res(&cks, &ct1, "ct1", msg1 as f32, msg1);
|
||||
|
||||
let one_div_by_6 = cks.encrypt(1. / 6.); //should be in trivial encrypt
|
||||
let one_div_by_120 = cks.encrypt(1. / 120.); //should be in trivial encrypt
|
||||
|
||||
let ct1_square = sks.mul_total_parallelized(&ct1, &ct1);
|
||||
print_res(&cks, &ct1_square, "ct1_square", (msg1 * msg1) as f32, msg1 * msg1);
|
||||
|
||||
let ct1_cube = sks.mul_total_parallelized(&ct1_square, &ct1);
|
||||
print_res(&cks, &ct1_cube, "ct1_cube", (msg1 * msg1 * msg1) as f32, msg1 * msg1 * msg1);
|
||||
|
||||
let ct1_power_five = sks.mul_total_parallelized(&ct1_square, &ct1_cube);
|
||||
print_res(&cks, &ct1_power_five, "ct1_power_five", (msg1 * msg1 * msg1 * msg1 * msg1) as f32, msg1 * msg1 * msg1 * msg1 * msg1);
|
||||
|
||||
|
||||
let ct1_cube_time_one_div_by_6 = sks.mul_total_parallelized(&ct1_cube, &one_div_by_6);
|
||||
print_res(&cks, &ct1_cube_time_one_div_by_6, "ct1_cube_time_one_div_by_6", (msg1 * msg1 * msg1 / 6.) as f32, msg1 * msg1 * msg1 / 6.);
|
||||
|
||||
let ct1_power_five_time_one_div_by_120 = sks.mul_total_parallelized(&ct1_power_five, &one_div_by_120);
|
||||
print_res(&cks, &ct1_power_five_time_one_div_by_120, "ct1_power_five_time_one_div_by_120", (msg1 * msg1 * msg1 * msg1 * msg1 / 120.) as f32, msg1 * msg1 * msg1 * msg1 * msg1 / 120.);
|
||||
|
||||
|
||||
let res = sks.add_total_parallelized(&ct1, &ct1_power_five_time_one_div_by_120);
|
||||
print_res(&cks, &ct1_power_five_time_one_div_by_120, "res_1", (msg1 * msg1 * msg1 * msg1 * msg1 / 120.) as f32, msg1 * msg1 * msg1 * msg1 * msg1 / 120.);
|
||||
|
||||
let res = sks.sub_total_parallelized(&res, &ct1_cube_time_one_div_by_6);
|
||||
|
||||
println!("Sine, exact result : {:?}", msg1.sin());
|
||||
let approximation = msg1 + msg1 * msg1 * msg1 * msg1 * msg1 / 120. - msg1 * msg1 * msg1 / 6.;
|
||||
print_res(&cks, &res, "Sine approximation", approximation as f32, approximation);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res < (approximation * 1.001).abs());
|
||||
assert!(res > (approximation * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_add() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
let msg2 = rng.gen::<f32>() as f64;
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
|
||||
print_res(&cks, &ct1, "ct 1", msg1 as f32, msg1);
|
||||
print_res(&cks, &ct2, "ct 2", msg2 as f32, msg2);
|
||||
|
||||
let res = sks.add_total_parallelized(&ct1, &ct2);
|
||||
|
||||
print_res(&cks, &res, "Addition", (msg1 + msg2) as f32, msg1 + msg2);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res.abs() < ((msg1 + msg2) * 1.001).abs());
|
||||
assert!(res.abs() > ((msg1 + msg2) * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_sub() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
let msg2 = rng.gen::<f32>() as f64;
|
||||
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
|
||||
|
||||
print_res(&cks, &ct1, "ct 1", msg1 as f32, msg1);
|
||||
print_res(&cks, &ct2, "ct 2", msg2 as f32, msg2);
|
||||
let res = sks.sub_total_parallelized(&ct1, &ct2);
|
||||
|
||||
print_res(&cks, &res, "Subtraction", (msg1 - msg2) as f32, msg1 - msg2);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res.abs() < ((msg1 - msg2) * 1.001).abs());
|
||||
assert!(res.abs() > ((msg1 - msg2) * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn depth_test_parallelized() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
let max_i = 1_000.;
|
||||
|
||||
let mut vec_float_32 = vec![];
|
||||
let mut vec_float_64 = vec![];
|
||||
let mut vec_hom_float = vec![];
|
||||
let mut vec_deep = vec![];
|
||||
let mut vec_nb_operation = vec![];
|
||||
|
||||
let len_vec = 3 as u16;
|
||||
for i in 0..len_vec {
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
println!("msg_{:?}: {:?}", i, msg);
|
||||
let ct = cks.encrypt(msg);
|
||||
print_res(&cks, &ct, "encrypt/decrypt", msg as f32, msg);
|
||||
|
||||
vec_float_64.push(msg);
|
||||
vec_float_32.push(msg as f32);
|
||||
vec_hom_float.push(ct);
|
||||
vec_deep.push(0);
|
||||
vec_nb_operation.push(0);
|
||||
}
|
||||
|
||||
for i in 0..NB_OPE {
|
||||
println!("\n----Round {:?}----", i);
|
||||
let r_ope = rng.gen::<u16>() % 3;
|
||||
let r_value_1 = (rng.gen::<u16>() % len_vec) as usize;
|
||||
let mut r_value_2 = (rng.gen::<u16>() % len_vec) as usize;
|
||||
let mut r_place = (rng.gen::<u16>() % 2) as usize;
|
||||
while r_value_1 == r_value_2 {
|
||||
r_value_2 = (rng.gen::<u16>() % len_vec) as usize;
|
||||
}
|
||||
if r_place == 0 {
|
||||
r_place = r_value_1
|
||||
} else {
|
||||
r_place = r_value_2
|
||||
}
|
||||
|
||||
vec_deep[r_place] = min(vec_deep[r_value_1], vec_deep[r_value_2]) + 1;
|
||||
vec_nb_operation[r_place] =
|
||||
max(vec_nb_operation[r_value_1], vec_nb_operation[r_value_2]) + 1;
|
||||
if r_ope == 0 {
|
||||
println!(
|
||||
"block {:?} * block {:?} -> block{:?}\n",
|
||||
r_value_1, r_value_2, r_place
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
vec_float_64[r_value_2],
|
||||
vec_float_64[r_value_1] * vec_float_64[r_value_2]
|
||||
);
|
||||
vec_hom_float[r_place] = sks.mul_total_parallelized(
|
||||
&mut vec_hom_float[r_value_1].clone(),
|
||||
&mut vec_hom_float[r_value_2].clone(),
|
||||
);
|
||||
vec_float_32[r_place] = vec_float_32[r_value_1] * vec_float_32[r_value_2];
|
||||
vec_float_64[r_place] = vec_float_64[r_value_1] * vec_float_64[r_value_2];
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
} else if r_ope == 1 {
|
||||
println!(
|
||||
"block {:?} + block {:?} -> block{:?}\n",
|
||||
r_value_1, r_value_2, r_place
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} + {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
vec_float_64[r_value_2],
|
||||
vec_float_64[r_value_1] + vec_float_64[r_value_2]
|
||||
);
|
||||
|
||||
vec_hom_float[r_place] =
|
||||
sks.add_total_parallelized(&vec_hom_float[r_value_1], &vec_hom_float[r_value_2]);
|
||||
vec_float_32[r_place] = vec_float_32[r_value_1] + vec_float_32[r_value_2];
|
||||
vec_float_64[r_place] = vec_float_64[r_value_1] + vec_float_64[r_value_2];
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res add",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"block {:?} - block {:?} -> block{:?}\n",
|
||||
r_value_1, r_value_2, r_place
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} - {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
vec_float_64[r_value_2],
|
||||
vec_float_64[r_value_1] - vec_float_64[r_value_2]
|
||||
);
|
||||
|
||||
vec_hom_float[r_place] =
|
||||
sks.sub_total_parallelized(&vec_hom_float[r_value_1], &vec_hom_float[r_value_2]);
|
||||
vec_float_32[r_place] = vec_float_32[r_value_1] - vec_float_32[r_value_2];
|
||||
vec_float_64[r_place] = vec_float_64[r_value_1] - vec_float_64[r_value_2];
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res sub",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
if vec_float_64[r_value_1].abs() > max_i {
|
||||
let msg_tmp = (1. / max_i) * rng.gen::<f32>() as f64; // 1. / (vec_float_64[r_value_1].abs() + vec_float_64[r_value_2].clone().abs() );
|
||||
let mut ct_tmp = cks.encrypt(msg_tmp);
|
||||
|
||||
println!(
|
||||
"block {:?} * {:?} -> block{:?}\n",
|
||||
r_value_1, msg_tmp, r_value_1
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
msg_tmp,
|
||||
vec_float_64[r_value_1] * msg_tmp
|
||||
);
|
||||
|
||||
vec_hom_float[r_place] =
|
||||
sks.mul_total_parallelized(&mut vec_hom_float[r_value_1].clone(), &mut ct_tmp);
|
||||
vec_float_32[r_value_1] = vec_float_32[r_value_1] * msg_tmp as f32;
|
||||
vec_float_64[r_value_1] = vec_float_64[r_value_1] * msg_tmp;
|
||||
vec_nb_operation[r_value_1] += 1;
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
if vec_float_64[r_value_2].abs() > max_i {
|
||||
let msg_tmp = (1. / max_i) * rng.gen::<f32>() as f64; // 1. / (vec_float_64[r_value_1].abs() + vec_float_64[r_value_2].clone().abs() );
|
||||
let mut ct_tmp = cks.encrypt(msg_tmp);
|
||||
|
||||
println!(
|
||||
"block {:?} * {:?} -> block{:?}\n",
|
||||
r_value_2, msg_tmp, r_value_2
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
msg_tmp,
|
||||
vec_float_64[r_value_1] * msg_tmp
|
||||
);
|
||||
|
||||
vec_hom_float[r_value_2] =
|
||||
sks.mul_total_parallelized(&mut vec_hom_float[r_value_2].clone(), &mut ct_tmp);
|
||||
vec_float_32[r_value_2] = vec_float_32[r_value_2] * msg_tmp as f32;
|
||||
vec_float_64[r_value_2] = vec_float_64[r_value_2] * msg_tmp;
|
||||
vec_nb_operation[r_value_2] += 1;
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
|
||||
if vec_float_64[r_value_1].abs() < 1. / max_i {
|
||||
let msg_tmp = max_i * rng.gen::<f32>() as f64; // 1. / (vec_float_64[r_value_1].abs() + vec_float_64[r_value_2].clone().abs() );
|
||||
let mut ct_tmp = cks.encrypt(msg_tmp);
|
||||
|
||||
println!(
|
||||
"block {:?} * {:?} -> block{:?}\n",
|
||||
r_value_1, msg_tmp, r_value_1
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
msg_tmp,
|
||||
vec_float_64[r_value_1] * msg_tmp
|
||||
);
|
||||
|
||||
vec_hom_float[r_value_1] =
|
||||
sks.mul_total_parallelized(&mut vec_hom_float[r_value_1].clone(), &mut ct_tmp);
|
||||
vec_float_32[r_value_1] = vec_float_32[r_value_1] * msg_tmp as f32;
|
||||
vec_float_64[r_value_1] = vec_float_64[r_value_1] * msg_tmp;
|
||||
vec_nb_operation[r_value_1] += 1;
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
if vec_float_64[r_value_2].abs() < 1. / max_i {
|
||||
let msg_tmp = max_i * rng.gen::<f32>() as f64; // 1. / (vec_float_64[r_value_1].abs() + vec_float_64[r_value_2].clone().abs() );
|
||||
let mut ct_tmp = cks.encrypt(msg_tmp);
|
||||
|
||||
println!(
|
||||
"block {:?} * {:?} -> block{:?}\n",
|
||||
r_value_2, msg_tmp, r_value_2
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
msg_tmp,
|
||||
vec_float_64[r_value_1] * msg_tmp
|
||||
);
|
||||
|
||||
vec_hom_float[r_value_2] =
|
||||
sks.mul_total_parallelized(&mut vec_hom_float[r_value_2].clone(), &mut ct_tmp);
|
||||
vec_float_32[r_value_2] = vec_float_32[r_value_2] * msg_tmp as f32;
|
||||
vec_float_64[r_value_2] = vec_float_64[r_value_2] * msg_tmp;
|
||||
vec_nb_operation[r_value_2] += 1;
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
println!("----End Round {:?}----", i);
|
||||
println!("--------------------");
|
||||
println!("--------------------");
|
||||
println!("--------------------");
|
||||
}
|
||||
|
||||
for i in 0..len_vec as usize {
|
||||
println!("------");
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[i],
|
||||
"Final result",
|
||||
vec_float_32[i],
|
||||
vec_float_64[i],
|
||||
);
|
||||
//println!("Deep : {:?}", vec_deep[i]);
|
||||
//println!("Ope : {:?}", vec_nb_operation[i]);
|
||||
|
||||
let res = cks.decrypt(&vec_hom_float[i]);
|
||||
assert!(res.abs() < (vec_float_64[i] * 1.001).abs());
|
||||
assert!(res.abs() > (vec_float_64[i] * 0.999).abs());
|
||||
//println!("------");
|
||||
}
|
||||
//println!("Info :");
|
||||
//println!("len mantissa : {:?}", LEN_MAN);
|
||||
//println!("len exponent : {:?}", LEN_EXP);
|
||||
//println!("number operations : {:?}", NB_OPE);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_same_as_ls_22_32() {
|
||||
let (cks, sks) = gen_keys(
|
||||
PARAM_MESSAGE_2_CARRY_2_32,
|
||||
WOP_PARAM_MESSAGE_2_CARRY_2_32,
|
||||
LEN_MAN32,
|
||||
LEN_EXP32,
|
||||
);
|
||||
print_info(&cks);
|
||||
let msg1 = -2.7914999921796382_e-15;
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
print_res(&cks, &ct1, "Encrypt/Decrypt", msg1 as f32, msg1);
|
||||
|
||||
let msg2 = 8.3867001884896375_e-12;
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
print_res(&cks, &ct2, "Encrypt/Decrypt", msg2 as f32, msg2);
|
||||
|
||||
let msg3 = 1.82634005135360_e14;
|
||||
let ct3 = cks.encrypt(msg3);
|
||||
print_res(&cks, &ct3, "Encrypt/Decrypt", msg3 as f32, msg3);
|
||||
|
||||
let msg4 = -6.278269952_e9;
|
||||
let ct4 = cks.encrypt(msg4);
|
||||
print_res(&cks, &ct4, "Encrypt/Decrypt", msg4 as f32, msg4);
|
||||
|
||||
let res_1 = sks.add_total_parallelized(&ct1, &ct2);
|
||||
print_res(
|
||||
&cks,
|
||||
&res_1,
|
||||
"res add",
|
||||
msg1 as f32 + msg2 as f32,
|
||||
msg1 + msg2,
|
||||
);
|
||||
|
||||
let res_2 = sks.sub_total_parallelized(&ct3, &ct4);
|
||||
print_res(
|
||||
&cks,
|
||||
&res_2,
|
||||
"res add",
|
||||
msg3 as f32 - msg4 as f32,
|
||||
msg3 - msg4,
|
||||
);
|
||||
|
||||
let mut witness32 = (msg3 as f32 - msg4 as f32) * (msg1 as f32 + msg2 as f32);
|
||||
let mut witness64 = (msg3 - msg4) * (msg1 + msg2);
|
||||
let res = sks.mul_total_parallelized(&res_1, &res_2);
|
||||
print_res(&cks, &res, "res mul", witness32, witness64);
|
||||
|
||||
let res = sks.mul_total_parallelized(&res, &res);
|
||||
witness32 *= witness32;
|
||||
witness64 *= witness64;
|
||||
print_res(&cks, &res, "res mul", witness32, witness64);
|
||||
let res = cks.decrypt(&res);
|
||||
|
||||
assert!(res.abs() < ((witness32 * 1.001 as f32) as f64).abs());
|
||||
assert!(res.abs() > ((witness32 * 0.999 as f32) as f64).abs());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_same_as_ls_22_64() {
|
||||
let (cks, sks) = gen_keys(
|
||||
PARAM_MESSAGE_2_CARRY_2_64,
|
||||
WOP_PARAM_MESSAGE_2_CARRY_2_64,
|
||||
LEN_MAN64,
|
||||
LEN_EXP64,
|
||||
);
|
||||
|
||||
print_info(&cks);
|
||||
let msg1 = -9.1763514236254290_e-32;
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
print_res(&cks, &ct1, "Encrypt/Decrypt", msg1 as f32, msg1);
|
||||
|
||||
let msg2 = 6.2467247246375865_e-24;
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
print_res(&cks, &ct2, "Encrypt/Decrypt", msg2 as f32, msg2);
|
||||
|
||||
let msg3 = 2.4523526872362373_e22;
|
||||
let ct3 = cks.encrypt(msg3);
|
||||
print_res(&cks, &ct3, "Encrypt/Decrypt", msg3 as f32, msg3);
|
||||
|
||||
let msg4 = -5.4324663335297274_e17;
|
||||
let ct4 = cks.encrypt(msg4);
|
||||
print_res(&cks, &ct4, "Encrypt/Decrypt", msg4 as f32, msg4);
|
||||
|
||||
let res_1 = sks.add_total_parallelized(&ct1, &ct2);
|
||||
print_res(
|
||||
&cks,
|
||||
&res_1,
|
||||
"res add",
|
||||
msg1 as f32 + msg2 as f32,
|
||||
msg1 + msg2,
|
||||
);
|
||||
|
||||
let res_2 = sks.sub_total_parallelized(&ct3, &ct4);
|
||||
print_res(
|
||||
&cks,
|
||||
&res_2,
|
||||
"res add",
|
||||
msg3 as f32 - msg4 as f32,
|
||||
msg3 - msg4,
|
||||
);
|
||||
|
||||
let mut witness32 = (msg3 as f32 - msg4 as f32) * (msg1 as f32 + msg2 as f32);
|
||||
let mut witness64 = (msg3 - msg4) * (msg1 + msg2);
|
||||
let res = sks.mul_total_parallelized(&res_1, &res_2);
|
||||
print_res(&cks, &res, "res mul", witness32, witness64);
|
||||
|
||||
let res = sks.mul_total_parallelized(&res, &res);
|
||||
witness32 *= witness32;
|
||||
witness64 *= witness64;
|
||||
print_res(&cks, &res, "res mul", witness32, witness64);
|
||||
let res = cks.decrypt(&res);
|
||||
|
||||
assert!(res.abs() < (witness64 * 1.001).abs());
|
||||
assert!(res.abs() > (witness64 * 0.999).abs());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_relu() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
let msg = rng.gen::<f32>() as f64 - rng.gen::<f32>() as f64;
|
||||
let ct = cks.encrypt(msg);
|
||||
print_res(&cks, &ct, "decrypt", msg as f32, msg);
|
||||
let res = sks.relu(&ct);
|
||||
print_res(&cks, &res, "relu", 0.0_f32.max(msg as f32), msg.max(0.));
|
||||
let res = cks.decrypt(&res);
|
||||
assert_eq!(res, msg.max(0.));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_sigmoid() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
|
||||
let msg = (rng.gen::<f32>() as f64 + 0.4).abs();
|
||||
let ct = cks.encrypt(msg);
|
||||
print_res(&cks, &ct, "ct", msg as f32, msg);
|
||||
let res = sks.sigmoid(&ct);
|
||||
print_res(&cks, &res, "approx sigmoid", 1.0_f32.min(msg as f32), msg.min(1.));
|
||||
let res = cks.decrypt(&res);
|
||||
|
||||
assert!(res > msg.min(1.) * 0.999);
|
||||
assert!(res < msg.min(1.) * 1.001);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_res(
|
||||
cks: &ClientKey,
|
||||
ct: &Ciphertext,
|
||||
operation: &str,
|
||||
witness32: f32,
|
||||
witness64: f64,
|
||||
) {
|
||||
println!("\n--------------------",);
|
||||
println!("{:?}:\n", operation);
|
||||
println!("Result : {:?}", cks.decrypt(&ct));
|
||||
println!("Clear 32-bits: {:?}", witness32);
|
||||
println!("Clear 64-bits: {:?}\n", witness64);
|
||||
println!("--------------------");
|
||||
}
|
||||
|
||||
pub fn print_info(cks: &ClientKey) {
|
||||
println!("\n-----Info-----");
|
||||
println!("length exp {:?}", cks.vector_length_exponent);
|
||||
println!("length man {:?}", cks.vector_length_mantissa);
|
||||
let msg_modulus = cks.parameters().message_modulus().0;
|
||||
let car_modulus = cks.parameters().carry_modulus().0;
|
||||
println!("msg modulus {:?}, 0b{:b}", msg_modulus, msg_modulus);
|
||||
println!("car modulus {:?}, 0b{:b}", car_modulus, car_modulus);
|
||||
println!(
|
||||
"total space {:?}, 0b{:b}",
|
||||
car_modulus * msg_modulus,
|
||||
car_modulus * msg_modulus
|
||||
);
|
||||
let log_msg_modulus = f64::log2(msg_modulus as f64) as usize;
|
||||
let bias = -((1 << (cks.vector_length_exponent.0 * log_msg_modulus - 1)) as i64)
|
||||
- (cks.vector_length_mantissa.0 as i64 - 1);
|
||||
println!("Bias {:?}", bias);
|
||||
println!("--------------\n");
|
||||
}
|
||||
551
concrete-float/src/server_key/tools.rs
Normal file
551
concrete-float/src/server_key/tools.rs
Normal file
@@ -0,0 +1,551 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use shortint::ciphertext::{Ciphertext as ShortintCiphertext, Degree};
|
||||
|
||||
use std::cmp::{max, min};
|
||||
use tfhe::core_crypto::algorithms::{
|
||||
cmux_assign, extract_lwe_sample_from_glwe_ciphertext, keyswitch_lwe_ciphertext,
|
||||
par_keyswitch_lwe_ciphertext,
|
||||
};
|
||||
|
||||
use aligned_vec::ABox;
|
||||
use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut, StackReq};
|
||||
use tfhe::core_crypto::commons::parameters::*;
|
||||
use tfhe::core_crypto::entities::*;
|
||||
use tfhe::core_crypto::fft_impl::fft64::c64;
|
||||
use tfhe::core_crypto::fft_impl::fft64::crypto::ggsw::fill_with_forward_fourier_scratch;
|
||||
use tfhe::core_crypto::fft_impl::fft64::crypto::wop_pbs::{
|
||||
circuit_bootstrap_boolean, circuit_bootstrap_boolean_parallelized,
|
||||
circuit_bootstrap_boolean_scratch, extract_bits, extract_bits_parallelized,
|
||||
extract_bits_scratch,
|
||||
};
|
||||
use tfhe::core_crypto::fft_impl::fft64::math::fft::par_convert_polynomials_list_to_fourier;
|
||||
use tfhe::core_crypto::prelude::{ContiguousEntityContainer, Fft};
|
||||
use tfhe::shortint;
|
||||
use tfhe::shortint::ciphertext::NoiseLevel;
|
||||
|
||||
use rayon::prelude::*;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn ggsw_pbs_ks_cbs(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let accumulator = self.key.generate_lookup_table(|x| min(1, x) as u64);
|
||||
let res = self.key.apply_lookup_table(&ct1, &accumulator);
|
||||
self.ggsw_ks_cbs(&res, message_space)
|
||||
}
|
||||
|
||||
/// return ggsw(0) if ct1 = 0, return ggsw(1) otherwise
|
||||
pub fn ggsw_ks_cbs(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let ciphertext_modulus = ct1.ct.ciphertext_modulus();
|
||||
|
||||
let mut res_ks = LweCiphertext::new(
|
||||
0u64,
|
||||
LweSize(self.wopbs_key.param.lwe_dimension.to_lwe_size().0),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
keyswitch_lwe_ciphertext(&self.key.key_switching_key, &ct1.ct, &mut res_ks);
|
||||
self.ggsw_cbs(&res_ks.as_view(), message_space)
|
||||
}
|
||||
|
||||
/// return ggsw(0) if ct1 = 0, return ggsw(1) otherwise
|
||||
pub fn ggsw_cbs(
|
||||
&self,
|
||||
ct: &LweCiphertext<&[u64]>,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let glwe_dimension = self.wopbs_key.param.glwe_dimension;
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let base_log_cbs = self.wopbs_key.param.cbs_base_log;
|
||||
let level_count_cbs = self.wopbs_key.param.cbs_level;
|
||||
let ciphertext_modulus = ct.ciphertext_modulus();
|
||||
|
||||
let fourier_bsk = match &self.wopbs_key.wopbs_server_key.bootstrapping_key {
|
||||
shortint::server_key::ShortintBootstrappingKey::Classic(fbsk) => fbsk.as_view(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let mut cbs_res = GgswCiphertext::new(
|
||||
0u64,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
base_log_cbs,
|
||||
level_count_cbs,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
let mut ggsw = FourierGgswCiphertext::new(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
base_log_cbs,
|
||||
level_count_cbs,
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
circuit_bootstrap_boolean_scratch::<u64>(
|
||||
ct.lwe_size(),
|
||||
fourier_bsk.output_lwe_dimension().to_lwe_size(),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
let mut stack = PodStack::new(&mut mem);
|
||||
circuit_bootstrap_boolean(
|
||||
fourier_bsk,
|
||||
ct.as_view(),
|
||||
cbs_res.as_mut_view(),
|
||||
DeltaLog(63 - message_space),
|
||||
self.wopbs_key.cbs_pfpksk.as_view(),
|
||||
fft,
|
||||
stack.rb_mut(),
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
|
||||
let mut stack = PodStack::new(&mut mem);
|
||||
ggsw.as_mut_view()
|
||||
.fill_with_forward_fourier(cbs_res.as_view(), fft, stack.rb_mut());
|
||||
ggsw
|
||||
}
|
||||
|
||||
pub fn extract_bit_cbs(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
) -> Vec<FourierGgswCiphertext<ABox<[c64]>>> {
|
||||
let glwe_dimension = self.wopbs_key.param.glwe_dimension;
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let lwe_dimension = self.wopbs_key.param.lwe_dimension;
|
||||
let message_modulus = self.wopbs_key.param.message_modulus;
|
||||
let log_message_modulus = f64::log2(message_modulus.0 as f64) as usize;
|
||||
let log_carry_modulus = f64::log2(self.wopbs_key.param.carry_modulus.0 as f64) as usize;
|
||||
let ciphertext_modulus = ct1.ct.ciphertext_modulus();
|
||||
|
||||
let ksk = &self.key.key_switching_key;
|
||||
let delta_log = 63 - log_message_modulus * log_carry_modulus;
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let req = || {
|
||||
StackReq::try_any_of([
|
||||
fill_with_forward_fourier_scratch(fft)?,
|
||||
extract_bits_scratch::<u64>(
|
||||
lwe_dimension,
|
||||
LweDimension(polynomial_size.0 * glwe_dimension.0 + 1),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)?,
|
||||
])
|
||||
};
|
||||
let req = req().unwrap();
|
||||
let mut mem = GlobalPodBuffer::new(req);
|
||||
let stack = PodStack::new(&mut mem);
|
||||
let fourier_bsk = match &self.wopbs_key.wopbs_server_key.bootstrapping_key {
|
||||
shortint::server_key::ShortintBootstrappingKey::Classic(fbsk) => fbsk.as_view(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let mut lwe_out_list = LweCiphertextList::new(
|
||||
0u64,
|
||||
ksk.output_lwe_size(),
|
||||
LweCiphertextCount(log_message_modulus),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
extract_bits(
|
||||
lwe_out_list.as_mut_view(),
|
||||
ct1.ct.as_view(),
|
||||
ksk.as_view(),
|
||||
fourier_bsk,
|
||||
DeltaLog(delta_log),
|
||||
ExtractedBitsCount(log_message_modulus),
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
let mut out_vec_ggsw: Vec<FourierGgswCiphertext<ABox<[c64]>>> = Vec::new();
|
||||
for lwe in lwe_out_list.iter() {
|
||||
let ggsw = self.ggsw_cbs_parallelized(&lwe, 0);
|
||||
out_vec_ggsw.append(&mut vec![ggsw]);
|
||||
}
|
||||
out_vec_ggsw
|
||||
}
|
||||
|
||||
//return ct0 if we have ggsw(0)
|
||||
//return ct1 if we have ggsw(1)
|
||||
//with cti a ShortintCiphertext
|
||||
pub fn cmux(
|
||||
&self,
|
||||
ct0: &ShortintCiphertext,
|
||||
ct1: &ShortintCiphertext,
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> ShortintCiphertext {
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let glwe_dim = self.wopbs_key.param.glwe_dimension;
|
||||
let mut vec_0 = vec![0u64; polynomial_size.0 * (glwe_dim.0 + 1)];
|
||||
let mut vec_1 = vec![0u64; polynomial_size.0 * (glwe_dim.0 + 1)];
|
||||
for (i, (ct_i_0, ct_i_1)) in ct0
|
||||
.ct
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(ct1.ct.as_ref().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if i % polynomial_size.0 == 0 {
|
||||
vec_0[i] = *ct_i_0;
|
||||
vec_1[i] = *ct_i_1;
|
||||
} else {
|
||||
let index =
|
||||
(i / polynomial_size.0 + 1) * polynomial_size.0 - (i % polynomial_size.0);
|
||||
vec_0[index] = 0 - (*ct_i_0);
|
||||
vec_1[index] = 0 - (*ct_i_1);
|
||||
}
|
||||
}
|
||||
let mut rlwe_0 =
|
||||
GlweCiphertext::from_container(vec_0, polynomial_size, self.key.ciphertext_modulus);
|
||||
let mut rlwe_1 =
|
||||
GlweCiphertext::from_container(vec_1, polynomial_size, self.key.ciphertext_modulus);
|
||||
|
||||
cmux_assign(&mut rlwe_0, &mut rlwe_1, ggsw);
|
||||
|
||||
let mut output = LweCiphertext::new(
|
||||
0_u64,
|
||||
LweSize(polynomial_size.0 * glwe_dim.0 + 1),
|
||||
self.key.ciphertext_modulus,
|
||||
);
|
||||
extract_lwe_sample_from_glwe_ciphertext(&rlwe_0, &mut output, MonomialDegree(0));
|
||||
let ct_out = shortint::Ciphertext::new(
|
||||
output,
|
||||
Degree::new(max(ct0.degree.get(), ct1.degree.get())),
|
||||
NoiseLevel::NOMINAL, // TODO: check this is valid in the context of floats
|
||||
ct0.message_modulus,
|
||||
ct0.carry_modulus,
|
||||
PBSOrder::KeyswitchBootstrap,
|
||||
);
|
||||
ct_out
|
||||
}
|
||||
|
||||
//return ct0 in ct0 if we have ggsw(0)
|
||||
//return ct1 in ct0 if we have ggsw(1)
|
||||
//with cti = [Ciphertext]
|
||||
pub fn cmuxes(
|
||||
&self,
|
||||
ct0: &[ShortintCiphertext],
|
||||
ct1: &[ShortintCiphertext],
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> Vec<shortint::Ciphertext> {
|
||||
let mut vec_output: Vec<ShortintCiphertext> = Vec::new();
|
||||
for (ct_0, ct_1) in ct0.iter().zip(ct1.iter()) {
|
||||
let output = self.cmux(ct_0, ct_1, ggsw);
|
||||
vec_output.push(output);
|
||||
}
|
||||
vec_output
|
||||
}
|
||||
|
||||
//return ct0 in a nwe ct if we have ggsw(0)
|
||||
//return ct1 in a new ct if we have ggsw(1)
|
||||
//with cti a fp
|
||||
pub fn cmuxes_full(
|
||||
&self,
|
||||
ct0: &Ciphertext,
|
||||
ct1: &Ciphertext,
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> Ciphertext {
|
||||
let res_man = self.cmuxes(&ct0.ct_vec_mantissa, &ct1.ct_vec_mantissa, &ggsw);
|
||||
let res_exp = self.cmuxes(&ct0.ct_vec_exponent, &ct1.ct_vec_exponent, &ggsw);
|
||||
let res_sig = self.cmux(&ct0.ct_sign, &ct1.ct_sign, &ggsw);
|
||||
let mut new = self.create_trivial_zero_from_ct(ct0);
|
||||
new.ct_vec_mantissa = res_man;
|
||||
new.ct_vec_exponent = res_exp;
|
||||
new.ct_sign = res_sig;
|
||||
new
|
||||
}
|
||||
|
||||
pub fn cmux_tree_mantissa(
|
||||
&self,
|
||||
vec_mantissa: &Vec<shortint::Ciphertext>,
|
||||
vec_ggsw: &[FourierGgswCiphertext<ABox<[c64]>>],
|
||||
) -> Vec<shortint::Ciphertext> {
|
||||
let zero = self.key.create_trivial(0_u64);
|
||||
let mut cpy = vec_mantissa.clone();
|
||||
let mut vec_fp = Vec::new();
|
||||
for _ in 0..(vec_mantissa.len() + 1) {
|
||||
vec_fp.push(cpy.clone());
|
||||
cpy.push(zero.clone());
|
||||
let _ = cpy.remove(0);
|
||||
}
|
||||
let vec_zero = cpy;
|
||||
for ggsw in vec_ggsw.iter().rev() {
|
||||
if vec_fp.len() == 1 {
|
||||
vec_fp[0] = self.cmuxes(&mut vec_fp[0], &vec_zero, ggsw);
|
||||
} else {
|
||||
if vec_fp.len() % 2 == 0 {
|
||||
for i in 0..vec_fp.len() / 2 {
|
||||
let ct_0 = vec_fp.get_mut(2 * i).unwrap().clone();
|
||||
let ct_1 = vec_fp.get_mut(2 * i + 1).unwrap().clone();
|
||||
vec_fp[i] = self.cmuxes(&ct_0, &ct_1, ggsw);
|
||||
}
|
||||
vec_fp.truncate(vec_fp.len() / 2);
|
||||
} else {
|
||||
for i in 0..vec_fp.len() / 2 {
|
||||
let ct_0 = vec_fp.get_mut(2 * i).unwrap().clone();
|
||||
let ct_1 = vec_fp.get_mut(2 * i + 1).unwrap().clone();
|
||||
vec_fp[i] = self.cmuxes(&ct_0, &ct_1, ggsw);
|
||||
}
|
||||
let last = vec_fp.len();
|
||||
let ct_0 = vec_fp.last().unwrap().clone();
|
||||
let ct_1 = &vec_zero;
|
||||
vec_fp[last / 2] = self.cmuxes(&ct_0, &ct_1, ggsw);
|
||||
vec_fp.truncate((vec_fp.len() + 1) / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_fp[0].clone()
|
||||
}
|
||||
|
||||
pub fn is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let accumulator = self.key.generate_lookup_table(|x| u64::from(x != 0));
|
||||
let res = self.key.apply_lookup_table(&ct1, &accumulator);
|
||||
self.ggsw_ks_cbs_parallelized(&res, message_space)
|
||||
}
|
||||
|
||||
/// return ggsw(0) if ct1 = 0, return ggsw(1) otherwise
|
||||
pub fn ggsw_ks_cbs_parallelized(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let ciphertext_modulus = ct1.ct.ciphertext_modulus();
|
||||
|
||||
let mut res_ks = LweCiphertext::new(
|
||||
0u64,
|
||||
LweSize(self.wopbs_key.param.lwe_dimension.to_lwe_size().0),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
par_keyswitch_lwe_ciphertext(&self.key.key_switching_key, &ct1.ct, &mut res_ks);
|
||||
self.ggsw_cbs_parallelized(&res_ks.as_view(), message_space)
|
||||
}
|
||||
|
||||
/// return ggsw(0) if ct1 = 0, return ggsw(1) otherwise
|
||||
pub fn ggsw_cbs_parallelized(
|
||||
&self,
|
||||
ct: &LweCiphertext<&[u64]>,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
// todo!("ggsw_cbs_parallelized");
|
||||
let glwe_dimension = self.wopbs_key.param.glwe_dimension;
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let base_log_cbs = self.wopbs_key.param.cbs_base_log;
|
||||
let level_count_cbs = self.wopbs_key.param.cbs_level;
|
||||
let ciphertext_modulus = ct.ciphertext_modulus();
|
||||
|
||||
let fourier_bsk = match &self.wopbs_key.wopbs_server_key.bootstrapping_key {
|
||||
shortint::server_key::ShortintBootstrappingKey::Classic(fbsk) => fbsk.as_view(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let mut cbs_res = GgswCiphertext::new(
|
||||
0u64,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
base_log_cbs,
|
||||
level_count_cbs,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
let mut ggsw = FourierGgswCiphertext::new(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
base_log_cbs,
|
||||
level_count_cbs,
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
circuit_bootstrap_boolean_scratch::<u64>(
|
||||
ct.lwe_size(),
|
||||
fourier_bsk.output_lwe_dimension().to_lwe_size(),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
let mut stack = PodStack::new(&mut mem);
|
||||
circuit_bootstrap_boolean_parallelized(
|
||||
fourier_bsk,
|
||||
ct.as_view(),
|
||||
cbs_res.as_mut_view(),
|
||||
DeltaLog(63 - message_space),
|
||||
self.wopbs_key.cbs_pfpksk.as_view(),
|
||||
fft,
|
||||
stack.rb_mut(),
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
|
||||
let mut _stack = PodStack::new(&mut mem);
|
||||
|
||||
par_convert_polynomials_list_to_fourier(
|
||||
ggsw.as_mut_view().data(),
|
||||
cbs_res.as_ref(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
);
|
||||
// ggsw.as_mut_view()
|
||||
// .fill_with_forward_fourier(cbs_res.as_view(), fft, stack.rb_mut());
|
||||
ggsw
|
||||
}
|
||||
|
||||
pub fn extract_bit_cbs_parallelized(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
) -> Vec<FourierGgswCiphertext<ABox<[c64]>>> {
|
||||
// todo!("extract_bit_cbs_parallelized");
|
||||
let glwe_dimension = self.wopbs_key.param.glwe_dimension;
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let lwe_dimension = self.wopbs_key.param.lwe_dimension;
|
||||
let message_modulus = self.wopbs_key.param.message_modulus;
|
||||
let log_message_modulus = f64::log2(message_modulus.0 as f64) as usize;
|
||||
let log_carry_modulus = f64::log2(self.wopbs_key.param.carry_modulus.0 as f64) as usize;
|
||||
let ciphertext_modulus = ct1.ct.ciphertext_modulus();
|
||||
|
||||
let ksk = &self.key.key_switching_key;
|
||||
let delta_log = 63 - log_message_modulus * log_carry_modulus;
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let req = || {
|
||||
StackReq::try_any_of([
|
||||
fill_with_forward_fourier_scratch(fft)?,
|
||||
extract_bits_scratch::<u64>(
|
||||
lwe_dimension,
|
||||
LweDimension(polynomial_size.0 * glwe_dimension.0 + 1),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)?,
|
||||
])
|
||||
};
|
||||
let req = req().unwrap();
|
||||
let mut mem = GlobalPodBuffer::new(req);
|
||||
let stack = PodStack::new(&mut mem);
|
||||
let fourier_bsk = match &self.wopbs_key.wopbs_server_key.bootstrapping_key {
|
||||
shortint::server_key::ShortintBootstrappingKey::Classic(fbsk) => fbsk.as_view(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let mut lwe_out_list = LweCiphertextList::new(
|
||||
0u64,
|
||||
ksk.output_lwe_size(),
|
||||
LweCiphertextCount(log_message_modulus),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
extract_bits_parallelized(
|
||||
lwe_out_list.as_mut_view(),
|
||||
ct1.ct.as_view(),
|
||||
ksk.as_view(),
|
||||
fourier_bsk,
|
||||
DeltaLog(delta_log),
|
||||
ExtractedBitsCount(log_message_modulus),
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
let mut out_vec_ggsw: Vec<FourierGgswCiphertext<ABox<[c64]>>> = Vec::new();
|
||||
for lwe in lwe_out_list.iter() {
|
||||
let ggsw = self.ggsw_cbs(&lwe, 0);
|
||||
out_vec_ggsw.append(&mut vec![ggsw]);
|
||||
}
|
||||
out_vec_ggsw
|
||||
}
|
||||
|
||||
//return ct0 in ct0 if we have ggsw(0)
|
||||
//return ct1 in ct0 if we have ggsw(1)
|
||||
//with cti = [Ciphertext]
|
||||
pub fn cmuxes_parallelized(
|
||||
&self,
|
||||
ct0: &[ShortintCiphertext],
|
||||
ct1: &[ShortintCiphertext],
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> Vec<shortint::Ciphertext> {
|
||||
assert_eq!(ct0.len(), ct1.len());
|
||||
let len = ct0.len();
|
||||
let mut vec_output: Vec<ShortintCiphertext> = Vec::with_capacity(len);
|
||||
|
||||
ct0.par_iter()
|
||||
.zip(ct1.par_iter())
|
||||
.map(|(ct_0_i, ct_1_i)| self.cmux(ct_0_i, ct_1_i, ggsw))
|
||||
.collect_into_vec(&mut vec_output);
|
||||
|
||||
vec_output
|
||||
}
|
||||
|
||||
//return ct0 in a nwe ct if we have ggsw(0)
|
||||
//return ct1 in a new ct if we have ggsw(1)
|
||||
//with cti a fp
|
||||
pub fn cmuxes_full_parallelized(
|
||||
&self,
|
||||
ct0: &Ciphertext,
|
||||
ct1: &Ciphertext,
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> Ciphertext {
|
||||
// todo!("cmuxes_full_parallelized");
|
||||
let (res_man, res_exp) = rayon::join(
|
||||
|| self.cmuxes_parallelized(&ct0.ct_vec_mantissa, &ct1.ct_vec_mantissa, &ggsw),
|
||||
|| self.cmuxes_parallelized(&ct0.ct_vec_exponent, &ct1.ct_vec_exponent, &ggsw),
|
||||
);
|
||||
let res_sig = self.cmux(&ct0.ct_sign, &ct1.ct_sign, &ggsw);
|
||||
let mut new = self.create_trivial_zero_from_ct(ct0);
|
||||
new.ct_vec_mantissa = res_man;
|
||||
new.ct_vec_exponent = res_exp;
|
||||
new.ct_sign = res_sig;
|
||||
new
|
||||
}
|
||||
|
||||
pub fn cmux_tree_mantissa_parallelized(
|
||||
&self,
|
||||
vec_mantissa: &Vec<shortint::Ciphertext>,
|
||||
vec_ggsw: &[FourierGgswCiphertext<ABox<[c64]>>],
|
||||
) -> Vec<shortint::Ciphertext> {
|
||||
// todo!("cmux_tree_mantissa_parallelized");
|
||||
let zero = self.key.create_trivial(0_u64);
|
||||
let mut cpy = vec_mantissa.clone();
|
||||
let mut vec_fp = Vec::new();
|
||||
for _ in 0..(vec_mantissa.len() + 1) {
|
||||
vec_fp.push(cpy.clone());
|
||||
cpy.push(zero.clone());
|
||||
let _ = cpy.remove(0);
|
||||
}
|
||||
let vec_zero = cpy;
|
||||
// TODO cmux tree in parallel
|
||||
for ggsw in vec_ggsw.iter().rev() {
|
||||
if vec_fp.len() == 1 {
|
||||
vec_fp[0] = self.cmuxes_parallelized(&mut vec_fp[0], &vec_zero, ggsw);
|
||||
} else {
|
||||
if vec_fp.len() % 2 == 0 {
|
||||
for i in 0..vec_fp.len() / 2 {
|
||||
let ct_0 = vec_fp.get_mut(2 * i).unwrap().clone();
|
||||
let ct_1 = vec_fp.get_mut(2 * i + 1).unwrap().clone();
|
||||
vec_fp[i] = self.cmuxes_parallelized(&ct_0, &ct_1, ggsw);
|
||||
}
|
||||
vec_fp.truncate(vec_fp.len() / 2);
|
||||
} else {
|
||||
for i in 0..vec_fp.len() / 2 {
|
||||
let ct_0 = vec_fp.get_mut(2 * i).unwrap().clone();
|
||||
let ct_1 = vec_fp.get_mut(2 * i + 1).unwrap().clone();
|
||||
vec_fp[i] = self.cmuxes_parallelized(&ct_0, &ct_1, ggsw);
|
||||
}
|
||||
let last = vec_fp.len();
|
||||
let ct_0 = vec_fp.last().unwrap().clone();
|
||||
let ct_1 = &vec_zero;
|
||||
vec_fp[last / 2] = self.cmuxes_parallelized(&ct_0, &ct_1, ggsw);
|
||||
vec_fp.truncate((vec_fp.len() + 1) / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_fp[0].clone()
|
||||
}
|
||||
}
|
||||
9
concrete-float/src/test_user_docs.rs
Normal file
9
concrete-float/src/test_user_docs.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use doc_comment::doctest;
|
||||
|
||||
doctest!("../docs/getting_started/first_circuit.md", first_circuit);
|
||||
doctest!("../docs/tutorials/serialization.md", serialization_tuto);
|
||||
doctest!(
|
||||
"../docs/tutorials/circuit_evaluation.md",
|
||||
circuit_evaluation
|
||||
);
|
||||
doctest!("../docs/how_to/pbs.md", pbs);
|
||||
@@ -55,7 +55,7 @@ concrete-csprng = { version = "0.4.0", path = "../concrete-csprng", features = [
|
||||
] }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
rayon = { version = "1.5.0" }
|
||||
rayon = { version = "1.5" }
|
||||
bincode = { version = "1.3.3", optional = true }
|
||||
concrete-fft = { version = "0.3.0", features = ["serde", "fft128"] }
|
||||
pulp = "0.13"
|
||||
@@ -82,6 +82,7 @@ bytemuck = "1.13.1"
|
||||
boolean = ["dep:paste"]
|
||||
shortint = ["dep:paste"]
|
||||
integer = ["shortint", "dep:paste"]
|
||||
float_wopbs = ["shortint", "dep:paste"]
|
||||
internal-keycache = ["dep:lazy_static", "dep:fs2", "dep:bincode", "dep:paste"]
|
||||
safe-deserialization = ["dep:bincode"]
|
||||
|
||||
@@ -103,7 +104,6 @@ __wasm_api = [
|
||||
"dep:console_error_panic_hook",
|
||||
"dep:serde-wasm-bindgen",
|
||||
"dep:getrandom",
|
||||
"getrandom/js",
|
||||
"dep:bincode",
|
||||
"safe-deserialization",
|
||||
]
|
||||
@@ -210,6 +210,12 @@ path = "benches/utilities.rs"
|
||||
harness = false
|
||||
required-features = ["boolean", "shortint", "integer", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "float-wopbs-bench"
|
||||
path = "benches/float_wopbs/bench.rs"
|
||||
harness = false
|
||||
required-features = []
|
||||
|
||||
# Examples used as tools
|
||||
|
||||
[[example]]
|
||||
@@ -263,3 +269,7 @@ required-features = ["boolean"]
|
||||
|
||||
[lib]
|
||||
crate-type = ["lib", "staticlib", "cdylib"]
|
||||
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(bench)'] }
|
||||
|
||||
90
tfhe/benches/float_wopbs/bench.rs
Normal file
90
tfhe/benches/float_wopbs/bench.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use tfhe::float_wopbs::gen_keys;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use tfhe::float_wopbs::parameters::{
|
||||
PARAM_MESSAGE_2_16_BITS, PARAM_MESSAGE_4_16_BITS, PARAM_MESSAGE_8_16_BITS,
|
||||
};
|
||||
use tfhe::float_wopbs::parameters::{ PARAM_MESSAGE_2_4_8_BITS_BIV, PARAM_MESSAGE_4_2_8_BITS_BIV};
|
||||
use tfhe::shortint::WopbsParameters;
|
||||
|
||||
macro_rules! named_param {
|
||||
($param:ident) => {
|
||||
(stringify!($param), $param)
|
||||
};
|
||||
}
|
||||
|
||||
struct Parameters {
|
||||
parameters: WopbsParameters,
|
||||
bit_mantissa: usize,
|
||||
bit_exponent: usize,
|
||||
}
|
||||
|
||||
|
||||
const PARAM_4_BIT_LWE_8_BITS: Parameters = Parameters {
|
||||
parameters: PARAM_MESSAGE_2_4_8_BITS_BIV,
|
||||
bit_mantissa: 4,
|
||||
bit_exponent: 3,
|
||||
};
|
||||
|
||||
const PARAM_2_BIT_LWE_8_BITS: Parameters = Parameters {
|
||||
parameters: PARAM_MESSAGE_4_2_8_BITS_BIV,
|
||||
bit_mantissa: 4,
|
||||
bit_exponent: 3,
|
||||
};
|
||||
|
||||
|
||||
const SERVER_KEY_BENCH_PARAMS: [(&str, Parameters); 2] =
|
||||
[ named_param!(PARAM_4_BIT_LWE_8_BITS),
|
||||
named_param!(PARAM_2_BIT_LWE_8_BITS)];
|
||||
|
||||
criterion_main!(float);
|
||||
|
||||
criterion_group!(float, float_wopbs_bivariate);
|
||||
|
||||
pub fn float_wopbs_mut_eval(c: &mut Criterion) {
|
||||
for name_param in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(name_param.1.parameters);
|
||||
let bit_mantissa = &name_param.1.bit_mantissa;
|
||||
let bit_exponent = &name_param.1.bit_exponent;
|
||||
let e_min = -2;
|
||||
let msg_1 = 0.375;
|
||||
|
||||
// Encryption:
|
||||
let mut ct_1 = cks.encrypt(msg_1, e_min, *bit_mantissa, *bit_exponent);
|
||||
|
||||
let lut = sks.create_lut(&mut ct_1, |x| x);
|
||||
let bench_id = format!("8-bit floats WoP-PBS lut eval::{}", name_param.0);
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.wop_pbs(&sks, &mut ct_1, &lut);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn float_wopbs_bivariate(c: &mut Criterion) {
|
||||
for name_param in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(name_param.1.parameters);
|
||||
let bit_mantissa = &name_param.1.bit_mantissa;
|
||||
let bit_exponent = &name_param.1.bit_exponent;
|
||||
|
||||
let e_min = -2;
|
||||
let msg_1 = 0.375;
|
||||
|
||||
// Encryption:
|
||||
let mut ct_1 = cks.encrypt(msg_1, e_min, *bit_mantissa, *bit_exponent);
|
||||
let msg_2 = -44.;
|
||||
let mut ct_2 = cks.encrypt(msg_2, e_min, *bit_mantissa, *bit_exponent);
|
||||
|
||||
let lut = sks.create_bivariate_lut(&mut ct_1, |x, y| y * x);
|
||||
let bench_id = format!("8-bit floats WoP-PBS bivariate::{}", name_param.0);
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.wop_pbs_bivariate(&sks, &mut ct_1, &mut ct_2, &lut);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -355,7 +355,7 @@ pub fn add_external_product_assign_scratch<Scalar>(
|
||||
substack0.try_and(StackReq::try_all_of([fourier_scratch; 4])?)
|
||||
}
|
||||
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
pub fn add_external_product_assign<Scalar, ContOut, ContGgsw, ContGlwe>(
|
||||
out: &mut GlweCiphertext<ContOut>,
|
||||
ggsw: &Fourier128GgswCiphertext<ContGgsw>,
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::core_crypto::prelude::{Container, ContainerMut, SignedDecomposer};
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use dyn_stack::{PodStack, ReborrowMut};
|
||||
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
pub fn add_external_product_assign_split<ContOutLo, ContOutHi, ContGgsw, ContGlweLo, ContGlweHi>(
|
||||
out_lo: &mut GlweCiphertext<ContOutLo>,
|
||||
out_hi: &mut GlweCiphertext<ContOutHi>,
|
||||
|
||||
@@ -473,7 +473,7 @@ pub fn add_external_product_assign_scratch<Scalar>(
|
||||
}
|
||||
|
||||
/// Perform the external product of `ggsw` and `glwe`, and adds the result to `out`.
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
pub fn add_external_product_assign<Scalar>(
|
||||
mut out: GlweCiphertextMutView<'_, Scalar>,
|
||||
ggsw: FourierGgswCiphertextView<'_>,
|
||||
@@ -597,7 +597,7 @@ pub fn add_external_product_assign<Scalar>(
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
fn collect_next_term<'a, Scalar: UnsignedTorus>(
|
||||
decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>,
|
||||
substack1: &'a mut PodStack,
|
||||
@@ -612,7 +612,7 @@ fn collect_next_term<'a, Scalar: UnsignedTorus>(
|
||||
(glwe_level, glwe_decomp_term, substack2)
|
||||
}
|
||||
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
pub(crate) fn update_with_fmadd(
|
||||
output_fft_buffer: &mut [c64],
|
||||
lhs_polynomial_list: &[c64],
|
||||
|
||||
@@ -12,6 +12,7 @@ use super::ggsw::{
|
||||
};
|
||||
use crate::core_crypto::algorithms::polynomial_algorithms::*;
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::computation_buffers::ComputationBuffers;
|
||||
use crate::core_crypto::commons::math::decomposition::DecompositionLevel;
|
||||
use crate::core_crypto::commons::numeric::CastInto;
|
||||
use crate::core_crypto::commons::parameters::*;
|
||||
@@ -19,6 +20,8 @@ use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::commons::utils::izip;
|
||||
use crate::core_crypto::entities::*;
|
||||
|
||||
use rayon::prelude::*;
|
||||
|
||||
use concrete_fft::c64;
|
||||
|
||||
pub fn extract_bits_scratch<Scalar>(
|
||||
@@ -224,6 +227,167 @@ pub fn extract_bits<Scalar: UnsignedTorus + CastInto<usize>>(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_bits_parallelized<Scalar: UnsignedTorus + CastInto<usize>>(
|
||||
mut lwe_list_out: LweCiphertextList<&'_ mut [Scalar]>,
|
||||
lwe_in: LweCiphertext<&'_ [Scalar]>,
|
||||
ksk: LweKeyswitchKey<&'_ [Scalar]>,
|
||||
fourier_bsk: FourierLweBootstrapKeyView<'_>,
|
||||
delta_log: DeltaLog,
|
||||
number_of_bits_to_extract: ExtractedBitsCount,
|
||||
fft: FftView<'_>,
|
||||
stack: PodStack<'_>,
|
||||
) {
|
||||
debug_assert!(lwe_list_out.ciphertext_modulus() == lwe_in.ciphertext_modulus());
|
||||
debug_assert!(lwe_in.ciphertext_modulus() == ksk.ciphertext_modulus());
|
||||
debug_assert!(
|
||||
ksk.ciphertext_modulus().is_native_modulus(),
|
||||
"This operation only supports native moduli"
|
||||
);
|
||||
|
||||
let ciphertext_n_bits = Scalar::BITS;
|
||||
let number_of_bits_to_extract = number_of_bits_to_extract.0;
|
||||
|
||||
debug_assert!(
|
||||
ciphertext_n_bits >= number_of_bits_to_extract + delta_log.0,
|
||||
"Tried to extract {} bits, while the maximum number of extractable bits for {} bits
|
||||
ciphertexts and a scaling factor of 2^{} is {}",
|
||||
number_of_bits_to_extract,
|
||||
ciphertext_n_bits,
|
||||
delta_log.0,
|
||||
ciphertext_n_bits - delta_log.0,
|
||||
);
|
||||
debug_assert!(
|
||||
lwe_list_out.lwe_size().to_lwe_dimension() == ksk.output_key_lwe_dimension(),
|
||||
"lwe_list_out needs to have an lwe_size of {}, got {}",
|
||||
ksk.output_key_lwe_dimension().0,
|
||||
lwe_list_out.lwe_size().to_lwe_dimension().0,
|
||||
);
|
||||
debug_assert!(
|
||||
lwe_list_out.lwe_ciphertext_count().0 == number_of_bits_to_extract,
|
||||
"lwe_list_out needs to have a ciphertext count of {}, got {}",
|
||||
number_of_bits_to_extract,
|
||||
lwe_list_out.lwe_ciphertext_count().0,
|
||||
);
|
||||
debug_assert!(
|
||||
lwe_in.lwe_size() == fourier_bsk.output_lwe_dimension().to_lwe_size(),
|
||||
"lwe_in needs to have an LWE dimension of {}, got {}",
|
||||
fourier_bsk.output_lwe_dimension().to_lwe_size().0,
|
||||
lwe_in.lwe_size().0,
|
||||
);
|
||||
debug_assert!(
|
||||
ksk.output_key_lwe_dimension() == fourier_bsk.input_lwe_dimension(),
|
||||
"ksk needs to have an output LWE dimension of {}, got {}",
|
||||
fourier_bsk.input_lwe_dimension().0,
|
||||
ksk.output_key_lwe_dimension().0,
|
||||
);
|
||||
debug_assert!(lwe_list_out.ciphertext_modulus() == lwe_in.ciphertext_modulus());
|
||||
debug_assert!(lwe_in.ciphertext_modulus() == ksk.ciphertext_modulus());
|
||||
|
||||
let polynomial_size = fourier_bsk.polynomial_size();
|
||||
let glwe_size = fourier_bsk.glwe_size();
|
||||
let glwe_dimension = glwe_size.to_glwe_dimension();
|
||||
let ciphertext_modulus = lwe_in.ciphertext_modulus();
|
||||
|
||||
let align = CACHELINE_ALIGN;
|
||||
|
||||
let (mut lwe_in_buffer_data, stack) =
|
||||
stack.collect_aligned(align, lwe_in.as_ref().iter().copied());
|
||||
let mut lwe_in_buffer =
|
||||
LweCiphertext::from_container(&mut *lwe_in_buffer_data, lwe_in.ciphertext_modulus());
|
||||
|
||||
let (mut lwe_out_ks_buffer_data, stack) =
|
||||
stack.make_aligned_with(ksk.output_lwe_size().0, align, |_| Scalar::ZERO);
|
||||
let mut lwe_out_ks_buffer =
|
||||
LweCiphertext::from_container(&mut *lwe_out_ks_buffer_data, ksk.ciphertext_modulus());
|
||||
|
||||
let (mut pbs_accumulator_data, stack) =
|
||||
stack.make_aligned_with(glwe_size.0 * polynomial_size.0, align, |_| Scalar::ZERO);
|
||||
let mut pbs_accumulator = GlweCiphertextMutView::from_container(
|
||||
&mut *pbs_accumulator_data,
|
||||
polynomial_size,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let lwe_size = glwe_dimension
|
||||
.to_equivalent_lwe_dimension(polynomial_size)
|
||||
.to_lwe_size();
|
||||
let (mut lwe_out_pbs_buffer_data, mut stack) =
|
||||
stack.make_aligned_with(lwe_size.0, align, |_| Scalar::ZERO);
|
||||
let mut lwe_out_pbs_buffer = LweCiphertext::from_container(
|
||||
&mut *lwe_out_pbs_buffer_data,
|
||||
lwe_list_out.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
// We iterate on the list in reverse as we want to store the extracted MSB at index 0
|
||||
for (bit_idx, mut output_ct) in lwe_list_out.iter_mut().rev().enumerate() {
|
||||
// Shift on padding bit
|
||||
let (lwe_bit_left_shift_buffer_data, _) = stack.rb_mut().collect_aligned(
|
||||
align,
|
||||
lwe_in_buffer
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|s| *s << (ciphertext_n_bits - delta_log.0 - bit_idx - 1)),
|
||||
);
|
||||
|
||||
// Key switch to input PBS key
|
||||
par_keyswitch_lwe_ciphertext(
|
||||
&ksk,
|
||||
&LweCiphertext::from_container(
|
||||
&*lwe_bit_left_shift_buffer_data,
|
||||
lwe_in.ciphertext_modulus(),
|
||||
),
|
||||
&mut lwe_out_ks_buffer,
|
||||
);
|
||||
|
||||
drop(lwe_bit_left_shift_buffer_data);
|
||||
|
||||
// Store the keyswitch output unmodified to the output list (as we need to to do other
|
||||
// computations on the output of the keyswitch)
|
||||
output_ct
|
||||
.as_mut()
|
||||
.copy_from_slice(lwe_out_ks_buffer.as_ref());
|
||||
|
||||
// If this was the last extracted bit, break
|
||||
// we subtract 1 because if the number_of_bits_to_extract is 1 we want to stop right away
|
||||
if bit_idx == number_of_bits_to_extract - 1 {
|
||||
break;
|
||||
}
|
||||
|
||||
// Add q/4 to center the error while computing a negacyclic LUT
|
||||
let out_ks_body = lwe_out_ks_buffer.get_mut_body().data;
|
||||
*out_ks_body = (*out_ks_body).wrapping_add(Scalar::ONE << (ciphertext_n_bits - 2));
|
||||
|
||||
// Fill lut for the current bit (equivalent to trivial encryption as mask is 0s)
|
||||
// The LUT is filled with -alpha in each coefficient where alpha = delta*2^{bit_idx-1}
|
||||
for poly_coeff in &mut pbs_accumulator
|
||||
.as_mut_view()
|
||||
.get_mut_body()
|
||||
.as_mut_polynomial()
|
||||
.iter_mut()
|
||||
{
|
||||
*poly_coeff = Scalar::ZERO.wrapping_sub(Scalar::ONE << (delta_log.0 - 1 + bit_idx));
|
||||
}
|
||||
|
||||
fourier_bsk.bootstrap(
|
||||
lwe_out_pbs_buffer.as_mut_view(),
|
||||
lwe_out_ks_buffer.as_view(),
|
||||
pbs_accumulator.as_view(),
|
||||
fft,
|
||||
stack.rb_mut(),
|
||||
);
|
||||
|
||||
// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0 if the
|
||||
// extracted bit was 0 and 1 in the other case
|
||||
let out_pbs_body = lwe_out_pbs_buffer.get_mut_body().data;
|
||||
|
||||
*out_pbs_body = (*out_pbs_body).wrapping_add(Scalar::ONE << (delta_log.0 + bit_idx - 1));
|
||||
|
||||
// Remove the extracted bit from the initial LWE to get a 0 at the extracted bit location.
|
||||
izip!(lwe_in_buffer.as_mut(), lwe_out_pbs_buffer.as_ref())
|
||||
.for_each(|(out, inp)| *out = (*out).wrapping_sub(*inp));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn circuit_bootstrap_boolean_scratch<Scalar>(
|
||||
lwe_in_size: LweSize,
|
||||
bsk_output_lwe_size: LweSize,
|
||||
@@ -343,6 +507,137 @@ pub fn circuit_bootstrap_boolean<Scalar: UnsignedTorus + CastInto<usize>>(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn circuit_bootstrap_boolean_parallelized<Scalar: UnsignedTorus + CastInto<usize>>(
|
||||
fourier_bsk: FourierLweBootstrapKeyView<'_>,
|
||||
lwe_in: LweCiphertext<&[Scalar]>,
|
||||
mut ggsw_out: GgswCiphertext<&mut [Scalar]>,
|
||||
delta_log: DeltaLog,
|
||||
pfpksk_list: LwePrivateFunctionalPackingKeyswitchKeyList<&[Scalar]>,
|
||||
fft: FftView<'_>,
|
||||
_stack: PodStack<'_>,
|
||||
) {
|
||||
debug_assert!(lwe_in.ciphertext_modulus() == ggsw_out.ciphertext_modulus());
|
||||
debug_assert!(ggsw_out.ciphertext_modulus() == pfpksk_list.ciphertext_modulus());
|
||||
|
||||
debug_assert!(
|
||||
pfpksk_list.ciphertext_modulus().is_native_modulus(),
|
||||
"This operation currently only supports native moduli"
|
||||
);
|
||||
|
||||
let level_cbs = ggsw_out.decomposition_level_count();
|
||||
let base_log_cbs = ggsw_out.decomposition_base_log();
|
||||
|
||||
debug_assert!(
|
||||
level_cbs.0 >= 1,
|
||||
"level_cbs needs to be >= 1, got {}",
|
||||
level_cbs.0
|
||||
);
|
||||
debug_assert!(
|
||||
base_log_cbs.0 >= 1,
|
||||
"base_log_cbs needs to be >= 1, got {}",
|
||||
base_log_cbs.0
|
||||
);
|
||||
|
||||
let fpksk_input_lwe_key_dimension = pfpksk_list.input_key_lwe_dimension();
|
||||
let fourier_bsk_output_lwe_dimension = fourier_bsk.output_lwe_dimension();
|
||||
|
||||
debug_assert!(
|
||||
fpksk_input_lwe_key_dimension == fourier_bsk_output_lwe_dimension,
|
||||
"The fourier_bsk output_lwe_dimension, got {}, must be equal to the fpksk \
|
||||
input_lwe_key_dimension, got {}",
|
||||
fourier_bsk_output_lwe_dimension.0,
|
||||
fpksk_input_lwe_key_dimension.0
|
||||
);
|
||||
|
||||
let fpksk_output_polynomial_size = pfpksk_list.output_polynomial_size();
|
||||
let fpksk_output_glwe_key_dimension = pfpksk_list.output_key_glwe_dimension();
|
||||
|
||||
debug_assert!(
|
||||
ggsw_out.polynomial_size() == fpksk_output_polynomial_size,
|
||||
"The output GGSW ciphertext needs to have the same polynomial size as the fpksks, \
|
||||
got {}, expected {}",
|
||||
ggsw_out.polynomial_size().0,
|
||||
fpksk_output_polynomial_size.0
|
||||
);
|
||||
|
||||
debug_assert!(
|
||||
ggsw_out.glwe_size().to_glwe_dimension() == fpksk_output_glwe_key_dimension,
|
||||
"The output GGSW ciphertext needs to have the same GLWE dimension as the fpksks, \
|
||||
got {}, expected {}",
|
||||
ggsw_out.glwe_size().to_glwe_dimension().0,
|
||||
fpksk_output_glwe_key_dimension.0
|
||||
);
|
||||
|
||||
debug_assert!(
|
||||
ggsw_out.glwe_size().0 == pfpksk_list.lwe_pfpksk_count().0,
|
||||
"The input vector of pfpksk_list needs to have {} ggsw.glwe_size elements got {}",
|
||||
ggsw_out.glwe_size().0,
|
||||
pfpksk_list.lwe_pfpksk_count().0,
|
||||
);
|
||||
|
||||
// // Output for every bootstrapping
|
||||
// let (mut lwe_out_bs_buffer_data, mut stack) = stack.make_aligned_with(
|
||||
// fourier_bsk_output_lwe_dimension.to_lwe_size().0,
|
||||
// CACHELINE_ALIGN,
|
||||
// |_| Scalar::ZERO,
|
||||
// );
|
||||
// let mut lwe_out_bs_buffer =
|
||||
// LweCiphertext::from_container(&mut *lwe_out_bs_buffer_data, lwe_in.ciphertext_modulus());
|
||||
|
||||
ggsw_out.par_iter_mut().enumerate().for_each(
|
||||
|(decomposition_level_minus_one, mut ggsw_level_matrix)| {
|
||||
let mut computation_buffers = ComputationBuffers::new();
|
||||
computation_buffers.resize(
|
||||
circuit_bootstrap_boolean_scratch::<u64>(
|
||||
lwe_in.lwe_size(),
|
||||
fourier_bsk.output_lwe_dimension().to_lwe_size(),
|
||||
fourier_bsk.glwe_size(),
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let stack = computation_buffers.stack();
|
||||
|
||||
let (mut lwe_out_bs_buffer_data, mut stack) = stack.make_aligned_with(
|
||||
fourier_bsk_output_lwe_dimension.to_lwe_size().0,
|
||||
CACHELINE_ALIGN,
|
||||
|_| Scalar::ZERO,
|
||||
);
|
||||
let mut lwe_out_bs_buffer = LweCiphertext::from_container(
|
||||
&mut *lwe_out_bs_buffer_data,
|
||||
lwe_in.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let decomposition_level = DecompositionLevel(decomposition_level_minus_one + 1);
|
||||
homomorphic_shift_boolean(
|
||||
fourier_bsk,
|
||||
lwe_out_bs_buffer.as_mut_view(),
|
||||
lwe_in.as_view(),
|
||||
decomposition_level,
|
||||
base_log_cbs,
|
||||
delta_log,
|
||||
fft,
|
||||
stack.rb_mut(),
|
||||
);
|
||||
|
||||
pfpksk_list
|
||||
.par_iter()
|
||||
.zip(ggsw_level_matrix.as_mut_glwe_list().par_iter_mut())
|
||||
.for_each(|(pfpksk, mut glwe_out)| {
|
||||
par_private_functional_keyswitch_lwe_ciphertext_into_glwe_ciphertext(
|
||||
&pfpksk,
|
||||
&mut glwe_out,
|
||||
&lwe_out_bs_buffer,
|
||||
);
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn homomorphic_shift_boolean_scratch<Scalar>(
|
||||
lwe_in_size: LweSize,
|
||||
glwe_size: GlweSize,
|
||||
|
||||
@@ -193,7 +193,7 @@ impl Fft {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
fn convert_forward_torus<Scalar: UnsignedTorus>(
|
||||
out: &mut [c64],
|
||||
in_re: &[Scalar],
|
||||
@@ -238,7 +238,7 @@ fn convert_forward_integer_scalar<Scalar: UnsignedTorus>(
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
fn convert_forward_integer<Scalar: UnsignedTorus>(
|
||||
out: &mut [c64],
|
||||
in_re: &[Scalar],
|
||||
@@ -260,7 +260,7 @@ fn convert_forward_integer<Scalar: UnsignedTorus>(
|
||||
convert_forward_integer_scalar::<Scalar>(out, in_re, in_im, twisties);
|
||||
}
|
||||
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
fn convert_backward_torus<Scalar: UnsignedTorus>(
|
||||
out_re: &mut [Scalar],
|
||||
out_im: &mut [Scalar],
|
||||
@@ -303,7 +303,7 @@ fn convert_add_backward_torus_scalar<Scalar: UnsignedTorus>(
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
fn convert_add_backward_torus<Scalar: UnsignedTorus>(
|
||||
out_re: &mut [Scalar],
|
||||
out_im: &mut [Scalar],
|
||||
|
||||
33
tfhe/src/float_wopbs/ciphertext/mod.rs
Normal file
33
tfhe/src/float_wopbs/ciphertext/mod.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! This module implements the ciphertext structure containing an encryption of an integer message.
|
||||
use crate::shortint::ciphertext::Ciphertext as ShortintCiphertext;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Id to recognize the key used to encrypt a block.
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct KeyId(pub usize);
|
||||
|
||||
#[derive(Serialize, Clone, Deserialize)]
|
||||
pub struct Ciphertext {
|
||||
pub(crate) ct_vec_float: Vec<ShortintCiphertext>,
|
||||
pub(crate) nb_bit_mantissa: usize,
|
||||
pub(crate) nb_bit_exponent: usize,
|
||||
pub(crate) e_min: i64,
|
||||
|
||||
pub(crate) key_id_vec: Vec<KeyId>,
|
||||
}
|
||||
|
||||
impl Ciphertext {
|
||||
/// Returns the slice of blocks that the ciphertext is composed of.
|
||||
pub fn ct_vec_float(&self) -> &[ShortintCiphertext] {
|
||||
&self.ct_vec_float
|
||||
}
|
||||
pub fn nb_bit_mantissa(&self) -> &usize {
|
||||
&self.nb_bit_mantissa
|
||||
}
|
||||
pub fn nb_bit_exponent(&self) -> &usize {
|
||||
&self.nb_bit_exponent
|
||||
}
|
||||
pub fn e_min(&self) -> &i64 {
|
||||
&self.e_min
|
||||
}
|
||||
}
|
||||
165
tfhe/src/float_wopbs/client_key/mod.rs
Normal file
165
tfhe/src/float_wopbs/client_key/mod.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
//! This module implements the generation of the client secret keys, together with the
|
||||
//! encryption and decryption methods.
|
||||
|
||||
pub(crate) mod utils;
|
||||
|
||||
use crate::float_wopbs::ciphertext::Ciphertext;
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use utils::radix_decomposition;
|
||||
use crate::shortint::ClassicPBSParameters;
|
||||
use crate::shortint::WopbsParameters;
|
||||
|
||||
/// The number of ciphertexts in the vector.
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct VecLength(pub usize);
|
||||
|
||||
/// A structure containing the client key, which must be kept secret.
|
||||
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
|
||||
pub struct ClientKey {
|
||||
pub(crate) key: crate::shortint::client_key::ClientKey,
|
||||
}
|
||||
|
||||
impl ClientKey {
|
||||
pub fn new(parameter_set: (ClassicPBSParameters, WopbsParameters))-> Self {
|
||||
Self {
|
||||
key: crate::shortint::ClientKey::new(parameter_set),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_shortint(key: crate::shortint::client_key::ClientKey) -> Self {
|
||||
Self { key }
|
||||
}
|
||||
|
||||
/// Returns the parameters used by the client key.
|
||||
pub fn parameters(&self) -> crate::shortint::parameters::ShortintParameterSet {
|
||||
self.key.parameters
|
||||
}
|
||||
|
||||
pub fn encrypt(
|
||||
&self,
|
||||
message: f64,
|
||||
e_min: i64,
|
||||
nb_bit_mantissa: usize,
|
||||
nb_bit_exponent: usize,
|
||||
) -> Ciphertext {
|
||||
let uint = float_to_uint(message, e_min, nb_bit_mantissa, nb_bit_exponent);
|
||||
let mut ct_vec_float: Vec<crate::shortint::ciphertext::Ciphertext> = Vec::new();
|
||||
let message_modulus = f64::log2((self.parameters().message_modulus().0) as f64) as usize;
|
||||
let mut vector_length = (nb_bit_mantissa + nb_bit_exponent + 1) / message_modulus;
|
||||
if vector_length * message_modulus != nb_bit_mantissa + nb_bit_exponent + 1 {
|
||||
vector_length += 1;
|
||||
}
|
||||
let mut power = 1_u64;
|
||||
for _ in 0..vector_length {
|
||||
let mut decomp = uint & (((1 << message_modulus) - 1) * power);
|
||||
decomp /= power;
|
||||
|
||||
// encryption
|
||||
let ct = self.key.encrypt(decomp);
|
||||
ct_vec_float.push(ct);
|
||||
//modulus to the power i
|
||||
power *= 1 << message_modulus;
|
||||
}
|
||||
|
||||
Ciphertext {
|
||||
ct_vec_float,
|
||||
nb_bit_mantissa,
|
||||
nb_bit_exponent,
|
||||
e_min,
|
||||
key_id_vec: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
//decrypt function for the all wopbs representation
|
||||
pub fn decrypt(&self, ctxt: &Ciphertext) -> f64 {
|
||||
let integer_result = self.decrypt_(ctxt);
|
||||
uint_to_float(
|
||||
integer_result,
|
||||
ctxt.e_min,
|
||||
ctxt.nb_bit_mantissa,
|
||||
ctxt.nb_bit_exponent,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn decrypt_(&self, ctxt: &Ciphertext) -> u64 {
|
||||
let mut result = 0_u64;
|
||||
let mut shift = 1_u64;
|
||||
for c_i in ctxt.ct_vec_float.iter() {
|
||||
//decrypt the component i of the integer and multiply it by the radix product
|
||||
let tmp = self.key.decrypt_message_and_carry(c_i).wrapping_mul(shift);
|
||||
|
||||
// update the result
|
||||
result = result.wrapping_add(tmp);
|
||||
|
||||
// update the shift for the next iteration
|
||||
shift = shift.wrapping_mul(self.parameters().message_modulus().0 as u64);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
pub fn float_to_uint(
|
||||
mut float: f64,
|
||||
e_min: i64,
|
||||
nb_bit_mantissa: usize,
|
||||
nb_bit_exponent: usize,
|
||||
) -> u64 {
|
||||
let min = 2.0_f64.powi(e_min as i32);
|
||||
let max = (2.0_f64.powi(nb_bit_mantissa as i32) - 1.) * (2.0_f64.powi(nb_bit_exponent as i32 + e_min as i32) );
|
||||
|
||||
let sign: u64;
|
||||
if float > 0. {
|
||||
sign = 0;
|
||||
} else {
|
||||
sign = 1;
|
||||
float *= -1.;
|
||||
}
|
||||
|
||||
let mut exponent = 0;
|
||||
let mut mantissa;
|
||||
if float == 0. {
|
||||
exponent = 0;
|
||||
mantissa = 0;
|
||||
} else if float >= max {
|
||||
//infinity
|
||||
exponent = (1 << nb_bit_exponent) - 1;
|
||||
mantissa = 0;
|
||||
} else if float < min {
|
||||
//subnormal values
|
||||
exponent = 0;
|
||||
mantissa =
|
||||
(float * ((1 << (nb_bit_mantissa - 1)) * (1 << (e_min.abs() - 1))) as f64) as u64;
|
||||
} else {
|
||||
//Normalized values
|
||||
while float < 1. {
|
||||
float *= 2.;
|
||||
exponent -= 1
|
||||
}
|
||||
while float > 1. {
|
||||
float /= 2.;
|
||||
exponent += 1
|
||||
}
|
||||
mantissa = (float * (1 << (nb_bit_mantissa)) as f64).round() as u64;
|
||||
if mantissa >= 1 << (nb_bit_mantissa){
|
||||
mantissa = mantissa >> 1;
|
||||
exponent += 1
|
||||
}
|
||||
exponent -= e_min;
|
||||
}
|
||||
let mantissa = mantissa & ((1 << nb_bit_mantissa) - 1);
|
||||
|
||||
let exponent = (exponent as u64) & ((1 << nb_bit_exponent) - 1);
|
||||
(sign << (nb_bit_mantissa + nb_bit_exponent)) + (exponent << nb_bit_mantissa) + mantissa
|
||||
}
|
||||
|
||||
pub fn uint_to_float(int: u64, e_min: i64, nb_bit_mantissa: usize, nb_bit_exponent: usize) -> f64 {
|
||||
let mantissa = (int % (1 << nb_bit_mantissa)) as f64 / (1 << (nb_bit_mantissa)) as f64;
|
||||
let exponent = ((int >> nb_bit_mantissa) % (1 << nb_bit_exponent)) as i64;
|
||||
let sign = int >> (nb_bit_exponent + nb_bit_mantissa);
|
||||
let value = mantissa * 2_f64.powi((exponent + e_min) as i32);
|
||||
if sign == 0 {
|
||||
value
|
||||
} else {
|
||||
-value
|
||||
}
|
||||
}
|
||||
51
tfhe/src/float_wopbs/client_key/utils.rs
Normal file
51
tfhe/src/float_wopbs/client_key/utils.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct RadixDecomposition {
|
||||
pub msg_space: usize,
|
||||
pub block_number: usize,
|
||||
}
|
||||
|
||||
/// Computes possible radix decompositions
|
||||
///
|
||||
/// Takes the number of bit of the message space as input and output a vector containing all the
|
||||
/// correct
|
||||
/// possible block decomposition assuming the same message space for all blocks.
|
||||
/// Lower and upper bounds define the minimal and maximal space to be considered
|
||||
/// Example: 6,2,4 -> \[\[2,3\], \[3,2\]\] : \[msg_space = 2 bits, block_number = 3\]
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::float::client_key::radix_decomposition;
|
||||
/// let input_space = 16; //
|
||||
/// let min = 2;
|
||||
/// let max = 4;
|
||||
/// let decomp = radix_decomposition(input_space, min, max);
|
||||
///
|
||||
/// // Check that 3 possible radix decompositions are provided
|
||||
/// assert_eq!(decomp.len(), 3);
|
||||
/// ```
|
||||
pub fn radix_decomposition(
|
||||
input_space: usize,
|
||||
min_space: usize,
|
||||
max_space: usize,
|
||||
) -> Vec<RadixDecomposition> {
|
||||
let mut out: Vec<RadixDecomposition> = vec![];
|
||||
let mut max = max_space;
|
||||
if max_space > input_space {
|
||||
max = input_space;
|
||||
}
|
||||
for msg_space in min_space..max + 1 {
|
||||
let mut block_number = input_space / msg_space;
|
||||
//Manual ceil of the division
|
||||
if input_space % msg_space != 0 {
|
||||
block_number += 1;
|
||||
}
|
||||
out.push(RadixDecomposition {
|
||||
msg_space,
|
||||
block_number,
|
||||
})
|
||||
}
|
||||
out
|
||||
}
|
||||
86
tfhe/src/float_wopbs/keycache.rs
Normal file
86
tfhe/src/float_wopbs/keycache.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use crate::float_wopbs::{ClientKey, ServerKey};
|
||||
use crate::shortint::WopbsParameters;
|
||||
use lazy_static::lazy_static;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct FloatKeyCache;
|
||||
|
||||
const FLOAT_KEY_DIR: &str = "../keys/float/";
|
||||
|
||||
impl FloatKeyCache {
|
||||
pub fn get_from_params(&self, wopbs_params: WopbsParameters) -> (ClientKey, ServerKey) {
|
||||
let pbs_params = crate::shortint::parameters::ClassicPBSParameters {
|
||||
lwe_dimension: wopbs_params.lwe_dimension,
|
||||
glwe_dimension: wopbs_params.glwe_dimension,
|
||||
polynomial_size: wopbs_params.polynomial_size,
|
||||
lwe_modular_std_dev: wopbs_params.lwe_modular_std_dev,
|
||||
glwe_modular_std_dev: wopbs_params.glwe_modular_std_dev,
|
||||
pbs_base_log: wopbs_params.pbs_base_log,
|
||||
pbs_level: wopbs_params.pbs_level,
|
||||
ks_base_log: wopbs_params.ks_base_log,
|
||||
ks_level: wopbs_params.ks_level,
|
||||
message_modulus: wopbs_params.message_modulus,
|
||||
carry_modulus: wopbs_params.carry_modulus,
|
||||
ciphertext_modulus: wopbs_params.ciphertext_modulus,
|
||||
encryption_key_choice: wopbs_params.encryption_key_choice,
|
||||
};
|
||||
|
||||
let params = (pbs_params, wopbs_params);
|
||||
|
||||
let keys = crate::shortint::keycache::KEY_CACHE_WOPBS.get_from_param(params);
|
||||
let (client_key, server_key) = (keys.client_key(), keys.server_key());
|
||||
// TODO DANGER
|
||||
let wopbs_key =
|
||||
crate::shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(client_key, server_key);
|
||||
let client_key = ClientKey::from_shortint(client_key.clone());
|
||||
let server_key = ServerKey::from_shortint(&client_key, server_key.clone(), wopbs_key);
|
||||
(client_key, server_key)
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
pub static ref KEY_CACHE: FloatKeyCache = FloatKeyCache::default();
|
||||
}
|
||||
|
||||
pub fn get_sks(str: &str) -> Option<ServerKey> {
|
||||
let fiptr = format!("{}SKS_{}.bin", FLOAT_KEY_DIR,str);
|
||||
let filepath = Path::new(&fiptr);
|
||||
let file = File::open(filepath);
|
||||
let file = match file {
|
||||
Ok(file) => file,
|
||||
Err(_) => return None,
|
||||
};
|
||||
let file = BufReader::new(file);
|
||||
let saved_key: ServerKey = bincode::deserialize_from(file).unwrap();
|
||||
Some(saved_key)
|
||||
}
|
||||
|
||||
pub fn get_cks(str: &str) -> Option<ClientKey> {
|
||||
let fiptr = format!("{}CKS_{}.bin", FLOAT_KEY_DIR,str);
|
||||
let filepath = Path::new(&fiptr);
|
||||
let file = File::open(filepath);
|
||||
let file = match file {
|
||||
Ok(file) => file,
|
||||
Err(_) => return None,
|
||||
};
|
||||
let file = BufReader::new(file);
|
||||
let saved_key: ClientKey = bincode::deserialize_from(file).unwrap();
|
||||
Some(saved_key)
|
||||
}
|
||||
|
||||
pub fn save_sks(key: &ServerKey, str: &str) {
|
||||
let filepath = format!("{}SKS_{}.bin", FLOAT_KEY_DIR,str);
|
||||
std::fs::create_dir_all(FLOAT_KEY_DIR).unwrap();
|
||||
let file = BufWriter::new(File::create(filepath).unwrap());
|
||||
bincode::serialize_into(file, key).unwrap();
|
||||
}
|
||||
|
||||
pub fn save_cks(key: &ClientKey, str: &str) {
|
||||
let filepath = format!("{}CKS_{}.bin", FLOAT_KEY_DIR,str);
|
||||
std::fs::create_dir_all(FLOAT_KEY_DIR).unwrap();
|
||||
let file = BufWriter::new(File::create(filepath).unwrap());
|
||||
bincode::serialize_into(file, key).unwrap();
|
||||
}
|
||||
55
tfhe/src/float_wopbs/mod.rs
Executable file
55
tfhe/src/float_wopbs/mod.rs
Executable file
@@ -0,0 +1,55 @@
|
||||
//! Welcome the the `concrete-float` documentation!
|
||||
//!
|
||||
//! # Description
|
||||
//!
|
||||
//! This library makes it possible to execute floating point operations.
|
||||
//!
|
||||
//! It allows to execute a floating point circuit on an untrusted server because both circuit inputs
|
||||
//! and outputs are kept private.
|
||||
//!
|
||||
//! Data is encrypted on the client side, before being sent to the server.
|
||||
//! On the server side every computation is performed on ciphertexts
|
||||
extern crate core;
|
||||
|
||||
pub mod ciphertext;
|
||||
pub mod client_key;
|
||||
#[cfg(any(test, doctest, feature = "internal-keycache"))]
|
||||
pub mod keycache;
|
||||
pub mod parameters;
|
||||
pub mod server_key;
|
||||
|
||||
pub use ciphertext::Ciphertext;
|
||||
pub use client_key::ClientKey;
|
||||
pub use server_key::{CheckError, ServerKey};
|
||||
|
||||
|
||||
|
||||
/// Generate a couple of client and server keys with given parameters
|
||||
///
|
||||
/// * the client key is used to encrypt and decrypt and has to be kept secret;
|
||||
/// * the server key is used to perform homomorphic operations on the server side and it is meant to
|
||||
/// be published (the client sends it to the server).
|
||||
///
|
||||
pub fn gen_keys(
|
||||
parameters_set: crate::shortint::parameters::WopbsParameters,
|
||||
) -> (ClientKey, ServerKey) {
|
||||
let pbs_params = crate::shortint::parameters::ClassicPBSParameters {
|
||||
lwe_dimension: parameters_set.lwe_dimension,
|
||||
glwe_dimension: parameters_set.glwe_dimension,
|
||||
polynomial_size: parameters_set.polynomial_size,
|
||||
lwe_modular_std_dev: parameters_set.lwe_modular_std_dev,
|
||||
glwe_modular_std_dev: parameters_set.glwe_modular_std_dev,
|
||||
pbs_base_log: parameters_set.pbs_base_log,
|
||||
pbs_level: parameters_set.pbs_level,
|
||||
ks_base_log: parameters_set.ks_base_log,
|
||||
ks_level: parameters_set.ks_level,
|
||||
message_modulus: parameters_set.message_modulus,
|
||||
carry_modulus: parameters_set.carry_modulus,
|
||||
ciphertext_modulus: parameters_set.ciphertext_modulus,
|
||||
encryption_key_choice: parameters_set.encryption_key_choice,
|
||||
};
|
||||
let params = (pbs_params, parameters_set);
|
||||
let cks = ClientKey::new(params);
|
||||
let sks = ServerKey::new(&cks);
|
||||
(cks,sks)
|
||||
}
|
||||
213
tfhe/src/float_wopbs/parameters/mod.rs
Normal file
213
tfhe/src/float_wopbs/parameters/mod.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
#![allow(clippy::excessive_precision)]
|
||||
pub use crate::shortint::parameters::{EncryptionKeyChoice, WopbsParameters};
|
||||
use crate::shortint::CiphertextModulus;
|
||||
|
||||
use crate::shortint::parameters::{CarryModulus, MessageModulus};
|
||||
pub use crate::shortint::parameters::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, DispersionParameter, GlweDimension,
|
||||
LweDimension, PolynomialSize, StandardDev,
|
||||
};
|
||||
|
||||
pub const ALL_PARAMETER_VEC_INTEGER_16_BITS: [WopbsParameters; 4] = [
|
||||
PARAM_MESSAGE_8_16_BITS,
|
||||
PARAM_MESSAGE_4_16_BITS,
|
||||
PARAM_MESSAGE_2_16_BITS,
|
||||
PARAM_TEST_WOP,
|
||||
];
|
||||
|
||||
//TODO toy parameters
|
||||
// /!\ unsecure
|
||||
pub const PARAM_TEST_WOP: WopbsParameters = WopbsParameters {
|
||||
lwe_dimension: LweDimension(10),
|
||||
glwe_dimension: GlweDimension(1),
|
||||
polynomial_size: PolynomialSize(1024),
|
||||
lwe_modular_std_dev: StandardDev(0.0000000000000000000004168323308734758),
|
||||
glwe_modular_std_dev: StandardDev(0.00000000000000000000000000000000000004905643852600863),
|
||||
pbs_base_log: DecompositionBaseLog(7),
|
||||
pbs_level: DecompositionLevelCount(6),
|
||||
ks_base_log: DecompositionBaseLog(1),
|
||||
ks_level: DecompositionLevelCount(14),
|
||||
pfks_level: DecompositionLevelCount(6),
|
||||
pfks_base_log: DecompositionBaseLog(7),
|
||||
pfks_modular_std_dev: StandardDev(0.000000000000000000000000000000000000004905643852600863),
|
||||
cbs_level: DecompositionLevelCount(7),
|
||||
cbs_base_log: DecompositionBaseLog(4),
|
||||
message_modulus: MessageModulus(16),
|
||||
carry_modulus: CarryModulus(1),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
};
|
||||
|
||||
|
||||
pub const PARAM_MESSAGE_2_4_8_BITS_BIV: WopbsParameters = WopbsParameters {
|
||||
lwe_dimension: LweDimension(592),
|
||||
glwe_dimension: GlweDimension(2),
|
||||
polynomial_size: PolynomialSize(1024),
|
||||
lwe_modular_std_dev: StandardDev(0.00014316832876365714),
|
||||
glwe_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
|
||||
pbs_base_log: DecompositionBaseLog(9),
|
||||
pbs_level: DecompositionLevelCount(4),
|
||||
ks_base_log: DecompositionBaseLog(2),
|
||||
ks_level: DecompositionLevelCount(5),
|
||||
pfks_level: DecompositionLevelCount(2),
|
||||
pfks_base_log: DecompositionBaseLog(17),
|
||||
pfks_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
|
||||
cbs_level: DecompositionLevelCount(1),
|
||||
cbs_base_log: DecompositionBaseLog(14),
|
||||
message_modulus: MessageModulus(16),
|
||||
carry_modulus: CarryModulus(1),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
};
|
||||
|
||||
pub const PARAM_MESSAGE_4_2_8_BITS_BIV: WopbsParameters = WopbsParameters {
|
||||
lwe_dimension: LweDimension(564),
|
||||
glwe_dimension: GlweDimension(2),
|
||||
polynomial_size: PolynomialSize(1024),
|
||||
lwe_modular_std_dev: StandardDev(0.00024077946887044908),
|
||||
glwe_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
|
||||
pbs_base_log: DecompositionBaseLog(12),
|
||||
pbs_level: DecompositionLevelCount(3),
|
||||
ks_base_log: DecompositionBaseLog(2),
|
||||
ks_level: DecompositionLevelCount(5),
|
||||
pfks_level: DecompositionLevelCount(2),
|
||||
pfks_base_log: DecompositionBaseLog(17),
|
||||
pfks_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
|
||||
cbs_level: DecompositionLevelCount(1),
|
||||
cbs_base_log: DecompositionBaseLog(13),
|
||||
message_modulus: MessageModulus(4),
|
||||
carry_modulus: CarryModulus(1),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
};
|
||||
|
||||
|
||||
pub const PARAM_MESSAGE_5_2_8_BITS_BIV: WopbsParameters = WopbsParameters {
|
||||
lwe_dimension: LweDimension(635),
|
||||
glwe_dimension: GlweDimension(4),
|
||||
polynomial_size: PolynomialSize(512),
|
||||
lwe_modular_std_dev: StandardDev(-13.91),
|
||||
glwe_modular_std_dev: StandardDev(-51.49),
|
||||
pbs_base_log: DecompositionBaseLog(8),
|
||||
pbs_level: DecompositionLevelCount(5),
|
||||
ks_base_log: DecompositionBaseLog(2),
|
||||
ks_level: DecompositionLevelCount(6),
|
||||
pfks_level: DecompositionLevelCount(2),
|
||||
pfks_base_log: DecompositionBaseLog(17),
|
||||
pfks_modular_std_dev: StandardDev(-51.49),
|
||||
cbs_level: DecompositionLevelCount(1),
|
||||
cbs_base_log: DecompositionBaseLog(14),
|
||||
message_modulus: MessageModulus(32),
|
||||
carry_modulus: CarryModulus(1),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
};
|
||||
|
||||
|
||||
pub const PARAM_MESSAGE_2_4_8_BITS_TRI: WopbsParameters = WopbsParameters {
|
||||
lwe_dimension: LweDimension(589),
|
||||
glwe_dimension: GlweDimension(1),
|
||||
polynomial_size: PolynomialSize(2048),
|
||||
lwe_modular_std_dev: StandardDev(0.00015133150634020836),
|
||||
glwe_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
|
||||
pbs_base_log: DecompositionBaseLog(8),
|
||||
pbs_level: DecompositionLevelCount(5),
|
||||
ks_base_log: DecompositionBaseLog(2),
|
||||
ks_level: DecompositionLevelCount(5),
|
||||
pfks_level: DecompositionLevelCount(2),
|
||||
pfks_base_log: DecompositionBaseLog(17),
|
||||
pfks_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
|
||||
cbs_level: DecompositionLevelCount(1),
|
||||
cbs_base_log: DecompositionBaseLog(14),
|
||||
message_modulus: MessageModulus(16),
|
||||
carry_modulus: CarryModulus(1),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
};
|
||||
|
||||
|
||||
pub const PARAM_MESSAGE_4_2_8_BITS_TRI: WopbsParameters = WopbsParameters {
|
||||
lwe_dimension: LweDimension(573),
|
||||
glwe_dimension: GlweDimension(1),
|
||||
polynomial_size: PolynomialSize(2048),
|
||||
lwe_modular_std_dev: StandardDev(0.00020387888657919176),
|
||||
glwe_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
|
||||
pbs_base_log: DecompositionBaseLog(11),
|
||||
pbs_level: DecompositionLevelCount(3),
|
||||
ks_base_log: DecompositionBaseLog(2),
|
||||
ks_level: DecompositionLevelCount(5),
|
||||
pfks_level: DecompositionLevelCount(2),
|
||||
pfks_base_log: DecompositionBaseLog(17),
|
||||
pfks_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
|
||||
cbs_level: DecompositionLevelCount(1),
|
||||
cbs_base_log: DecompositionBaseLog(12),
|
||||
message_modulus: MessageModulus(4),
|
||||
carry_modulus: CarryModulus(1),
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
};
|
||||
|
||||
|
||||
pub const PARAM_MESSAGE_2_16_BITS: WopbsParameters = WopbsParameters {
|
||||
lwe_dimension: LweDimension(493),
|
||||
glwe_dimension: GlweDimension(1),
|
||||
polynomial_size: PolynomialSize(2048),
|
||||
lwe_modular_std_dev: StandardDev(0.00049144710341316649172),
|
||||
glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951),
|
||||
pbs_base_log: DecompositionBaseLog(16),
|
||||
pbs_level: DecompositionLevelCount(2),
|
||||
ks_level: DecompositionLevelCount(5),
|
||||
ks_base_log: DecompositionBaseLog(2),
|
||||
pfks_level: DecompositionLevelCount(2),
|
||||
pfks_base_log: DecompositionBaseLog(16),
|
||||
pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951),
|
||||
cbs_level: DecompositionLevelCount(6),
|
||||
cbs_base_log: DecompositionBaseLog(3),
|
||||
message_modulus: MessageModulus(4),
|
||||
carry_modulus: CarryModulus(1),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
};
|
||||
|
||||
|
||||
pub const PARAM_MESSAGE_8_16_BITS: WopbsParameters = WopbsParameters {
|
||||
lwe_dimension: LweDimension(481),
|
||||
glwe_dimension: GlweDimension(1),
|
||||
polynomial_size: PolynomialSize(2048),
|
||||
lwe_modular_std_dev: StandardDev(0.00061200133780220371345),
|
||||
glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951),
|
||||
pbs_base_log: DecompositionBaseLog(9),
|
||||
pbs_level: DecompositionLevelCount(4),
|
||||
ks_level: DecompositionLevelCount(9),
|
||||
ks_base_log: DecompositionBaseLog(1),
|
||||
pfks_level: DecompositionLevelCount(4),
|
||||
pfks_base_log: DecompositionBaseLog(9),
|
||||
pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951),
|
||||
cbs_level: DecompositionLevelCount(4),
|
||||
cbs_base_log: DecompositionBaseLog(6),
|
||||
message_modulus: MessageModulus(256),
|
||||
carry_modulus: CarryModulus(1),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
};
|
||||
|
||||
pub const PARAM_MESSAGE_4_16_BITS: WopbsParameters = WopbsParameters {
|
||||
lwe_dimension: LweDimension(493),
|
||||
glwe_dimension: GlweDimension(1),
|
||||
polynomial_size: PolynomialSize(2048),
|
||||
lwe_modular_std_dev: StandardDev(0.00049144710341316649172),
|
||||
glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951),
|
||||
pbs_base_log: DecompositionBaseLog(16),
|
||||
pbs_level: DecompositionLevelCount(2),
|
||||
ks_level: DecompositionLevelCount(5),
|
||||
ks_base_log: DecompositionBaseLog(2),
|
||||
pfks_level: DecompositionLevelCount(2),
|
||||
pfks_base_log: DecompositionBaseLog(16),
|
||||
pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951),
|
||||
cbs_level: DecompositionLevelCount(6),
|
||||
cbs_base_log: DecompositionBaseLog(3),
|
||||
message_modulus: MessageModulus(16),
|
||||
carry_modulus: CarryModulus(1),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
};
|
||||
290
tfhe/src/float_wopbs/server_key/mod.rs
Normal file
290
tfhe/src/float_wopbs/server_key/mod.rs
Normal file
@@ -0,0 +1,290 @@
|
||||
//! Module with the definition of the ServerKey.
|
||||
//!
|
||||
//! This module implements the generation of the server public key, together with all the
|
||||
//! available homomorphic integer operations.
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use crate::core_crypto::commons::parameters::{CiphertextCount, DeltaLog, LweCiphertextCount};
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::float_wopbs::ciphertext::Ciphertext;
|
||||
use crate::float_wopbs::client_key::{float_to_uint, uint_to_float, ClientKey};
|
||||
use crate::shortint::ciphertext::Degree;
|
||||
use crate::shortint::wopbs::WopbsLUTBase;
|
||||
use crate::core_crypto::commons::parameters;
|
||||
use crate::shortint::ciphertext::MaxDegree;
|
||||
use crate::shortint::ciphertext::NoiseLevel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::shortint;
|
||||
|
||||
|
||||
/// Error returned when the carry buffer is full.
|
||||
pub use crate::shortint::CheckError;
|
||||
|
||||
/// A structure containing the server public key.
|
||||
///
|
||||
/// The server key is generated by the client and is meant to be published: the client
|
||||
/// sends it to the server so it can compute homomorphic integer circuits.
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct ServerKey {
|
||||
pub key: shortint::server_key::ServerKey,
|
||||
pub wopbs_key: shortint::wopbs::WopbsKey,
|
||||
}
|
||||
|
||||
impl ServerKey {
|
||||
/// Generates a server key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::float::parameters::PARAM_MESSAGE_4_16_BITS;
|
||||
/// use tfhe::float::{ClientKey, ServerKey};
|
||||
///
|
||||
/// // Generate the client key:
|
||||
/// let cks = ClientKey::new(PARAM_MESSAGE_4_16_BITS);
|
||||
///
|
||||
/// // Generate the server key:
|
||||
/// let sks = ServerKey::new(&cks);
|
||||
/// ```
|
||||
pub fn new(cks: &ClientKey) -> ServerKey {
|
||||
// It should remain just enough space to add a carry
|
||||
let max =
|
||||
(cks.key.parameters.message_modulus().0 - 1) * cks.key.parameters.carry_modulus().0 - 1;
|
||||
let sks =
|
||||
shortint::server_key::ServerKey::new_with_max_degree(&cks.key, MaxDegree(max));
|
||||
// TODO DANGER
|
||||
let wopbs_key =
|
||||
shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(&cks.key, &sks);
|
||||
ServerKey {
|
||||
key: sks,
|
||||
wopbs_key,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a ServerKey from an already generated shortint::ServerKey.
|
||||
pub fn from_shortint(
|
||||
cks: &ClientKey,
|
||||
mut key: crate::shortint::server_key::ServerKey,
|
||||
wopbs_key: crate::shortint::wopbs::WopbsKey,
|
||||
) -> ServerKey {
|
||||
// It should remain just enough space add a carry
|
||||
let max =
|
||||
(cks.key.parameters.message_modulus().0 - 1) * cks.key.parameters.carry_modulus().0 - 1;
|
||||
key.max_degree = MaxDegree(max);
|
||||
ServerKey { key, wopbs_key }
|
||||
}
|
||||
|
||||
pub fn create_lut<F: Fn(f64) -> f64>(&self, ct: &mut Ciphertext, f: F) -> Vec<Vec<u64>> {
|
||||
let log_msg_mod = f64::log2(self.wopbs_key.param.message_modulus.0 as f64).ceil() as u64;
|
||||
let bit_mantissa = ct.nb_bit_mantissa;
|
||||
let bit_exponent = ct.nb_bit_exponent;
|
||||
let len_vec = ct.ct_vec_float.len();
|
||||
|
||||
let e_min = ct.e_min;
|
||||
let total_bit = (bit_exponent + bit_mantissa) as u64 + 1;
|
||||
let mut lut_len = 1 << total_bit;
|
||||
if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
|
||||
lut_len = self.wopbs_key.param.polynomial_size.0;
|
||||
}
|
||||
let mut vec_lut = vec![vec![0; lut_len]; len_vec];
|
||||
|
||||
for value in 0..1 << total_bit {
|
||||
let float = uint_to_float(value, e_min, bit_mantissa, bit_exponent);
|
||||
let mut encoded_float = float_to_uint(f(float), e_min, bit_mantissa, bit_exponent);
|
||||
for lut in vec_lut.iter_mut() {
|
||||
lut[value as usize] =
|
||||
(encoded_float % (1 << log_msg_mod)) * ((1_u64 << 63) / (1 << log_msg_mod));
|
||||
encoded_float >>= log_msg_mod;
|
||||
}
|
||||
}
|
||||
vec_lut
|
||||
}
|
||||
|
||||
pub fn create_bivariate_lut<F: Fn(f64, f64) -> f64>(
|
||||
&self,
|
||||
ct: &mut Ciphertext,
|
||||
f: F,
|
||||
) -> Vec<Vec<u64>> {
|
||||
let log_msg_mod = f64::log2(self.wopbs_key.param.message_modulus.0 as f64).ceil() as u64;
|
||||
let bit_mantissa = ct.nb_bit_mantissa;
|
||||
let bit_exponent = ct.nb_bit_exponent;
|
||||
let len_vec = ct.ct_vec_float.len();
|
||||
|
||||
let e_min = ct.e_min;
|
||||
let total_bit = 2 * ((bit_exponent + bit_mantissa) as u64 + 1);
|
||||
let mut lut_len = 1 << total_bit;
|
||||
if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
|
||||
lut_len = self.wopbs_key.param.polynomial_size.0;
|
||||
}
|
||||
let mut vec_lut = vec![vec![0; lut_len]; len_vec];
|
||||
|
||||
for value in 0..1 << total_bit {
|
||||
let value_1 = value % (1 << (total_bit / 2));
|
||||
let value_2 = value >> (total_bit / 2);
|
||||
let float_1 = uint_to_float(value_1, e_min, bit_mantissa, bit_exponent);
|
||||
let float_2 = uint_to_float(value_2, e_min, bit_mantissa, bit_exponent);
|
||||
let float = f(float_1, float_2);
|
||||
let mut encoded_float = float_to_uint(float, e_min, bit_mantissa, bit_exponent);
|
||||
for lut in vec_lut.iter_mut() {
|
||||
lut[value as usize] =
|
||||
(encoded_float % (1 << log_msg_mod)) * ((1_u64 << 63) / (1 << log_msg_mod));
|
||||
encoded_float >>= log_msg_mod;
|
||||
}
|
||||
}
|
||||
vec_lut
|
||||
}
|
||||
|
||||
pub fn create_trivariate_lut<F: Fn(f64, f64, f64) -> f64>(
|
||||
&self,
|
||||
ct: &mut Ciphertext,
|
||||
f: F,
|
||||
) -> Vec<Vec<u64>> {
|
||||
let log_msg_mod = f64::log2(self.wopbs_key.param.message_modulus.0 as f64).ceil() as u64;
|
||||
let bit_mantissa = ct.nb_bit_mantissa;
|
||||
let bit_exponent = ct.nb_bit_exponent;
|
||||
let len_vec = ct.ct_vec_float.len();
|
||||
|
||||
let e_min = ct.e_min;
|
||||
let total_bit = 3 * ((bit_exponent + bit_mantissa) as u64 + 1);
|
||||
|
||||
let mut lut_len = 1 << total_bit;
|
||||
if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
|
||||
lut_len = self.wopbs_key.param.polynomial_size.0;
|
||||
}
|
||||
let mut vec_lut = vec![vec![0; lut_len]; len_vec];
|
||||
|
||||
for value in 0..1 << total_bit {
|
||||
let value_1 = value % (1 << (total_bit / 3));
|
||||
let value_2 = (value >> (total_bit / 3)) % (1 << (total_bit / 3));
|
||||
let value_3 = (value >> (2 * total_bit / 3)) % (1 << (total_bit / 3));
|
||||
let float_1 = uint_to_float(value_1, e_min, bit_mantissa, bit_exponent);
|
||||
let float_2 = uint_to_float(value_2, e_min, bit_mantissa, bit_exponent);
|
||||
let float_3 = uint_to_float(value_3, e_min, bit_mantissa, bit_exponent);
|
||||
let float = f(float_1, float_2, float_3);
|
||||
let mut encoded_float = float_to_uint(float, e_min, bit_mantissa, bit_exponent);
|
||||
for lut in vec_lut.iter_mut() {
|
||||
lut[value as usize] =
|
||||
(encoded_float % (1 << log_msg_mod)) * ((1_u64 << 63) / (1 << log_msg_mod));
|
||||
encoded_float >>= log_msg_mod;
|
||||
}
|
||||
}
|
||||
vec_lut
|
||||
}
|
||||
|
||||
pub fn wop_pbs_bivariate(
|
||||
&self,
|
||||
sks: &ServerKey,
|
||||
ct_in_1: &mut Ciphertext,
|
||||
ct_in_2: &mut Ciphertext,
|
||||
lut: &[Vec<u64>],
|
||||
) -> Ciphertext {
|
||||
let mut vec_ct = vec![ct_in_1, ct_in_2];
|
||||
self.wop(sks, &mut vec_ct, lut)
|
||||
}
|
||||
|
||||
pub fn wop_pbs_trivariate(
|
||||
&self,
|
||||
sks: &ServerKey,
|
||||
ct_in_1: &mut Ciphertext,
|
||||
ct_in_2: &mut Ciphertext,
|
||||
ct_in_3: &mut Ciphertext,
|
||||
lut: &[Vec<u64>],
|
||||
) -> Ciphertext {
|
||||
let mut vec_ct = vec![ct_in_1, ct_in_2, ct_in_3];
|
||||
self.wop(sks, &mut vec_ct, lut)
|
||||
}
|
||||
|
||||
pub fn wop_pbs(&self, sks: &ServerKey, ct_in: &mut Ciphertext, lut: &[Vec<u64>]) -> Ciphertext {
|
||||
let mut vec_ct = vec![ct_in];
|
||||
self.wop(sks, &mut vec_ct, lut)
|
||||
}
|
||||
|
||||
fn wop(
|
||||
&self,
|
||||
// TODO DANGER
|
||||
_sks: &ServerKey,
|
||||
vec_ct_in: &mut [&mut Ciphertext],
|
||||
lut: &[Vec<u64>],
|
||||
) -> Ciphertext {
|
||||
let total_bits_extracted = vec_ct_in.iter().fold(0usize, |acc, ct_in| {
|
||||
acc + ct_in.nb_bit_mantissa + ct_in.nb_bit_exponent + 1
|
||||
});
|
||||
|
||||
let extract_bits_output_lwe_size = self
|
||||
.wopbs_key
|
||||
.wopbs_server_key
|
||||
.key_switching_key
|
||||
.output_key_lwe_dimension()
|
||||
.to_lwe_size();
|
||||
|
||||
let mut vec_lwe = LweCiphertextList::new(
|
||||
0u64,
|
||||
extract_bits_output_lwe_size,
|
||||
LweCiphertextCount(total_bits_extracted),
|
||||
self.wopbs_key.param.ciphertext_modulus,
|
||||
);
|
||||
|
||||
let mut bits_extracted_so_far = 0;
|
||||
|
||||
// Extraction of each bit for each block
|
||||
for ct_in in vec_ct_in.iter_mut() {
|
||||
let mut remain_bit_to_extract = ct_in.nb_bit_mantissa + ct_in.nb_bit_exponent + 1;
|
||||
let mut nb_bit_to_extract = f64::log2((self.key.message_modulus.0) as f64) as usize;
|
||||
for block in ct_in.ct_vec_float.iter_mut() {
|
||||
let delta = (1_usize << 63) / (block.message_modulus.0 * block.carry_modulus.0);
|
||||
let delta_log = DeltaLog(f64::log2(delta as f64) as usize);
|
||||
if nb_bit_to_extract > remain_bit_to_extract {
|
||||
nb_bit_to_extract = remain_bit_to_extract;
|
||||
remain_bit_to_extract = 0;
|
||||
} else {
|
||||
remain_bit_to_extract -= nb_bit_to_extract
|
||||
}
|
||||
|
||||
// Fill in reverse to have the proper order for float ops
|
||||
bits_extracted_so_far += nb_bit_to_extract;
|
||||
let extract_from_bit = total_bits_extracted - bits_extracted_so_far;
|
||||
let extract_to_bit = extract_from_bit + nb_bit_to_extract;
|
||||
|
||||
// TODO DANGER
|
||||
let tmp = self
|
||||
.wopbs_key
|
||||
.extract_bits(delta_log, block, parameters::ExtractedBitsCount(nb_bit_to_extract));
|
||||
|
||||
let mut lwe_sub_list = vec_lwe.get_sub_mut(extract_from_bit..extract_to_bit);
|
||||
lwe_sub_list.as_mut().copy_from_slice(tmp.as_ref());
|
||||
}
|
||||
}
|
||||
let out_ct_count = lut.len();
|
||||
let lut = WopbsLUTBase::from_vec(
|
||||
lut.iter().flatten().cloned().collect(),
|
||||
CiphertextCount(out_ct_count),
|
||||
);
|
||||
// TODO DANGER
|
||||
let mut vec_ct_out = self
|
||||
.wopbs_key
|
||||
.circuit_bootstrapping_vertical_packing(&lut, &vec_lwe);
|
||||
|
||||
let mut ct_vec_out_float: Vec<crate::shortint::Ciphertext> = vec![];
|
||||
for (block_mantissa, result_ct) in
|
||||
vec_ct_in[0].ct_vec_float.iter().zip(vec_ct_out.drain(..))
|
||||
{
|
||||
ct_vec_out_float.push(crate::shortint::Ciphertext::new(
|
||||
result_ct,
|
||||
Degree(block_mantissa.message_modulus.0 - 1),
|
||||
NoiseLevel::NOMINAL,
|
||||
block_mantissa.message_modulus,
|
||||
block_mantissa.carry_modulus,
|
||||
crate::shortint::PBSOrder::KeyswitchBootstrap,
|
||||
));
|
||||
}
|
||||
|
||||
Ciphertext {
|
||||
ct_vec_float: ct_vec_out_float,
|
||||
nb_bit_mantissa: vec_ct_in[0].nb_bit_mantissa,
|
||||
nb_bit_exponent: vec_ct_in[0].nb_bit_exponent,
|
||||
key_id_vec: vec_ct_in[0].key_id_vec.clone(),
|
||||
e_min: vec_ct_in[0].e_min,
|
||||
}
|
||||
}
|
||||
}
|
||||
233
tfhe/src/float_wopbs/server_key/tests.rs
Normal file
233
tfhe/src/float_wopbs/server_key/tests.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use crate::float_wopbs::client_key::{float_to_uint, uint_to_float};
|
||||
use crate::float_wopbs::gen_keys;
|
||||
use crate::float_wopbs::keycache::{get_cks, get_sks, save_cks, save_sks};
|
||||
use crate::float_wopbs::parameters::*;
|
||||
use rand::Rng;
|
||||
use std::time::Instant;
|
||||
|
||||
#[test]
|
||||
pub fn float_wopbs_encode() {
|
||||
let mut rng = rand::thread_rng();
|
||||
let bit_mantissa = 3_usize;
|
||||
let bit_exponent = 4_usize;
|
||||
let e_min = -5;
|
||||
|
||||
let (cks, _) = gen_keys(PARAM_TEST_WOP);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
// Encryption of one message:
|
||||
let ct = cks.encrypt(msg, e_min, bit_mantissa, bit_exponent);
|
||||
|
||||
let clear = uint_to_float(
|
||||
float_to_uint(msg, e_min, bit_mantissa, bit_exponent),
|
||||
e_min,
|
||||
bit_mantissa,
|
||||
bit_exponent,
|
||||
);
|
||||
println!("///////////////////////////////////////////////");
|
||||
let res = cks.decrypt(&ct);
|
||||
println!("clear : {res:?}");
|
||||
println!("result: {res:?}");
|
||||
println!("///////////////////////////////////////////////");
|
||||
assert_eq!(res, clear);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_wopbs_lut() {
|
||||
let mut rng = rand::thread_rng();
|
||||
let bit_mantissa = 3_usize;
|
||||
let bit_exponent = 4_usize;
|
||||
let e_min = -5;
|
||||
|
||||
let param_set = "PARAM_MESSAGE_2_16_BITS";
|
||||
|
||||
let cks = get_cks(param_set);
|
||||
let sks = get_sks(param_set);
|
||||
|
||||
let (cks, sks) = match (cks, sks) {
|
||||
(Some(cks), Some(sks)) => (cks, sks),
|
||||
_ => {
|
||||
// Generate the client key and the server key:
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_16_BITS);
|
||||
save_cks(&cks, param_set);
|
||||
save_sks(&sks, param_set);
|
||||
(cks, sks)
|
||||
}
|
||||
};
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
|
||||
// Encryption of one message:
|
||||
let mut ct = cks.encrypt(msg, e_min, bit_mantissa, bit_exponent);
|
||||
|
||||
println!("///////////////////////////////////////////////");
|
||||
|
||||
let res = cks.decrypt(&ct);
|
||||
println!("res_1 {res:?}");
|
||||
|
||||
let lut = sks.create_lut(&mut ct, |x| x);
|
||||
|
||||
let now = Instant::now();
|
||||
let ct = sks.wop_pbs(&sks, &mut ct, &lut);
|
||||
let res_wop = cks.decrypt(&ct);
|
||||
println!("res_wop {res_wop:?}");
|
||||
|
||||
let elapsed = now.elapsed();
|
||||
println!(
|
||||
"sks param modulus {:?} time : {elapsed:.2?}",
|
||||
sks.key.message_modulus
|
||||
);
|
||||
|
||||
println!("///////////////////////////////////////////////");
|
||||
assert_eq!(res, res_wop);
|
||||
// panic!()
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_wopbs_bivariate() {
|
||||
let bit_mantissa = 4_usize;
|
||||
let bit_exponent = 3_usize;
|
||||
let e_min = -5;
|
||||
let param_set = "PARAM_2_BIT_LWE_8_BITS";
|
||||
|
||||
//generate secret keys
|
||||
let cks = get_cks(param_set);
|
||||
let sks = get_sks(param_set);
|
||||
let (cks, sks) = match (cks, sks) {
|
||||
(Some(cks), Some(sks)) => (cks, sks),
|
||||
_ => {
|
||||
// Generate the client key and the server key:
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_16_BITS);
|
||||
save_cks(&cks, param_set);
|
||||
save_sks(&sks, param_set);
|
||||
(cks, sks)
|
||||
}
|
||||
};
|
||||
|
||||
// take two random messages
|
||||
let mut rng = rand::thread_rng();
|
||||
let msg_1 = rng.gen::<f32>() as f64;
|
||||
let msg_2 = -rng.gen::<f32>() as f64;
|
||||
|
||||
// convert 64 bits floating point in 8 bits floating point
|
||||
let msg_1_round = uint_to_float(
|
||||
float_to_uint(msg_1, e_min, bit_mantissa, bit_exponent),
|
||||
e_min,
|
||||
bit_mantissa,
|
||||
bit_exponent,
|
||||
);
|
||||
|
||||
let msg_2_round = uint_to_float(
|
||||
float_to_uint(msg_2, e_min, bit_mantissa, bit_exponent),
|
||||
e_min,
|
||||
bit_mantissa,
|
||||
bit_exponent,
|
||||
);
|
||||
println!("\nmessage 1 (8 bits floating point): {:?}", msg_1_round);
|
||||
println!("message 2 (8 bits floating point): {:?} \n", msg_2_round);
|
||||
|
||||
let mut ct_1 = cks.encrypt(msg_1, e_min, bit_mantissa, bit_exponent);
|
||||
let res = cks.decrypt(&ct_1);
|
||||
println!("encrypt/decrypt ct_1 (8 bits): {res:?}");
|
||||
let mut ct_2 = cks.encrypt(msg_2, e_min, bit_mantissa, bit_exponent);
|
||||
let res = cks.decrypt(&ct_2);
|
||||
println!("encrypt/decrypt ct_2 (8 bits): {res:?}");
|
||||
let lut = sks.create_bivariate_lut(&mut ct_1, |x, y| x + y);
|
||||
|
||||
let ct = sks.wop_pbs_bivariate(&sks, &mut ct_1, &mut ct_2, &lut);
|
||||
let res = cks.decrypt(&ct);
|
||||
|
||||
// Clear operation done on 64 bits floating point
|
||||
let exact = msg_1_round + msg_2_round;
|
||||
// Convert result on 8 bits floating points
|
||||
let exact = uint_to_float(
|
||||
float_to_uint(exact, e_min, bit_mantissa, bit_exponent),
|
||||
e_min,
|
||||
bit_mantissa,
|
||||
bit_exponent,
|
||||
);
|
||||
println!("\n//////////////////////////////////////////");
|
||||
println!("Clear result :{exact:?}");
|
||||
println!("Decrypted result (WoPBS-based) :{res:?}");
|
||||
println!("///////////////////////////////////////////////\n");
|
||||
assert_eq!(res, exact);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_wopbs_trivariate() {
|
||||
let bit_mantissa = 3_usize;
|
||||
let bit_exponent = 4_usize;
|
||||
let e_min = -5;
|
||||
|
||||
let param_set = "PARAM_MESSAGE_2_16_BITS";
|
||||
|
||||
let cks = get_cks(param_set);
|
||||
let sks = get_sks(param_set);
|
||||
|
||||
let (cks, sks) = match (cks, sks) {
|
||||
(Some(cks), Some(sks)) => (cks, sks),
|
||||
_ => {
|
||||
// Generate the client key and the server key:
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_4_8_BITS_TRI);
|
||||
save_cks(&cks, param_set);
|
||||
save_sks(&sks, param_set);
|
||||
(cks, sks)
|
||||
}
|
||||
};
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let msg_1 = rng.gen::<f32>() as f64;
|
||||
let msg_2 = -rng.gen::<f32>() as f64;
|
||||
let msg_3 = rng.gen::<f32>() as f64;
|
||||
println!("message 1 (64 bits): {:?}", msg_1);
|
||||
println!("message 2 (64 bits): {:?}", msg_2);
|
||||
println!("message 3 (64 bits): {:?}", msg_3);
|
||||
|
||||
let mut ct_1 = cks.encrypt(msg_1, e_min, bit_mantissa, bit_exponent);
|
||||
let mut ct_2 = cks.encrypt(msg_2, e_min, bit_mantissa, bit_exponent);
|
||||
let mut ct_3 = cks.encrypt(msg_3, e_min, bit_mantissa, bit_exponent);
|
||||
println!("encrypt/decrypt ct_1 (8 bits): {:?}", cks.decrypt(&ct_1));
|
||||
println!("encrypt/decrypt ct_2 (8 bits): {:?}", cks.decrypt(&ct_2));
|
||||
println!("encrypt/decrypt ct_3 (8 bits): {:?}", cks.decrypt(&ct_3));
|
||||
|
||||
let lut = sks.create_trivariate_lut(&mut ct_1, |x, y, z| x + y - z);
|
||||
let ct = sks.wop_pbs_trivariate(&sks, &mut ct_1, &mut ct_2, &mut ct_3, &lut);
|
||||
let res = cks.decrypt(&ct);
|
||||
let exact = msg_1 + msg_2 - msg_3;
|
||||
let msg_1_round = uint_to_float(
|
||||
float_to_uint(msg_1, e_min, bit_mantissa, bit_exponent),
|
||||
e_min,
|
||||
bit_mantissa,
|
||||
bit_exponent,
|
||||
);
|
||||
let msg_2_round = uint_to_float(
|
||||
float_to_uint(msg_2, e_min, bit_mantissa, bit_exponent),
|
||||
e_min,
|
||||
bit_mantissa,
|
||||
bit_exponent,
|
||||
);
|
||||
let msg_3_round = uint_to_float(
|
||||
float_to_uint(msg_3, e_min, bit_mantissa, bit_exponent),
|
||||
e_min,
|
||||
bit_mantissa,
|
||||
bit_exponent,
|
||||
);
|
||||
|
||||
let exact_round = uint_to_float(
|
||||
float_to_uint(
|
||||
msg_1_round + msg_2_round - msg_3_round,
|
||||
e_min,
|
||||
bit_mantissa,
|
||||
bit_exponent,
|
||||
),
|
||||
e_min,
|
||||
bit_mantissa,
|
||||
bit_exponent,
|
||||
);
|
||||
println!("\n///////////////////////////////////////////////");
|
||||
println!("Clear result (64 bits) :{exact:?}");
|
||||
println!("Clear result (8 bits) :{exact_round:?}");
|
||||
println!("Decrypted result (WoPBS-based) :{res:?}");
|
||||
println!("///////////////////////////////////////////////\n");
|
||||
assert_eq!(res, exact_round);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
/// Meant to be implemented on the inner server key
|
||||
/// eg the crate::integer::ServerKey
|
||||
#[allow(dead_code)]
|
||||
pub trait EvaluationIntegerKey<ClientKey> {
|
||||
fn new(client_key: &ClientKey) -> Self;
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ pub(crate) trait DecryptionKey<CiphertextType, ClearType> {
|
||||
fn decrypt(&self, ciphertext: &CiphertextType) -> ClearType;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub trait TypeIdentifier {
|
||||
fn type_variant(&self) -> crate::high_level_api::errors::Type;
|
||||
}
|
||||
|
||||
@@ -61,6 +61,18 @@ impl MaxDegree {
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_server_key_max_degree_ex(
|
||||
msg_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
) -> MaxDegree {
|
||||
let full_max_degree = msg_modulus.0 * carry_modulus.0 - 1;
|
||||
|
||||
let carry_max_degree = carry_modulus.0 - 1;
|
||||
|
||||
// We want to be have a margin to add a carry from another block
|
||||
MaxDegree::new(full_max_degree - carry_max_degree)
|
||||
}
|
||||
|
||||
impl ServerKey {
|
||||
/// Generates a server key.
|
||||
///
|
||||
@@ -167,6 +179,15 @@ impl ServerKey {
|
||||
mut key: crate::shortint::server_key::ServerKey,
|
||||
) -> Self {
|
||||
key.max_degree = MaxDegree::integer_crt_server_key(cks.key.parameters);
|
||||
|
||||
Self { key }
|
||||
}
|
||||
|
||||
pub fn from_shortint_ex(mut key: crate::shortint::server_key::ServerKey) -> Self {
|
||||
// It should remain just enough space add a carry
|
||||
let max_degree = integer_server_key_max_degree_ex(key.message_modulus, key.carry_modulus);
|
||||
|
||||
key.max_degree = max_degree;
|
||||
Self { key }
|
||||
}
|
||||
|
||||
|
||||
@@ -122,6 +122,50 @@ impl ServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn partial_propagate_parallelized_w_carry<T>(
|
||||
&self,
|
||||
ctxt: &mut T,
|
||||
start_index: usize,
|
||||
) -> crate::shortint::Ciphertext
|
||||
where
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
if self.is_eligible_for_parallel_single_carry_propagation(ctxt) {
|
||||
let num_blocks = ctxt.blocks().len();
|
||||
|
||||
let (mut message_blocks, carry_blocks) = rayon::join(
|
||||
|| {
|
||||
ctxt.blocks()[start_index..]
|
||||
.par_iter()
|
||||
.map(|block| self.key.message_extract(block))
|
||||
.collect::<Vec<_>>()
|
||||
},
|
||||
|| {
|
||||
let mut carry_blocks = Vec::with_capacity(num_blocks);
|
||||
// No need to compute the carry of the last block, we would just throw it away
|
||||
ctxt.blocks()[start_index..num_blocks - 1]
|
||||
.par_iter()
|
||||
.map(|block| self.key.carry_extract(block))
|
||||
.collect_into_vec(&mut carry_blocks);
|
||||
carry_blocks.insert(0, self.key.create_trivial(0));
|
||||
carry_blocks
|
||||
},
|
||||
);
|
||||
|
||||
ctxt.blocks_mut()[start_index..].swap_with_slice(&mut message_blocks);
|
||||
let carries = T::from_blocks(carry_blocks);
|
||||
self.unchecked_add_assign_parallelized_low_latency(ctxt, &carries)
|
||||
} else {
|
||||
let len = ctxt.blocks().len();
|
||||
let mut carry = self.key.create_trivial(0);
|
||||
for i in start_index..len {
|
||||
carry = self.propagate_parallelized(ctxt, i);
|
||||
}
|
||||
|
||||
carry
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
///
|
||||
/// # Example
|
||||
|
||||
@@ -46,6 +46,7 @@
|
||||
// End allowed pedantic lints
|
||||
|
||||
// Nursery lints
|
||||
#![allow(unknown_lints)]
|
||||
#![warn(clippy::nursery)]
|
||||
// The following lints have been temporarily allowed
|
||||
// They are expected to be fixed progressively
|
||||
@@ -61,6 +62,8 @@
|
||||
#![cfg_attr(all(doc, not(doctest)), feature(doc_auto_cfg))]
|
||||
#![cfg_attr(all(doc, not(doctest)), feature(doc_cfg))]
|
||||
#![warn(rustdoc::broken_intra_doc_links)]
|
||||
#![allow(elided_named_lifetimes)]
|
||||
#![allow(unstable_name_collisions)]
|
||||
|
||||
#[cfg(feature = "__c_api")]
|
||||
pub mod c_api;
|
||||
@@ -78,6 +81,12 @@ pub mod boolean;
|
||||
/// cbindgen:ignore
|
||||
pub mod core_crypto;
|
||||
|
||||
/// Welcome to the TFHE-rs [`float_wopbs`](`crate::float_wopbs`) module documentation!
|
||||
///
|
||||
/// # Special module attributes
|
||||
/// cbindgen:ignore
|
||||
pub mod float_wopbs;
|
||||
|
||||
#[cfg(feature = "integer")]
|
||||
/// Welcome to the TFHE-rs [`integer`](`crate::integer`) module documentation!
|
||||
///
|
||||
|
||||
@@ -112,7 +112,7 @@ impl std::ops::Mul<usize> for NoiseLevel {
|
||||
|
||||
/// Maximum value that the degree can reach.
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct MaxDegree(usize);
|
||||
pub struct MaxDegree(pub usize);
|
||||
|
||||
impl MaxDegree {
|
||||
pub fn new(value: usize) -> Self {
|
||||
@@ -143,7 +143,7 @@ impl MaxDegree {
|
||||
|
||||
/// This tracks the number of operations that has been done.
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct Degree(usize);
|
||||
pub struct Degree(pub usize);
|
||||
|
||||
impl Degree {
|
||||
pub fn new(degree: usize) -> Self {
|
||||
|
||||
@@ -36,7 +36,7 @@ pub struct MessageModulus(pub usize);
|
||||
|
||||
/// The number of bits on which the carry will be encoded.
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct CarryModulus(pub usize);
|
||||
pub struct CarryModulus(pub usize);
|
||||
|
||||
/// Determines in what ring computations are made
|
||||
pub type CiphertextModulus = CoreCiphertextModulus<u64>;
|
||||
|
||||
@@ -1 +1 @@
|
||||
nightly-2023-11-30
|
||||
nightly-2024-07-05
|
||||
|
||||
Reference in New Issue
Block a user