mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-12 16:18:52 -05:00
Compare commits
67 Commits
mz/refacto
...
cm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d220008757 | ||
|
|
b1b55b6426 | ||
|
|
77bea74ac9 | ||
|
|
a8b6c72910 | ||
|
|
c2b21ed709 | ||
|
|
ad41fdf5a5 | ||
|
|
eeae19f35f | ||
|
|
798572e58c | ||
|
|
36d375943c | ||
|
|
a1488b10d5 | ||
|
|
b153641280 | ||
|
|
48405959a4 | ||
|
|
d39e73be91 | ||
|
|
71447d845f | ||
|
|
837df59b44 | ||
|
|
303cac2092 | ||
|
|
bb1a969c34 | ||
|
|
9dac9242be | ||
|
|
e97ac815eb | ||
|
|
62feb59722 | ||
|
|
ac8916a30f | ||
|
|
069ea98ad6 | ||
|
|
e587d1835e | ||
|
|
16121e7487 | ||
|
|
3ab566de7b | ||
|
|
2309b07703 | ||
|
|
8755094c38 | ||
|
|
4da10e9dd5 | ||
|
|
cdda260063 | ||
|
|
be413fff50 | ||
|
|
3ed960d255 | ||
|
|
bdadd39a34 | ||
|
|
f03f2f9c6d | ||
|
|
f03ec9bbed | ||
|
|
5137751dd2 | ||
|
|
edc3449dbf | ||
|
|
6068c509de | ||
|
|
30a4348e3a | ||
|
|
3013e02d90 | ||
|
|
c029917c5c | ||
|
|
d23d04021b | ||
|
|
5a5e9e0ac1 | ||
|
|
cc0a3bad8d | ||
|
|
1d12f60849 | ||
|
|
a0db39c86e | ||
|
|
ef4558ac13 | ||
|
|
bfb22b4531 | ||
|
|
88025010e1 | ||
|
|
b1f4f3b330 | ||
|
|
7575a426ab | ||
|
|
1ac57218b1 | ||
|
|
b7c3f16e24 | ||
|
|
bf4f9198fb | ||
|
|
e618e1d05d | ||
|
|
000428d688 | ||
|
|
937c90666b | ||
|
|
b6a6f1b098 | ||
|
|
c2d7f1748c | ||
|
|
e8cd55dee6 | ||
|
|
95aea9dbe8 | ||
|
|
89f701d307 | ||
|
|
224146686f | ||
|
|
b6b5f92220 | ||
|
|
0fec9e252b | ||
|
|
53c9b82824 | ||
|
|
f670a950d6 | ||
|
|
a44970a9a3 |
2
.github/workflows/aws_tfhe_fast_tests.yml
vendored
2
.github/workflows/aws_tfhe_fast_tests.yml
vendored
@@ -80,7 +80,7 @@ jobs:
|
||||
|
||||
- name: Run user docs tests
|
||||
run: |
|
||||
CARGO_PROFILE=release_lto_off make test_user_doc
|
||||
make test_user_doc
|
||||
|
||||
- name: Run js on wasm API tests
|
||||
run: |
|
||||
|
||||
2
.github/workflows/aws_tfhe_tests.yml
vendored
2
.github/workflows/aws_tfhe_tests.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
||||
|
||||
- name: Run user docs tests
|
||||
run: |
|
||||
CARGO_PROFILE=release_lto_off make test_user_doc
|
||||
make test_user_doc
|
||||
|
||||
- name: Gen Keys if required
|
||||
run: |
|
||||
|
||||
11
.github/workflows/code_coverage.yml
vendored
11
.github/workflows/code_coverage.yml
vendored
@@ -38,6 +38,7 @@ jobs:
|
||||
group: ${{ github.workflow }}_${{ github.ref }}_${{ inputs.instance_image_id }}_${{ inputs.instance_type }}
|
||||
cancel-in-progress: true
|
||||
runs-on: ${{ inputs.runner_name }}
|
||||
timeout-minutes: 1080
|
||||
steps:
|
||||
# Step used for log purpose.
|
||||
- name: Instance configuration used
|
||||
@@ -67,7 +68,7 @@ jobs:
|
||||
|
||||
- name: Check for file changes
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@25ef3926d147cd02fc7e931c1ef50772bbb0d25d
|
||||
uses: tj-actions/changed-files@1c938490c880156b746568a518594309cfb3f66b
|
||||
with:
|
||||
files_yaml: |
|
||||
tfhe:
|
||||
@@ -79,6 +80,12 @@ jobs:
|
||||
if: steps.changed-files.outputs.tfhe_any_changed == 'true'
|
||||
run: |
|
||||
make GEN_KEY_CACHE_COVERAGE_ONLY=TRUE gen_key_cache
|
||||
make gen_key_cache_core_crypto
|
||||
|
||||
- name: Run coverage for core_crypto
|
||||
if: steps.changed-files.outputs.tfhe_any_changed == 'true'
|
||||
run: |
|
||||
make test_core_crypto_cov AVX512_SUPPORT=ON
|
||||
|
||||
- name: Run coverage for boolean
|
||||
if: steps.changed-files.outputs.tfhe_any_changed == 'true'
|
||||
@@ -97,7 +104,7 @@ jobs:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
directory: ./coverage/
|
||||
fail_ci_if_error: true
|
||||
files: shortint/cobertura.xml,boolean/cobertura.xml
|
||||
files: shortint/cobertura.xml,boolean/cobertura.xml,core_crypto/cobertura.xml,core_crypto_avx512/cobertura.xml
|
||||
|
||||
- name: Slack Notification
|
||||
if: ${{ failure() }}
|
||||
|
||||
2
.github/workflows/m1_tests.yml
vendored
2
.github/workflows/m1_tests.yml
vendored
@@ -85,7 +85,7 @@ jobs:
|
||||
|
||||
- name: Run user docs tests
|
||||
run: |
|
||||
CARGO_PROFILE=release_lto_off make test_user_doc
|
||||
make test_user_doc
|
||||
|
||||
# JS tests are more easily launched in docker, we won't test that on M1 as docker is pretty
|
||||
# slow on Apple machines due to the virtualization layer.
|
||||
|
||||
2
.github/workflows/start_benchmarks.yml
vendored
2
.github/workflows/start_benchmarks.yml
vendored
@@ -59,7 +59,7 @@ jobs:
|
||||
|
||||
- name: Check for file changes
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@25ef3926d147cd02fc7e931c1ef50772bbb0d25d
|
||||
uses: tj-actions/changed-files@1c938490c880156b746568a518594309cfb3f66b
|
||||
with:
|
||||
files_yaml: |
|
||||
common_benches:
|
||||
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
Pull Request has been approved :tada:
|
||||
Launching full test suite...
|
||||
@slab-ci cpu_test
|
||||
@slab-ci cpu_integer_test
|
||||
@slab-ci cpu_multi_bit_test
|
||||
@slab-ci cpu_unsigned_integer_test
|
||||
@slab-ci cpu_signed_integer_test
|
||||
@slab-ci cpu_wasm_test
|
||||
@slab-ci csprng_randomness_testing
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -3,9 +3,9 @@ target/
|
||||
.vscode/
|
||||
|
||||
# Path we use for internal-keycache during tests
|
||||
keys/
|
||||
/keys/
|
||||
# In case of symlinked keys
|
||||
keys
|
||||
/keys
|
||||
|
||||
**/Cargo.lock
|
||||
**/*.bin
|
||||
@@ -18,4 +18,4 @@ keys
|
||||
dieharder_run.log
|
||||
|
||||
# Coverage reports
|
||||
./coverage/
|
||||
/coverage/
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["tfhe", "tasks", "apps/trivium", "concrete-csprng"]
|
||||
members = ["tfhe", "tasks", "apps/trivium", "concrete-csprng", "concrete-float"]
|
||||
|
||||
[profile.bench]
|
||||
lto = "fat"
|
||||
|
||||
129
Makefile
129
Makefile
@@ -149,6 +149,11 @@ fix_newline: check_linelint_installed
|
||||
check_newline: check_linelint_installed
|
||||
linelint .
|
||||
|
||||
.PHONY: clippy_float # Run clippy lints on core_crypto with and without experimental features
|
||||
clippy_float: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
-p concrete-float -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_core # Run clippy lints on core_crypto with and without experimental features
|
||||
clippy_core: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
@@ -201,9 +206,8 @@ clippy_tasks:
|
||||
|
||||
.PHONY: clippy_trivium # Run clippy lints on Trivium app
|
||||
clippy_trivium: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy -p tfhe-trivium \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer \
|
||||
-p $(TFHE_SPEC) -- --no-deps -D warnings
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
-p tfhe-trivium -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_all_targets # Run clippy lints on all targets (benches, examples, etc.)
|
||||
clippy_all_targets:
|
||||
@@ -225,13 +229,6 @@ clippy_js_wasm_api clippy_tasks clippy_core clippy_concrete_csprng clippy_triviu
|
||||
clippy_fast: clippy clippy_all_targets clippy_c_api clippy_js_wasm_api clippy_tasks clippy_core \
|
||||
clippy_concrete_csprng
|
||||
|
||||
.PHONY: gen_key_cache # Run the script to generate keys and cache them for shortint tests
|
||||
gen_key_cache: install_rs_build_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
|
||||
--example generates_test_keys \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,internal-keycache -- \
|
||||
$(MULTI_BIT_ONLY) $(COVERAGE_ONLY)
|
||||
|
||||
.PHONY: build_core # Build core_crypto without experimental features
|
||||
build_core: install_rs_build_toolchain install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) build --profile $(CARGO_PROFILE) \
|
||||
@@ -319,6 +316,21 @@ test_core_crypto: install_rs_build_toolchain install_rs_check_toolchain
|
||||
--features=$(TARGET_ARCH_FEATURE),experimental,$(AVX512_FEATURE) -p $(TFHE_SPEC) -- core_crypto::; \
|
||||
fi
|
||||
|
||||
.PHONY: test_core_crypto_cov # Run the tests of the core_crypto module with code coverage
|
||||
test_core_crypto_cov: install_rs_build_toolchain install_rs_check_toolchain install_tarpaulin
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) tarpaulin --profile $(CARGO_PROFILE) \
|
||||
--out xml --output-dir coverage/core_crypto --line --engine llvm --timeout 500 \
|
||||
--implicit-test-threads $(COVERAGE_EXCLUDED_FILES) \
|
||||
--features=$(TARGET_ARCH_FEATURE),experimental,internal-keycache,__coverage \
|
||||
-p $(TFHE_SPEC) -- core_crypto::
|
||||
@if [[ "$(AVX512_SUPPORT)" == "ON" ]]; then \
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) tarpaulin --profile $(CARGO_PROFILE) \
|
||||
--out xml --output-dir coverage/core_crypto_avx512 --line --engine llvm --timeout 500 \
|
||||
--implicit-test-threads $(COVERAGE_EXCLUDED_FILES) \
|
||||
--features=$(TARGET_ARCH_FEATURE),experimental,internal-keycache,__coverage,$(AVX512_FEATURE) \
|
||||
-p $(TFHE_SPEC) -- core_crypto::; \
|
||||
fi
|
||||
|
||||
.PHONY: test_boolean # Run the tests of the boolean module
|
||||
test_boolean: install_rs_build_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
@@ -471,6 +483,59 @@ test_concrete_csprng:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-csprng
|
||||
|
||||
.PHONY: test_float # Run minifloat bivariate test
|
||||
test_float: test_float_add test_float_sub test_float_mul test_float_div test_float_cos test_float_sin test_float_relu test_float_sigmoid test_minifloat
|
||||
|
||||
.PHONY: test_minifloat # Run minifloat bivariate test
|
||||
test_minifloat:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE),shortint -p tfhe float_wopbs_bivariate -- --nocapture
|
||||
|
||||
.PHONY: test_float_cos # Run floating points cosine test
|
||||
test_float_cos:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::float_cos" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_sin # Run floating points sine test
|
||||
test_float_sin:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::float_sin" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_mul # Run floating points multiplication test
|
||||
test_float_mul:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_mul" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_add # Run floating points addition test
|
||||
test_float_add:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_add" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_sub # Run floating points subtraction test
|
||||
test_float_sub:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_sub" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_div # Run floating points division test
|
||||
test_float_div:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_div" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_relu # Run floating points relu test
|
||||
test_float_relu:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_relu" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_sigmoid # Run floating points sigmoid test
|
||||
test_float_sigmoid:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::test_float_sigmoid" -- --exact --nocapture
|
||||
|
||||
.PHONY: test_float_depth_test # Run floating points depth test
|
||||
test_float_depth_test:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE) -p concrete-float "server_key::tests::depth_test_parallelized" -- --exact --nocapture
|
||||
|
||||
.PHONY: doc # Build rust doc
|
||||
doc: install_rs_check_toolchain
|
||||
RUSTDOCFLAGS="--html-in-header katex-header.html" \
|
||||
@@ -624,9 +689,53 @@ ci_bench_web_js_api_parallel: build_web_js_api_parallel
|
||||
nvm use node && \
|
||||
$(MAKE) -C tfhe/web_wasm_parallel_tests bench-ci
|
||||
|
||||
.PHONY: bench_float # Run benchmarks for the floating points
|
||||
bench_float: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench
|
||||
|
||||
.PHONY: bench_float_8bit # Run benchmarks for the floating points
|
||||
bench_float_8bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench -- PARAM_8
|
||||
|
||||
|
||||
.PHONY: bench_float_16bit # Run benchmarks for the floating points
|
||||
bench_float_16bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench -- PARAM_16
|
||||
|
||||
|
||||
.PHONY: bench_float_32bit # Run benchmarks for the floating points
|
||||
bench_float_32bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench -- PARAM_32
|
||||
|
||||
.PHONY: bench_float_64bit # Run benchmarks for the floating points
|
||||
bench_float_64bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-bench -- PARAM_64
|
||||
|
||||
.PHONY: bench_minifloat # Run benchmarks for Wopbs floating points
|
||||
bench_minifloat: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench float-wopbs-bench
|
||||
|
||||
#
|
||||
# Utility tools
|
||||
#
|
||||
.PHONY: gen_key_cache # Run the script to generate keys and cache them for shortint tests
|
||||
gen_key_cache: install_rs_build_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
|
||||
--example generates_test_keys \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,internal-keycache -- \
|
||||
$(MULTI_BIT_ONLY) $(COVERAGE_ONLY)
|
||||
|
||||
.PHONY: gen_key_cache_core_crypto # Run function to generate keys and cache them for core_crypto tests
|
||||
gen_key_cache_core_crypto: install_rs_build_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --tests --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE),experimental,internal-keycache -p $(TFHE_SPEC) -- --nocapture \
|
||||
core_crypto::keycache::generate_keys
|
||||
|
||||
.PHONY: measure_hlapi_compact_pk_ct_sizes # Measure sizes of public keys and ciphertext for high-level API
|
||||
measure_hlapi_compact_pk_ct_sizes: install_rs_check_toolchain
|
||||
|
||||
297
README.md
297
README.md
@@ -1,175 +1,160 @@
|
||||
<p align="center">
|
||||
<!-- product name logo -->
|
||||
<img width=600 src="https://user-images.githubusercontent.com/5758427/231206749-8f146b97-3c5a-4201-8388-3ffa88580415.png">
|
||||
</p>
|
||||
<hr/>
|
||||
<p align="center">
|
||||
<a href="https://docs.zama.ai/tfhe-rs"> 📒 Read documentation</a> | <a href="https://zama.ai/community"> 💛 Community support</a>
|
||||
</p>
|
||||
<p align="center">
|
||||
<!-- Version badge using shields.io -->
|
||||
<a href="https://github.com/zama-ai/tfhe-rs/releases">
|
||||
<img src="https://img.shields.io/github/v/release/zama-ai/tfhe-rs?style=flat-square">
|
||||
</a>
|
||||
<!-- Zama Bounty Program -->
|
||||
<a href="https://github.com/zama-ai/bounty-program">
|
||||
<img src="https://img.shields.io/badge/Contribute-Zama%20Bounty%20Program-yellow?style=flat-square">
|
||||
</a>
|
||||
</p>
|
||||
<hr/>
|
||||
# Artifact:TFHE Gets Real: an Efficient and Flexible Homomorphic Floating-Point Arithmetic
|
||||
|
||||
|
||||
**TFHE-rs** is a pure Rust implementation of TFHE for boolean and integer
|
||||
arithmetics over encrypted data. It includes:
|
||||
- a **Rust** API
|
||||
- a **C** API
|
||||
- and a **client-side WASM** API
|
||||
## Description
|
||||
|
||||
**TFHE-rs** is meant for developers and researchers who want full control over
|
||||
what they can do with TFHE, while not having to worry about the low level
|
||||
implementation. The goal is to have a stable, simple, high-performance, and
|
||||
production-ready library for all the advanced features of TFHE.
|
||||
|
||||
## Getting Started
|
||||
The steps to run a first example are described below.
|
||||
In what follows, we provide instructions on how to run the benchmarks from the paper entitled **TFHE Gets Real: An Efficient and Flexible Homomorphic Floating-Point Arithmetic**.
|
||||
In particular, the benchmarks presented in **Table 5**, **Table 6**, **Table 7**, and the experiments shown in **Table 8** can be easily reproduced using this code. The implementation of the techniques described in the aforementioned paper has been integrated into the **TFHE-rs** library, version 0.5.0. The modified or added source files are organized into two different paths.
|
||||
|
||||
### Cargo.toml configuration
|
||||
To use the latest version of `TFHE-rs` in your project, you first need to add it as a dependency in your `Cargo.toml`:
|
||||
The Minifloats (Section 3.1) are located in *tfhe/src/float-wopbs*
|
||||
- Test files are located in *tfhe/src/float_wopbs/server_key/tests.rs*
|
||||
- Benchmarks are located in *tfhe/benches/float_wopbs/bench.rs*
|
||||
|
||||
+ For x86_64-based machines running Unix-like OSes:
|
||||
|
||||
```toml
|
||||
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix"] }
|
||||
The homomorphic floating points (Section 3.2) are located in *tfhe/concrete-float/*
|
||||
- Test files are located *tfhe/concrete-float/src/server_key/tests.rs*
|
||||
- Benchmarks are located in *tfhe/concrete-float/benches/bench.rs*
|
||||
|
||||
|
||||
## Dependencies
|
||||
|
||||
Tested on Linux and Mac OS with Rust version >= 1.80 (see [here](https://www.rust-lang.org/tools/install) a guide to install Rust).
|
||||
Complete list of dependencies and a guide on how to install TFHE-rs can be found in the online documentation [here](https://docs.zama.ai/tfhe-rs/0.5-3/getting-started/installation) or in the local file [here](./README_TFHE-rs.md).
|
||||
|
||||
## How to run benchmarks
|
||||
At the root of the project (i.e., in the TFHE-rs folder), enter the following commands to run the benchmarks:
|
||||
|
||||
- ```make bench_minifloat```: returns the timings associated to the Minifloats (**Table 6**).
|
||||
- ```make bench_float```: returns the timings associated to the HFP (**Table 5**, **Table 7**).
|
||||
These benchmarks first launch the parallelized and then the sequential experiments.
|
||||
This outputs the timings depending on the input precision.
|
||||
**This takes more than 6 hours to run**.
|
||||
|
||||
To run benchmarks for a specific precision over homomorphic floating points, here are the dedicated commands:
|
||||
- ```make bench_float_8bit```: Runs benchmarks for only 8-bit floating point *(around 15 min)*.
|
||||
- ```make bench_float_16bit```: Runs benchmarks for only 16-bit floating point *(around 30 min)*.
|
||||
- ```make bench_float_32bit```: Runs benchmarks for only 32-bit floating point *(around 1h40)*.
|
||||
- ```make bench_float_64bit```: Runs benchmarks for only 64-bit floating point *(around 6h30)*.
|
||||
|
||||
|
||||
We recall that the benchmarks were performed on AWS using an **m6i.metal** instance with an Intel Xeon 8375C (Ice Lake) processor running at 3.5 GHz, 128 vCPUs, and 512 GiB of memory.
|
||||
|
||||
### Understanding Benchmark Output (Criterion.rs)
|
||||
|
||||
This project uses [Criterion.rs](https://docs.rs/criterion/latest/criterion/) for benchmarking. Criterion is a powerful and statistically robust benchmarking framework for Rust, and it may produce outputs that are unfamiliar at first glance. This section explains how to interpret them.
|
||||
|
||||
#### Sample Output Structure
|
||||
|
||||
A typical benchmark result looks like this:
|
||||
|
||||
```
|
||||
test_float time: [53.2 µs 54.0 µs 54.8 µs]
|
||||
change: [+0.2% +1.0% +1.8%] (p = 0.002)
|
||||
Found 3 outliers among 100 measurements (3.00%)
|
||||
3 (3.00%) high mild
|
||||
```
|
||||
|
||||
+ For Apple Silicon or aarch64-based machines running Unix-like OSes:
|
||||
**Here's what this means:**
|
||||
|
||||
```toml
|
||||
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "aarch64-unix"] }
|
||||
- `time: [low est. median high est.]`: The estimated execution time of the function.
|
||||
- `change`: The performance change compared to a previous run (if available).
|
||||
- `outliers`: Some runs deviated from the typical time. Criterion detects and accounts for these using statistical methods.
|
||||
|
||||
---
|
||||
|
||||
#### Common Warnings and What They Mean
|
||||
|
||||
##### `Found X outliers among Y measurements`
|
||||
|
||||
Criterion runs each benchmark many times (default: 100) to get statistically significant results.
|
||||
An *outlier* is a run that was significantly faster or slower than the others.
|
||||
|
||||
- **Why does this happen?** Often, it's due to **other processes on the machine** (e.g., background services, OS interrupts, or CPU scheduling) affecting performance temporarily.
|
||||
- **Why it doesn't invalidate results:** Criterion uses statistical techniques to minimize the impact of these outliers when estimating performance.
|
||||
- **Best practice to reduce outliers:** Run the benchmarks on a **freshly rebooted machine**, with as few background processes as possible. Ideally, let the system idle for a minute after boot to stabilize before running benchmarks.
|
||||
|
||||
##### `Unable to complete 100 samples in 5.0s.`
|
||||
|
||||
The benchmark took longer than the expected 5 seconds.
|
||||
This is merely a warning indicating that the full set of 100 samples could not be collected within the default 5-second measurement window.
|
||||
|
||||
- **No action is required**: Criterion will still proceed to run all 100 samples, and the results remain statistically valid.
|
||||
- **Why the warning appears**: It's there to inform you that benchmarking is taking longer than expected and to help you tune settings if needed.
|
||||
- **Optional**: If you're constrained by time (e.g., running in CI), you can:
|
||||
- Reduce the sample size (e.g., to 10 or 20 samples).
|
||||
- Or increase the measurement time using:
|
||||
```bash
|
||||
cargo bench -- --measurement-time 30
|
||||
```
|
||||
|
||||
## How to run the tests
|
||||
### MiniFloats
|
||||
|
||||
To run the tests related to the **minifloats**, run the following command:
|
||||
- ```make test_minifloat```: Runs a bivariate operation between two minifloats.
|
||||
|
||||
|
||||
The **minifloat** test is available in the file *tfhe/src/float_wopbs/server_key/tests.rs*.
|
||||
|
||||
|
||||
|
||||
### Homomorphic Floating Points
|
||||
At the root of the project (i.e., in the TFHE-rs folder), enter the following commands to run the tests per operation on the **homomorphic floating points**:
|
||||
- ```make test_float_add```: Runs a 32-bit floating-point addition with two random inputs.
|
||||
- ```make test_float_sub```: Runs a 32-bit floating-point subtraction with two random inputs.
|
||||
- ```make test_float_mul```: Runs a 32-bit floating-point multiplication with two random inputs.
|
||||
- ```make test_float_div```: Runs a 32-bit floating-point division with two random inputs.
|
||||
- ```make test_float_cos```: Runs the experiment from **Table 8** with a random input value.
|
||||
- ```make test_float_sin```: Runs the experiment from **Table 8** with a random input value.
|
||||
- ```make test_float_relu```: Runs a 32-bit floating-point relu with a random input.
|
||||
- ```make test_float_sigmoid```: Runs a 32-bit floating-point sigmoid with a random input.
|
||||
- ```make test_float```: Runs all previous tests for operations on 32-bit floating-points.
|
||||
- ```make test_float_depth_test```: This command runs the following experiment:
|
||||
- **Step 1**: Create 3 blocks, each composed of a clear 32-bit floating point, a clear 64-bit floating point, and a 32-bit homomorphic floating point.
|
||||
- **Step 2**: Choose two blocks randomly among the 3 blocks and randomly select a parallelized operation (addition, subtraction, or multiplication).
|
||||
- **Step 3**: Compute the selected operation between the two selected blocks and store the result randomly in one of the two selected blocks.
|
||||
(The operation is performed respectively between the two 64-bit floating points, the two 32-bit floating points, and homomorphically between the two 32-bit homomorphic floating points.)
|
||||
- Repeat Steps 2 and 3 for 50 iterations.
|
||||
- To avoid reaching + or - infinity, or **NaN**, when the clear 64-bit floating point reaches a fixed bound, compute a multiplication to rescale the value close to 1.
|
||||
This operation is also performed homomorphically for the encrypted data. This test takes several minutes.
|
||||
|
||||
The tests are located in the file *tfhe/concrete-float/src/server_key/tests.rs*.
|
||||
|
||||
Due to the representation being close to, but not exactly the same as, a given representation, the obtained result is not identical to the one obtained in clear.
|
||||
To consider a test as "passed", we accept a difference of less than 0.1% compared to the 64-bit floating-point clear results.
|
||||
Note that using 8 or 16-bit homomorphic floating points might return errors due to a lack of precision and due to the comparisons with clear 64-bit floating points.
|
||||
|
||||
In each test, the different results are presented in the following format:
|
||||
```
|
||||
--------------------
|
||||
"Name":
|
||||
|
||||
Result :
|
||||
Clear 32-bits:
|
||||
Clear 64-bits:
|
||||
|
||||
--------------------
|
||||
```
|
||||
Note: users with ARM devices must compile `TFHE-rs` using a stable toolchain with version >= 1.72.
|
||||
where ```name``` stands for the name of the ciphertext or the name of the operation, result always corresponds to the decryption of a homomorphic floating point, and Clear ``` 32-bits``` and Clear ``` 64-bits``` correspond to the clear floating-point witness.
|
||||
|
||||
All tests in *tfhe/concrete-float/src/server_key/tests.rs* are conducted for 32-bit floating-point precision, as it provides the best ratio between execution time and precision.
|
||||
To change the parameter set used, the parameters in the following ``` const ``` must be uncommented (lines 79 to 87 in the file *tfhe/concrete-float/src/server_key/tests.rs*).
|
||||
|
||||
|
||||
+ For x86_64-based machines with the [`rdseed instruction`](https://en.wikipedia.org/wiki/RDRAND)
|
||||
running Windows:
|
||||
|
||||
```toml
|
||||
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64"] }
|
||||
```rust
|
||||
const PARAMS: [(&str, Parameters); 1] =
|
||||
[
|
||||
//named_param!(PARAM_FP_64_BITS),
|
||||
named_param!(PARAM_FP_32_BITS),
|
||||
//named_param!(PARAM_FP_16_BITS),
|
||||
//named_param!(PARAM_FP_8_BITS),
|
||||
];
|
||||
```
|
||||
|
||||
Note: aarch64-based machines are not yet supported for Windows as it's currently missing an entropy source to be able to seed the [CSPRNGs](https://en.wikipedia.org/wiki/Cryptographically_secure_pseudorandom_number_generator) used in TFHE-rs
|
||||
Note that the number in ``` [(\&str, Parameters); 1] ``` should correspond to the number of tested parameters, e.g., if another parameter sets is uncommented, this line becomes: ``` [(\&str, Parameters); 2] ```.
|
||||
The parameter ```PARAM_X``` corresponds to the parameters used in **Table 5**, and ```PARAM_TCHES_X``` corresponds to the parameters used in **Table 7**.
|
||||
|
||||
|
||||
## A simple example
|
||||
|
||||
Here is a full example:
|
||||
|
||||
``` rust
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint32, FheUint8};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Basic configuration to use homomorphic integers
|
||||
let config = ConfigBuilder::default().build();
|
||||
|
||||
// Key generation
|
||||
let (client_key, server_keys) = generate_keys(config);
|
||||
|
||||
let clear_a = 1344u32;
|
||||
let clear_b = 5u32;
|
||||
let clear_c = 7u8;
|
||||
|
||||
// Encrypting the input data using the (private) client_key
|
||||
// FheUint32: Encrypted equivalent to u32
|
||||
let mut encrypted_a = FheUint32::try_encrypt(clear_a, &client_key)?;
|
||||
let encrypted_b = FheUint32::try_encrypt(clear_b, &client_key)?;
|
||||
|
||||
// FheUint8: Encrypted equivalent to u8
|
||||
let encrypted_c = FheUint8::try_encrypt(clear_c, &client_key)?;
|
||||
|
||||
// On the server side:
|
||||
set_server_key(server_keys);
|
||||
|
||||
// Clear equivalent computations: 1344 * 5 = 6720
|
||||
let encrypted_res_mul = &encrypted_a * &encrypted_b;
|
||||
|
||||
// Clear equivalent computations: 1344 >> 5 = 42
|
||||
encrypted_a = &encrypted_res_mul >> &encrypted_b;
|
||||
|
||||
// Clear equivalent computations: let casted_a = a as u8;
|
||||
let casted_a: FheUint8 = encrypted_a.cast_into();
|
||||
|
||||
// Clear equivalent computations: min(42, 7) = 7
|
||||
let encrypted_res_min = &casted_a.min(&encrypted_c);
|
||||
|
||||
// Operation between clear and encrypted data:
|
||||
// Clear equivalent computations: 7 & 1 = 1
|
||||
let encrypted_res = encrypted_res_min & 1_u8;
|
||||
|
||||
// Decrypting on the client side:
|
||||
let clear_res: u8 = encrypted_res.decrypt(&client_key);
|
||||
assert_eq!(clear_res, 1_u8);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
To run this code, use the following command:
|
||||
<p align="center"> <code> cargo run --release </code> </p>
|
||||
|
||||
Note that when running code that uses `tfhe-rs`, it is highly recommended
|
||||
to run in release mode with cargo's `--release` flag to have the best performances possible,
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
There are two ways to contribute to TFHE-rs:
|
||||
|
||||
- you can open issues to report bugs or typos, or to suggest new ideas
|
||||
- you can ask to become an official contributor by emailing [hello@zama.ai](mailto:hello@zama.ai).
|
||||
(becoming an approved contributor involves signing our Contributor License Agreement (CLA))
|
||||
|
||||
Only approved contributors can send pull requests, so please make sure to get in touch before you do!
|
||||
|
||||
## Credits
|
||||
|
||||
This library uses several dependencies and we would like to thank the contributors of those
|
||||
libraries.
|
||||
|
||||
## Need support?
|
||||
<a target="_blank" href="https://community.zama.ai">
|
||||
<img src="https://user-images.githubusercontent.com/5758427/231115030-21195b55-2629-4c01-9809-be5059243999.png">
|
||||
</a>
|
||||
|
||||
## Citing TFHE-rs
|
||||
|
||||
To cite TFHE-rs in academic papers, please use the following entry:
|
||||
|
||||
```text
|
||||
@Misc{TFHE-rs,
|
||||
title={{TFHE-rs: A Pure Rust Implementation of the TFHE Scheme for Boolean and Integer Arithmetics Over Encrypted Data}},
|
||||
author={Zama},
|
||||
year={2022},
|
||||
note={\url{https://github.com/zama-ai/tfhe-rs}},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This software is distributed under the BSD-3-Clause-Clear license. If you have any questions,
|
||||
please contact us at `hello@zama.ai`.
|
||||
|
||||
## Disclaimers
|
||||
|
||||
### Security Estimation
|
||||
|
||||
Security estimations are done using the
|
||||
[Lattice Estimator](https://github.com/malb/lattice-estimator)
|
||||
with `red_cost_model = reduction.RC.BDGL16`.
|
||||
|
||||
When a new update is published in the Lattice Estimator, we update parameters accordingly.
|
||||
|
||||
### Side-Channel Attacks
|
||||
|
||||
Mitigation for side channel attacks have not yet been implemented in TFHE-rs,
|
||||
and will be released in upcoming versions.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{
|
||||
"m6i.metal": 7.168
|
||||
"m6i.metal": 7.168,
|
||||
"hpc7a.96xlarge": 7.7252
|
||||
}
|
||||
|
||||
15
ci/slab.toml
15
ci/slab.toml
@@ -3,27 +3,32 @@ region = "eu-west-3"
|
||||
image_id = "ami-051942e4055555752"
|
||||
instance_type = "m6i.32xlarge"
|
||||
|
||||
[profile.cpu-big_fallback]
|
||||
region = "us-east-1"
|
||||
image_id = "ami-04e3bb9aebb6786df"
|
||||
instance_type = "m6i.32xlarge"
|
||||
|
||||
[profile.cpu-small]
|
||||
region = "eu-west-3"
|
||||
image_id = "ami-051942e4055555752"
|
||||
instance_type = "m6i.4xlarge"
|
||||
|
||||
[profile.bench]
|
||||
region = "eu-west-3"
|
||||
image_id = "ami-051942e4055555752"
|
||||
instance_type = "m6i.metal"
|
||||
region = "eu-west-1"
|
||||
image_id = "ami-0e88d98b86aff13de"
|
||||
instance_type = "hpc7a.96xlarge"
|
||||
|
||||
[command.cpu_test]
|
||||
workflow = "aws_tfhe_tests.yml"
|
||||
profile = "cpu-big"
|
||||
check_run_name = "CPU AWS Tests"
|
||||
|
||||
[command.cpu_integer_test]
|
||||
[command.cpu_unsigned_integer_test]
|
||||
workflow = "aws_tfhe_integer_tests.yml"
|
||||
profile = "cpu-big"
|
||||
check_run_name = "CPU Unsigned Integer AWS Tests"
|
||||
|
||||
[command.cpu_multi_bit_test]
|
||||
[command.cpu_signed_integer_test]
|
||||
workflow = "aws_tfhe_signed_integer_tests.yml"
|
||||
profile = "cpu-big"
|
||||
check_run_name = "CPU Signed Integer AWS Tests"
|
||||
|
||||
@@ -22,6 +22,9 @@ pub trait Seeder {
|
||||
}
|
||||
|
||||
mod implem;
|
||||
// This import statement can be empty if seeder features are disabled, rustc's behavior changed to
|
||||
// warn of empty modules, we know this can happen, so allow it.
|
||||
#[allow(unused_imports)]
|
||||
pub use implem::*;
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
69
concrete-float/Cargo.toml
Normal file
69
concrete-float/Cargo.toml
Normal file
@@ -0,0 +1,69 @@
|
||||
[package]
|
||||
name = "concrete-float"
|
||||
version = "0.1.0-beta.0"
|
||||
edition = "2018"
|
||||
authors = ["Zama team"]
|
||||
license = "BSD-3-Clause-Clear"
|
||||
description = "Homomorphic Integer circuit interface for the concrete FHE library."
|
||||
homepage = "https://www.zama.ai/concrete-framework"
|
||||
documentation = "https://docs.zama.ai/home/"
|
||||
repository = "https://github.com/zama-ai/concrete"
|
||||
readme = "README.md"
|
||||
keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography"]
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
aligned-vec = { version = "0.5", features = ["serde"] }
|
||||
dyn-stack = { version = "0.9" }
|
||||
rayon = "1.5"
|
||||
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
|
||||
tfhe = { path = "../tfhe", features = ["shortint", "integer"] }
|
||||
|
||||
[target.'cfg(target_arch = "x86_64")'.dependencies]
|
||||
tfhe = { path = "../tfhe", features = ["shortint", "integer", "x86_64-unix"] }
|
||||
|
||||
[target.'cfg(target_arch = "aarch64")'.dependencies]
|
||||
tfhe = { path = "../tfhe", features = ["shortint", "integer", "aarch64-unix"] }
|
||||
|
||||
[features]
|
||||
nightly-avx512 = ["tfhe/nightly-avx512"]
|
||||
seeder_x86_64_rdseed = []
|
||||
seeder_unix = []
|
||||
generator_x86_64_aesni = []
|
||||
generator_fallback = []
|
||||
generator_aarch64_aes = []
|
||||
|
||||
x86_64 = [
|
||||
"seeder_x86_64_rdseed",
|
||||
"generator_x86_64_aesni",
|
||||
"generator_fallback",
|
||||
]
|
||||
x86_64-unix = ["x86_64", "seeder_unix"]
|
||||
aarch64 = [ "generator_aarch64_aes", "generator_fallback"]
|
||||
aarch64-unix = ["aarch64", "seeder_unix"]
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5.1"
|
||||
lazy_static = "1.4.0"
|
||||
bincode = "1.3.3"
|
||||
paste = "1.0.7"
|
||||
rand = "0.8.4"
|
||||
doc-comment = "0.3.3"
|
||||
#concrete-shortint = { path = "../tfhe", features = ["internal-keycache"] }
|
||||
|
||||
#[features]
|
||||
# Keychache used to speed up tests and benches
|
||||
# by not requiring to regererate keys at each launch
|
||||
#internal-keycache = ["lazy_static", "shortint/src/internal-keycache"]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
rustdoc-args = ["--html-in-header", "katex-header.html"]
|
||||
|
||||
[[bench]]
|
||||
name = "float-bench"
|
||||
path = "benches/bench.rs"
|
||||
harness = false
|
||||
required-features = []
|
||||
32
concrete-float/LICENSE
Normal file
32
concrete-float/LICENSE
Normal file
@@ -0,0 +1,32 @@
|
||||
BSD 3-Clause Clear License
|
||||
|
||||
Copyright © 2022 ZAMA.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this
|
||||
list of conditions and the following disclaimer in the documentation and/or other
|
||||
materials provided with the distribution.
|
||||
|
||||
3. Neither the name of ZAMA nor the names of its contributors may be used to endorse
|
||||
or promote products derived from this software without specific prior written permission.
|
||||
NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE*.
|
||||
THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
ZAMA OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
|
||||
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
|
||||
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
||||
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*In addition to the rights carried by this license, ZAMA grants to the user a non-exclusive,
|
||||
free and non-commercial license on all patents filed in its name relating to the open-source
|
||||
code (the "Patents") for the sole purpose of evaluation, development, research, prototyping
|
||||
and experimentation.
|
||||
11
concrete-float/README.md
Normal file
11
concrete-float/README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# concrete Integer
|
||||
|
||||
`concrete-integer` is a Rust library built on top of `concrete-shortint`, it
|
||||
combines multiple `shortint` to handle encrypted integers of "arbitrary"
|
||||
size.
|
||||
|
||||
## License
|
||||
|
||||
This software is distributed under the BSD-3-Clause-Clear license. If you have any questions,
|
||||
please contact us at `hello@zama.ai`.
|
||||
|
||||
304
concrete-float/benches/bench.rs
Normal file
304
concrete-float/benches/bench.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use concrete_float::gen_keys;
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use rand::Rng;
|
||||
|
||||
// Previous Parameters
|
||||
#[allow(unused_imports)]
|
||||
use concrete_float::parameters::{FINAL_PARAM_16,
|
||||
FINAL_PARAM_2_2_32, FINAL_PARAM_32,
|
||||
FINAL_PARAM_64, FINAL_PARAM_8,
|
||||
FINAL_WOP_PARAM_15, FINAL_WOP_PARAM_16,
|
||||
FINAL_WOP_PARAM_2_2_32, FINAL_WOP_PARAM_32,
|
||||
FINAL_WOP_PARAM_64, FINAL_WOP_PARAM_8,
|
||||
FINAL_PARAM_64_TCHESS, FINAL_PARAM_32_TCHESS,
|
||||
FINAL_WOP_PARAM_64_TCHESS, FINAL_WOP_PARAM_32_TCHESS};
|
||||
|
||||
use concrete_float::parameters::{FINAL_PARAM_16_BIS, FINAL_PARAM_32_BIS,
|
||||
FINAL_PARAM_64_BIS, FINAL_PARAM_8_BIS,
|
||||
FINAL_WOP_PARAM_16_BIS, FINAL_WOP_PARAM_32_BIS,
|
||||
FINAL_WOP_PARAM_64_BIS, FINAL_WOP_PARAM_8_BIS};
|
||||
use tfhe::shortint;
|
||||
|
||||
macro_rules! named_param {
|
||||
($param:ident) => {
|
||||
(stringify!($param), $param)
|
||||
};
|
||||
}
|
||||
|
||||
criterion_main!(float_parallelized, float);
|
||||
|
||||
struct Parameters {
|
||||
pbsparameters: shortint::ClassicPBSParameters,
|
||||
wopbsparameters: shortint::WopbsParameters,
|
||||
len_man: usize,
|
||||
len_exp: usize,
|
||||
}
|
||||
|
||||
//Parameter for a Floating point 64-bits equivalent
|
||||
const PARAM_64: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_64_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_64_BIS,
|
||||
len_man: 27,
|
||||
len_exp: 5,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 32-bits equivalent
|
||||
const PARAM_32: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_32_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_32_BIS,
|
||||
len_man: 13,
|
||||
len_exp: 4,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 16-bits equivalent
|
||||
const PARAM_16: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_16_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_16_BIS,
|
||||
len_man: 6,
|
||||
len_exp: 3,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 8-bits equivalent
|
||||
const PARAM_8: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_8_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_8_BIS,
|
||||
len_man: 3,
|
||||
len_exp: 2,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 64-bits equivalent
|
||||
//With failure probability smaller than PARAM_64
|
||||
const PARAM_TCHESS_64: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_64_TCHESS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_64_TCHESS,
|
||||
len_man: 27,
|
||||
len_exp: 5,
|
||||
};
|
||||
|
||||
|
||||
//Parameter for a Floating point 32-bits equivalent
|
||||
//With failure probability smaller than PARAM_32
|
||||
const PARAM_TCHESS_32: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_32_TCHESS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_32_TCHESS,
|
||||
len_man: 13,
|
||||
len_exp: 4,
|
||||
};
|
||||
|
||||
|
||||
const SERVER_KEY_BENCH_PARAMS: [(&str, Parameters);6] =
|
||||
[
|
||||
named_param!(PARAM_8),
|
||||
named_param!(PARAM_16),
|
||||
named_param!(PARAM_32),
|
||||
named_param!(PARAM_64),
|
||||
named_param!(PARAM_TCHESS_32),
|
||||
named_param!(PARAM_TCHESS_64),
|
||||
];
|
||||
|
||||
criterion_group!(
|
||||
float,
|
||||
add,
|
||||
mul,
|
||||
relu,
|
||||
sigmoid,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
float_parallelized,
|
||||
add_parallelized,
|
||||
mul_parallelized,
|
||||
div_parallelized,,
|
||||
);
|
||||
|
||||
|
||||
fn relu(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "Relu", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.relu(&ct);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn sigmoid(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "sigmoid", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.sigmoid(&ct);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn mul(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "mul", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.mul_total(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn mul_parallelized(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "mul parallelized", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.mul_total_parallelized(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn div_parallelized(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "div parallelized", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.division(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn add(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "add", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.add_total(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn add_parallelized(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("operation");
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param_name, param) in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg);
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
let ct2 = cks.encrypt(msg);
|
||||
|
||||
let bench_id = format!("{}::{}", "add parallelized", param_name);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.add_total_parallelized(&ct1, &ct2);
|
||||
})
|
||||
});
|
||||
}
|
||||
bench_group.finish()
|
||||
}
|
||||
20
concrete-float/docs/SUMMARY.md
Normal file
20
concrete-float/docs/SUMMARY.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Concrete-Integer User Guide
|
||||
|
||||
[Introduction](introduction.md)
|
||||
|
||||
# Getting Started
|
||||
|
||||
[Installation](getting_started/installation.md)
|
||||
|
||||
[Writing Your First Circuit](getting_started/first_circuit.md)
|
||||
|
||||
[Types Of Operations](getting_started/operation_types.md)
|
||||
|
||||
[List of Operations](getting_started/operation_list.md)
|
||||
|
||||
[Cryptographic Parameters](getting_started/parameters.md)
|
||||
|
||||
|
||||
# How to
|
||||
|
||||
[Serialization / Deserialization](tutorials/serialization.md)
|
||||
105
concrete-float/docs/getting_started/first_circuit.md
Normal file
105
concrete-float/docs/getting_started/first_circuit.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Writing Your First Circuit
|
||||
|
||||
|
||||
## Key Types
|
||||
|
||||
`concrete-integer` provides 2 basic key types:
|
||||
- `ClientKey`
|
||||
- `ServerKey`
|
||||
|
||||
The `ClientKey` is the key that encrypts and decrypts messages,
|
||||
thus this key is meant to be kept private and should never be shared.
|
||||
This key is created from parameter values that will dictate both the security and efficiency
|
||||
of computations. The parameters also set the maximum number of bits of message encrypted
|
||||
in a ciphertext.
|
||||
|
||||
The `ServerKey` is the key that is used to actually do the FHE computations. It contains (among other things)
|
||||
a bootstrapping key and a keyswitching key.
|
||||
This key is created from a `ClientKey` that needs to be shared to the server, therefore it is not
|
||||
meant to be kept private.
|
||||
A user with a `ServerKey` can compute on the encrypted data sent by the owner of the associated
|
||||
`ClientKey`.
|
||||
|
||||
To reflect that, computation/operation methods are tied to the `ServerKey` type.
|
||||
|
||||
|
||||
## 1. Key Generation
|
||||
|
||||
To generate the keys, a user needs two parameters:
|
||||
- A set of `shortint` cryptographic parameters.
|
||||
- The number of ciphertexts used to encrypt an integer (we call them "shortint blocks").
|
||||
|
||||
|
||||
For this example we are going to build a pair of keys that can encrypt an **8-bit** integer
|
||||
by using **4** shortint blocks that store **2** bits of message each.
|
||||
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 2. Encrypting values
|
||||
|
||||
|
||||
Once we have our keys we can encrypt values:
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 128;
|
||||
let msg2 = 13;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
}
|
||||
```
|
||||
|
||||
## 3. Computing and decrypting
|
||||
|
||||
With our `server_key`, and encrypted values, we can now do an addition
|
||||
and then decrypt the result.
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 128;
|
||||
let msg2 = 13;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
|
||||
// We use the server public key to execute an integer circuit:
|
||||
let ct_3 = server_key.unchecked_add(&ct_1, &ct_2);
|
||||
|
||||
// We use the client key to decrypt the output of the circuit:
|
||||
let output = client_key.decrypt(&ct_3);
|
||||
|
||||
assert_eq!(output, (msg1 + msg2) % modulus);
|
||||
}
|
||||
```
|
||||
49
concrete-float/docs/getting_started/installation.md
Normal file
49
concrete-float/docs/getting_started/installation.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Installation
|
||||
|
||||
## Cargo.toml
|
||||
|
||||
To use `concrete-integer`, you will need to add it to the list of dependencies
|
||||
of your project, by updating your `Cargo.toml` file.
|
||||
|
||||
```toml
|
||||
concrete-integer = "0.1.0"
|
||||
```
|
||||
|
||||
### Supported platforms
|
||||
|
||||
|
||||
As `concrete-integer` relies on `concrete-shortint`, which in turn relies on `concrete-core`,
|
||||
the support ted platforms supported are:
|
||||
- `x86_64 Linux`
|
||||
- `x86_64 macOS`.
|
||||
|
||||
Windows users can use `concrete-integer` through the `WSL`.
|
||||
|
||||
macOS users which have the newer M1 (`arm64`) devices can use `concrete-integer` by cross-compiling to
|
||||
`x86_64` and run their program with Rosetta.
|
||||
|
||||
First install the needed Rust toolchain:
|
||||
|
||||
```console
|
||||
# Install the macOS x86_64 toolchain (you only need to do this once)
|
||||
rustup toolchain install --force-non-host stable-x86_64-apple-darwin
|
||||
```
|
||||
|
||||
Then you can either:
|
||||
|
||||
- Manually specify the toolchain to use in each of the cargo commands:
|
||||
|
||||
For example:
|
||||
|
||||
```console
|
||||
cargo +stable-x86_64-apple-darwin build
|
||||
cargo +stable-x86_64-apple-darwin test
|
||||
```
|
||||
|
||||
- Or override the toolchain to use for the current project:
|
||||
|
||||
```console
|
||||
rustup override set stable-x86_64-apple-darwin
|
||||
# cargo will use the `stable-x86_64-apple-darwin` toolchain.
|
||||
cargo build
|
||||
```
|
||||
15
concrete-float/docs/getting_started/operation_list.md
Normal file
15
concrete-float/docs/getting_started/operation_list.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# List of available operations
|
||||
|
||||
`concrete-integer` comes with a set of already implemented functions:
|
||||
|
||||
|
||||
- addition between two ciphertexts
|
||||
- addition between a ciphertext and an unencrypted scalar
|
||||
- multiplication of a ciphertext by an unencrypted scalar
|
||||
- bitwise shift `<<`, `>>`
|
||||
- bitwise and, or and xor
|
||||
- multiplication between two ciphertexts
|
||||
- subtraction of a ciphertext by another ciphertext
|
||||
- subtraction of a ciphertext by an unencrypted scalar
|
||||
- negation of a ciphertext
|
||||
|
||||
86
concrete-float/docs/getting_started/operation_types.md
Normal file
86
concrete-float/docs/getting_started/operation_types.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# How Integers are represented
|
||||
|
||||
|
||||
In `concrete-integer`, the encrypted data is split amongst many ciphertexts
|
||||
encrypted using the `concrete-shortint` library.
|
||||
|
||||
This crate implements two ways to represent an integer:
|
||||
- the Radix representation
|
||||
- the CRT (Chinese Reminder Theorem) representation
|
||||
|
||||
## Radix based Integers
|
||||
The first possibility to represent a large integer is to use a radix-based decomposition on the
|
||||
plaintexts. Let $$B \in \mathbb{N}$$ be a basis such that the size of $$B$$ is smaller (or equal)
|
||||
to four bits.
|
||||
Then, an integer $$m \in \mathbb{N}$$ can be written as $$m = m_0 + m_1*B + m_2*B^2 + ... $$, where
|
||||
each $$m_i$$ is strictly smaller than $$B$$. Each $$m_i$$ is then independently encrypted. In
|
||||
the end, an Integer ciphertext is defined as a set of Shortint ciphertexts.
|
||||
|
||||
In practice, the definition of an Integer requires the basis and the number of blocks. This is
|
||||
done at the key creation step.
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
}
|
||||
```
|
||||
|
||||
In this example, the keys are dedicated to Integers decomposed as four blocks using the basis
|
||||
$$B=2^2$$. Otherwise said, they allow to work on Integers modulus $$(2^2)^4 = 2^8$$.
|
||||
|
||||
|
||||
In this representation, the correctness of operations requires to propagate the carries
|
||||
between the ciphertext. This operation is costly since it relies on the computation of many
|
||||
programmable bootstrapping over Shortints.
|
||||
|
||||
|
||||
## CRT based Integers
|
||||
The second approach to represent large integers is based on the Chinese Remainder Theorem.
|
||||
In this cases, the basis $$B$$ is composed of several integers $$b_i$$, such that there are
|
||||
pairwise coprime, and each b_i has a size smaller than four bits. Then, the Integer will be
|
||||
defined modulus $$\prod b_i$$. For an integer $$m$$, its CRT decomposition is simply defined as
|
||||
$$m % b_0, m % b_1, ...$$. Each part is then encrypted as a Shortint ciphertext. In
|
||||
the end, an Integer ciphertext is defined as a set of Shortint ciphertexts.
|
||||
|
||||
An example of such a basis
|
||||
could be $$B = [2, 3, 5]$$. This means that the Integer is defined modulus $$2*3*5 = 30$$.
|
||||
|
||||
This representation has many advantages: no carry propagation is required, so that only cleaning
|
||||
the carry buffer of each ciphertexts is enough. This implies that operations can easily be
|
||||
parallelized. Moreover, it allows to efficiently compute PBS in the case where the function is
|
||||
CRT compliant.
|
||||
|
||||
A variant of the CRT is proposed, where each block might be associated to a different key couple.
|
||||
In the end, a keychain is required to the computations, but performance might be improved.
|
||||
|
||||
|
||||
|
||||
# Types of operations
|
||||
|
||||
|
||||
Much like `concrete-shortint`, the operations available via a `ServerKey` may come in different variants:
|
||||
|
||||
- operations that take their inputs as encrypted values.
|
||||
- scalar operations take at least one non-encrypted value as input.
|
||||
|
||||
For example, the addition has both variants:
|
||||
|
||||
- `ServerKey::unchecked_add` which takes two encrypted values and adds them.
|
||||
- `ServerKey::unchecked_scalar_add` which takes an encrypted value and a clear value (the
|
||||
so-called scalar) and adds them.
|
||||
|
||||
Each operation may come in different 'flavors':
|
||||
|
||||
- `unchecked`: Always does the operation, without checking if the result may exceed the capacity of
|
||||
the plaintext space.
|
||||
- `checked`: Checks are done before computing the operation, returning an error if operation
|
||||
cannot be done safely.
|
||||
- `smart`: Always does the operation, if the operation cannot be computed safely, the smart operation
|
||||
will propagate the carry buffer to make the operation possible.
|
||||
|
||||
Not all operations have these 3 flavors, as some of them are implemented in a way that the operation
|
||||
is always possible without ever exceeding the plaintext space capacity.
|
||||
6
concrete-float/docs/getting_started/parameters.md
Normal file
6
concrete-float/docs/getting_started/parameters.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# Use of parameters
|
||||
|
||||
|
||||
`concrete-integer` does not come with its own set of parameters, instead it uses
|
||||
parameters from the `concrete-shortint` crate. Currently, only the parameters
|
||||
`PARAM_MESSAGE_{X}_CARRY_{X}` with `X` in [1,4] can be used in `concrete-integer`.
|
||||
47
concrete-float/docs/how_to/pbs.md
Normal file
47
concrete-float/docs/how_to/pbs.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# The tree programmable bootstrapping
|
||||
|
||||
In `concrete-integer`, the user can evaluate any function on an encrypted ciphertext. To do so the user must first
|
||||
create a `treepbs key`, choose a function to evaluate and give them as parameters to the `tree programmable bootstrapping`.
|
||||
|
||||
Two versions of the tree pbs are implemented: the `standard` version that computes a result according to every encrypted
|
||||
bit (message and carry), and the `base` version that only takes into account the message bits of each block.
|
||||
|
||||
{% hint style="warning" %}
|
||||
|
||||
The `tree pbs` is quite slow, therefore its use is currently restricted to two and three blocks integer ciphertexts.
|
||||
|
||||
{% endhint %}
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
use concrete_integer::treepbs::TreepbsKey;
|
||||
|
||||
fn main() {
|
||||
let num_block = 2;
|
||||
// Generate the client key and the server key:
|
||||
let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg: u64 = 27;
|
||||
let ct = cks.encrypt(msg);
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus.0.pow(2 as u32) as u64;
|
||||
|
||||
let treepbs_key = TreepbsKey::new(&cks);
|
||||
|
||||
let f = |x: u64| x * x;
|
||||
|
||||
// evaluate f
|
||||
let vec_res = treepbs_key.two_block_pbs(&sks, &ct, f);
|
||||
|
||||
// decryption
|
||||
let res = cks.decrypt(&vec_res);
|
||||
|
||||
let clear = f(msg) % modulus;
|
||||
assert_eq!(res, clear);
|
||||
}
|
||||
```
|
||||
|
||||
# The WOP programmable bootstrapping
|
||||
|
||||
8
concrete-float/docs/introduction.md
Normal file
8
concrete-float/docs/introduction.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Concrete-integer
|
||||
|
||||
## Introduction
|
||||
|
||||
`concrete-integer` is a Rust library (crate) based on `concrete-shortint`, this crate provides
|
||||
large precision integers by using multiple `shortint` ciphertexts.
|
||||
|
||||
The intended target audience for this library is people who are somewhat familiar with cryptography.
|
||||
120
concrete-float/docs/tutorials/circuit_evaluation.md
Normal file
120
concrete-float/docs/tutorials/circuit_evaluation.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# Circuit evaluation
|
||||
|
||||
Let's try to do a circuit evaluation using the different flavours of operations we already introduced.
|
||||
For a very small circuit, the `unchecked` flavour may be enough to do the computation correctly.
|
||||
Otherwise, the `checked` and `smart` are the best options.
|
||||
|
||||
As an example, let's do a scalar multiplication, a subtraction and an addition.
|
||||
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12;
|
||||
let msg2 = 11;
|
||||
let msg3 = 9;
|
||||
let scalar = 3;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
let ct_3 = client_key.encrypt(msg2);
|
||||
|
||||
server_key.unchecked_small_scalar_mul_assign(&mut ct_1, scalar);
|
||||
|
||||
server_key.unchecked_sub_assign(&mut ct_1, &ct_2);
|
||||
|
||||
server_key.unchecked_add_assign(&mut ct_1, &ct_3);
|
||||
|
||||
// We use the client key to decrypt the output of the circuit:
|
||||
let output = client_key.decrypt(&ct_1);
|
||||
// The carry buffer has been overflowed, the result is not correct
|
||||
assert_ne!(output, ((msg1 * scalar as u64 - msg2) + msg3) % modulus as u64);
|
||||
}
|
||||
```
|
||||
|
||||
During this computation the carry buffer has been overflowed and as all the operations were `unchecked` the output
|
||||
may be incorrect.
|
||||
|
||||
If we redo this same circuit but using the `checked` flavour, a panic will occur.
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 2;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12;
|
||||
let msg2 = 11;
|
||||
let msg3 = 9;
|
||||
let scalar = 3;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
let ct_3 = client_key.encrypt(msg3);
|
||||
|
||||
let result = server_key.checked_small_scalar_mul_assign(&mut ct_1, scalar);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let result = server_key.checked_sub_assign(&mut ct_1, &ct_2);
|
||||
assert!(result.is_err());
|
||||
|
||||
// We use the client key to decrypt the output of the circuit:
|
||||
// Only the scalar multiplication could be done
|
||||
let output = client_key.decrypt(&ct_1);
|
||||
assert_eq!(output, (msg1 * scalar) % modulus as u64);
|
||||
}
|
||||
```
|
||||
|
||||
Therefore the `checked` flavour permits to manually manage the overflow of the carry buffer
|
||||
by raising an error if the correctness is not guaranteed.
|
||||
|
||||
Lastly, using the `smart` flavour will output the correct result all the time. However, the computation may be slower
|
||||
as the carry buffer may be propagated during the computations.
|
||||
|
||||
```rust
|
||||
use concrete_integer::gen_keys;
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12;
|
||||
let msg2 = 11;
|
||||
let msg3 = 9;
|
||||
let scalar = 3;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
let mut ct_2 = client_key.encrypt(msg2);
|
||||
let mut ct_3 = client_key.encrypt(msg3);
|
||||
|
||||
server_key.smart_scalar_mul_assign(&mut ct_1, scalar);
|
||||
|
||||
server_key.smart_sub_assign(&mut ct_1, &mut ct_2);
|
||||
|
||||
server_key.smart_add_assign(&mut ct_1, &mut ct_3);
|
||||
|
||||
// We use the client key to decrypt the output of the circuit:
|
||||
let output = client_key.decrypt(&ct_1);
|
||||
assert_eq!(output, ((msg1 * scalar as u64 - msg2) + msg3) % modulus as u64);
|
||||
}
|
||||
```
|
||||
78
concrete-float/docs/tutorials/serialization.md
Normal file
78
concrete-float/docs/tutorials/serialization.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# Serialization / Deserialization
|
||||
|
||||
As explained in the introduction, some types (`Serverkey`, `Ciphertext`) are meant to be shared
|
||||
with the server that does the computations.
|
||||
|
||||
The easiest way to send these data to a server is to use the serialization and deserialization features.
|
||||
concrete-integer uses the serde framework, serde's Serialize and Deserialize are implemented.
|
||||
|
||||
To be able to serialize our data, we need to pick a [data format], for our use case,
|
||||
[bincode] is a good choice, mainly because it is binary format.
|
||||
|
||||
|
||||
```toml
|
||||
# Cargo.toml
|
||||
|
||||
[dependencies]
|
||||
# ...
|
||||
bincode = "1.3.3"
|
||||
```
|
||||
|
||||
|
||||
```rust
|
||||
// main.rs
|
||||
|
||||
use bincode;
|
||||
|
||||
use std::io::Cursor;
|
||||
use concrete_integer::{gen_keys, ServerKey, Ciphertext};
|
||||
use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 201;
|
||||
let msg2 = 12;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
|
||||
let mut serialized_data = Vec::new();
|
||||
bincode::serialize_into(&mut serialized_data, &server_key)?;
|
||||
bincode::serialize_into(&mut serialized_data, &ct_1)?;
|
||||
bincode::serialize_into(&mut serialized_data, &ct_2)?;
|
||||
|
||||
// Simulate sending serialized data to a server and getting
|
||||
// back the serialized result
|
||||
let serialized_result = server_function(&serialized_data)?;
|
||||
let result: Ciphertext = bincode::deserialize(&serialized_result)?;
|
||||
|
||||
let output = client_key.decrypt(&result);
|
||||
assert_eq!(output, (msg1 + msg2) % modulus);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
fn server_function(serialized_data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
let mut serialized_data = Cursor::new(serialized_data);
|
||||
let server_key: ServerKey = bincode::deserialize_from(&mut serialized_data)?;
|
||||
let ct_1: Ciphertext = bincode::deserialize_from(&mut serialized_data)?;
|
||||
let ct_2: Ciphertext = bincode::deserialize_from(&mut serialized_data)?;
|
||||
|
||||
let result = server_key.unchecked_add(&ct_1, &ct_2);
|
||||
|
||||
let serialized_result = bincode::serialize(&result)?;
|
||||
|
||||
Ok(serialized_result)
|
||||
}
|
||||
```
|
||||
|
||||
[serde]: https://crates.io/crates/serde
|
||||
[data format]: https://serde.rs/#data-formats
|
||||
[bincode]: https://crates.io/crates/bincode
|
||||
15
concrete-float/katex-header.html
Normal file
15
concrete-float/katex-header.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/katex.min.css" integrity="sha384-9eLZqc9ds8eNjO3TmqPeYcDj8n+Qfa4nuSiGYa6DjLNcv9BtN69ZIulL9+8CqC9Y" crossorigin="anonymous">
|
||||
<script src="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/katex.min.js" integrity="sha384-K3vbOmF2BtaVai+Qk37uypf7VrgBubhQreNQe9aGsz9lB63dIFiQVlJbr92dw2Lx" crossorigin="anonymous"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/contrib/auto-render.min.js" integrity="sha384-kmZOZB5ObwgQnS/DuDg6TScgOiWWBiVt0plIRkZCmE6rDZGrEOQeHM5PcHi+nyqe" crossorigin="anonymous"></script>
|
||||
<script>
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
renderMathInElement(document.body, {
|
||||
delimiters: [
|
||||
{left: "$$", right: "$$", display: true},
|
||||
{left: "\\(", right: "\\)", display: false},
|
||||
{left: "$", right: "$", display: false},
|
||||
{left: "\\[", right: "\\]", display: true}
|
||||
]
|
||||
});
|
||||
});
|
||||
</script>
|
||||
BIN
concrete-float/long_run
Normal file
BIN
concrete-float/long_run
Normal file
Binary file not shown.
30
concrete-float/src/ciphertext/mod.rs
Normal file
30
concrete-float/src/ciphertext/mod.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
//! This module implements the ciphertext structure containing an encryption of an integer message.
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe::shortint;
|
||||
|
||||
/// Id to recognize the key used to encrypt a block.
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct KeyId(pub usize);
|
||||
|
||||
#[derive(Serialize, Clone, Deserialize, PartialEq, Eq, Debug)]
|
||||
pub struct Ciphertext {
|
||||
pub ct_vec_mantissa: Vec<shortint::ciphertext::Ciphertext>,
|
||||
pub ct_vec_exponent: Vec<shortint::ciphertext::Ciphertext>,
|
||||
pub ct_sign: shortint::ciphertext::Ciphertext,
|
||||
pub(crate) e_min: i64,
|
||||
}
|
||||
impl Ciphertext {
|
||||
/// Returns the slice of blocks that the ciphertext is composed of.
|
||||
pub fn mantissa_blocks(&self) -> &[shortint::Ciphertext] {
|
||||
&self.ct_vec_mantissa
|
||||
}
|
||||
pub fn exponent_blocks(&self) -> &[shortint::Ciphertext] {
|
||||
&self.ct_vec_exponent
|
||||
}
|
||||
pub fn sign(&self) -> &shortint::Ciphertext {
|
||||
&self.ct_sign
|
||||
}
|
||||
pub fn e_min(&self) -> &i64 {
|
||||
&self.e_min
|
||||
}
|
||||
}
|
||||
265
concrete-float/src/client_key/mod.rs
Normal file
265
concrete-float/src/client_key/mod.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
//! This module implements the generation of the client secret keys, together with the
|
||||
//! encryption and decryption methods.
|
||||
|
||||
pub(crate) mod utils;
|
||||
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe::shortint;
|
||||
use tfhe::shortint::{ClassicPBSParameters, WopbsParameters};
|
||||
pub use utils::radix_decomposition;
|
||||
|
||||
/// The number of ciphertexts in the vector.
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct VecLength(pub usize);
|
||||
|
||||
/// A structure containing the client key, which must be kept secret.
|
||||
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
|
||||
pub struct ClientKey {
|
||||
pub(crate) key: shortint::client_key::ClientKey,
|
||||
pub(crate) vector_length_mantissa: VecLength,
|
||||
pub(crate) vector_length_exponent: VecLength,
|
||||
}
|
||||
|
||||
impl ClientKey {
|
||||
/// Allocates and generates a client key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::client_key::ClientKey;
|
||||
/// use concrete_float::parameters::{PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32};
|
||||
/// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// // Generate the client key associated to integers over 4 blocks
|
||||
/// // of messages with modulus over 2 bits
|
||||
/// let param = (PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32);
|
||||
/// let cks = ClientKey::new(param, 4, 1);
|
||||
/// ```
|
||||
pub fn new(
|
||||
parameter_set: (ClassicPBSParameters, WopbsParameters),
|
||||
size_mantissa: usize,
|
||||
size_exponent: usize,
|
||||
) -> Self {
|
||||
let key = shortint::ClientKey::new(parameter_set);
|
||||
Self {
|
||||
key,
|
||||
vector_length_mantissa: VecLength(size_mantissa),
|
||||
vector_length_exponent: VecLength(size_exponent),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the parameters used by the client key.
|
||||
pub fn parameters(&self) -> shortint::parameters::ShortintParameterSet {
|
||||
self.key.parameters
|
||||
}
|
||||
|
||||
/// Encrypts a float message using the client key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::client_key::ClientKey;
|
||||
/// use concrete_float::parameters::{PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32};
|
||||
///
|
||||
/// let param = (PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32);
|
||||
/// let mut cks = ClientKey::new(param, 3, 1);
|
||||
///
|
||||
/// let msg = 1844640.;
|
||||
/// // Encryption of one message:
|
||||
/// let ct = cks.encrypt(msg);
|
||||
/// let res = cks.decrypt(&ct);
|
||||
///
|
||||
/// //approximation less than 0.1%
|
||||
/// assert_eq!(res, msg)
|
||||
/// ```
|
||||
pub fn encrypt(&self, message: f64) -> Ciphertext {
|
||||
let ct_sign = self.encrypt_sign(message);
|
||||
|
||||
let log_msg_modulus = f64::log2(self.parameters().message_modulus().0 as f64) as usize;
|
||||
let e_min = -((1 << (self.vector_length_exponent.0 * log_msg_modulus - 1)) as i64)
|
||||
- (self.vector_length_mantissa.0 as i64 - 1);
|
||||
if message == 0. {
|
||||
let exponent = 0;
|
||||
let mantissa = 0.0;
|
||||
let ct_vec_mantissa = self.encrypt_mantissa(mantissa as u64);
|
||||
let ct_vec_exponent = self.encrypt_exponent(exponent as u64);
|
||||
Ciphertext {
|
||||
ct_vec_mantissa,
|
||||
ct_vec_exponent,
|
||||
ct_sign,
|
||||
e_min,
|
||||
}
|
||||
} else {
|
||||
let length_mantissa = self.vector_length_mantissa.0;
|
||||
let log_message_modulus =
|
||||
f64::log2(self.parameters().message_modulus().0 as f64) as usize;
|
||||
|
||||
let value_exponent = log_message_modulus as u64;
|
||||
let mut exponent = e_min.abs();
|
||||
let mut cpy_message = message.abs();
|
||||
while cpy_message < (1_u128 << (length_mantissa * log_message_modulus)) as f64 {
|
||||
cpy_message *= (1 << value_exponent) as f64;
|
||||
exponent -= 1;
|
||||
}
|
||||
while cpy_message >= (1_u128 << (length_mantissa * log_message_modulus)) as f64 {
|
||||
cpy_message /= (1 << value_exponent) as f64;
|
||||
exponent += 1;
|
||||
}
|
||||
//TODO
|
||||
if exponent >= (1 << (log_message_modulus * self.vector_length_exponent.0) as i64) {
|
||||
println!("encrypt overflow");
|
||||
}
|
||||
if exponent < 0 {
|
||||
for _ in 0..exponent.abs() {
|
||||
cpy_message /= (1 << value_exponent) as f64;
|
||||
}
|
||||
exponent = 0;
|
||||
//panic!()
|
||||
}
|
||||
let mantissa = cpy_message.round() as u64;
|
||||
let ct_vec_mantissa = self.encrypt_mantissa(mantissa);
|
||||
let ct_vec_exponent = self.encrypt_exponent(exponent as u64);
|
||||
Ciphertext {
|
||||
ct_vec_mantissa,
|
||||
ct_vec_exponent,
|
||||
ct_sign,
|
||||
e_min,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn encrypt_sign(&self, message: f64) -> shortint::ciphertext::Ciphertext {
|
||||
let sign: u64;
|
||||
if message >= 0. {
|
||||
sign = 0;
|
||||
} else {
|
||||
sign = 1
|
||||
}
|
||||
self.key.encrypt_without_padding(
|
||||
sign * (self.key.parameters.message_modulus().0 * self.key.parameters.carry_modulus().0
|
||||
/ 2) as u64,
|
||||
)
|
||||
}
|
||||
|
||||
fn encrypt_mantissa(&self, mantissa: u64) -> Vec<shortint::Ciphertext> {
|
||||
let mut ct_vec_mantissa: Vec<shortint::ciphertext::Ciphertext> = Vec::new();
|
||||
let mut power = 1_u128;
|
||||
let message_modulus = self.parameters().message_modulus().0 as u128;
|
||||
for _ in 0..self.vector_length_mantissa.0 {
|
||||
let mut decomp = mantissa as u128 & ((message_modulus - 1) * power);
|
||||
decomp /= power;
|
||||
|
||||
// encryption
|
||||
let ct = self.key.encrypt(decomp as u64);
|
||||
ct_vec_mantissa.push(ct);
|
||||
//modulus to the power i
|
||||
power *= message_modulus;
|
||||
}
|
||||
ct_vec_mantissa
|
||||
}
|
||||
|
||||
fn encrypt_exponent(&self, exponent: u64) -> Vec<shortint::Ciphertext> {
|
||||
let mut ct_vec_exponent: Vec<shortint::ciphertext::Ciphertext> = Vec::new();
|
||||
let mut power = 1_u64;
|
||||
let message_modulus = self.parameters().message_modulus().0 as u64;
|
||||
for _ in 0..self.vector_length_exponent.0 {
|
||||
let mut decomp = exponent as u64 & ((message_modulus - 1) * power);
|
||||
decomp /= power;
|
||||
|
||||
// encryption
|
||||
let ct = self.key.encrypt(decomp);
|
||||
ct_vec_exponent.push(ct);
|
||||
//modulus to the power i
|
||||
power *= message_modulus;
|
||||
}
|
||||
ct_vec_exponent
|
||||
}
|
||||
|
||||
/// Decrypts a ciphertext encrypting an float message
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::client_key::ClientKey;
|
||||
/// use concrete_float::parameters::{PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32};
|
||||
///
|
||||
/// let param = (PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32);
|
||||
/// let mut cks = ClientKey::new(param, 3, 1);
|
||||
///
|
||||
/// let msg = 1844640.;
|
||||
/// // Encryption of one message:
|
||||
/// let ct = cks.encrypt(msg);
|
||||
/// let res = cks.decrypt(&ct);
|
||||
///
|
||||
/// //approximation less than 0.1%
|
||||
/// assert_eq!(res, msg)
|
||||
/// ```
|
||||
pub fn decrypt(&self, ctxt: &Ciphertext) -> f64 {
|
||||
let log_message_modulus = f64::log2(self.parameters().message_modulus().0 as f64) as usize;
|
||||
let value_exponent = log_message_modulus as i64;
|
||||
|
||||
let mut mantissa = self.decrypt_mantissa(&ctxt.ct_vec_mantissa) as f64;
|
||||
let mut exponent = self.decrypt_exponent(&ctxt.ct_vec_exponent) as i64;
|
||||
let sign = self.decrypt_sign(&ctxt.ct_sign);
|
||||
|
||||
exponent += ctxt.e_min;
|
||||
if exponent > 0 {
|
||||
for _ in 0..exponent.abs() {
|
||||
mantissa *= (1_u128 << value_exponent) as f64
|
||||
}
|
||||
} else {
|
||||
for _ in 0..exponent.abs() {
|
||||
mantissa /= (1_u128 << value_exponent) as f64
|
||||
}
|
||||
}
|
||||
|
||||
let res;
|
||||
if sign == 1 {
|
||||
res = -mantissa
|
||||
} else {
|
||||
res = mantissa
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub fn decrypt_mantissa(&self, ctxt: &Vec<shortint::Ciphertext>) -> u128 {
|
||||
let mut result = 0_u128;
|
||||
let mut shift = 1_u128;
|
||||
for c_i in ctxt.iter() {
|
||||
//decrypt the component i of the integer and multiply it by the radix product
|
||||
let tmp = (self.key.decrypt_message_and_carry(c_i) as u128).wrapping_mul(shift);
|
||||
|
||||
// update the result
|
||||
result = result.wrapping_add(tmp as u128);
|
||||
|
||||
// update the shift for the next iteration
|
||||
shift = shift.wrapping_mul(self.parameters().message_modulus().0 as u128);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn decrypt_exponent(&self, ctxt: &Vec<shortint::Ciphertext>) -> u64 {
|
||||
let mut result = 0_u64;
|
||||
let mut shift = 1_u64;
|
||||
for c_i in ctxt.iter() {
|
||||
//decrypt the component i of the integer and multiply it by the radix product
|
||||
let tmp = self.key.decrypt_message_and_carry(c_i).wrapping_mul(shift);
|
||||
|
||||
// update the result
|
||||
result = result.wrapping_add(tmp);
|
||||
|
||||
// update the shift for the next iteration
|
||||
shift = shift.wrapping_mul(self.parameters().message_modulus().0 as u64);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn decrypt_sign(&self, ctxt: &shortint::Ciphertext) -> u64 {
|
||||
let result = self.key.decrypt_message_and_carry_without_padding(ctxt);
|
||||
result
|
||||
/ (self.key.parameters.message_modulus().0 * self.key.parameters.carry_modulus().0 / 2)
|
||||
as u64
|
||||
}
|
||||
}
|
||||
51
concrete-float/src/client_key/utils.rs
Normal file
51
concrete-float/src/client_key/utils.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct RadixDecomposition {
|
||||
pub msg_space: usize,
|
||||
pub block_number: usize,
|
||||
}
|
||||
|
||||
/// Computes possible radix decompositions
|
||||
///
|
||||
/// Takes the number of bit of the message space as input and output a vector containing all the
|
||||
/// correct
|
||||
/// possible block decomposition assuming the same message space for all blocks.
|
||||
/// Lower and upper bounds define the minimal and maximal space to be considered
|
||||
/// Example: 6,2,4 -> [ [2,3], [3,2]] : [msg_space = 2 bits, block_number = 3]
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::client_key::radix_decomposition;
|
||||
/// let input_space = 16; //
|
||||
/// let min = 2;
|
||||
/// let max = 4;
|
||||
/// let decomp = radix_decomposition(input_space, min, max);
|
||||
///
|
||||
/// // Check that 3 possible radix decompositions are provided
|
||||
/// assert_eq!(decomp.len(), 3);
|
||||
/// ```
|
||||
pub fn radix_decomposition(
|
||||
input_space: usize,
|
||||
min_space: usize,
|
||||
max_space: usize,
|
||||
) -> Vec<RadixDecomposition> {
|
||||
let mut out: Vec<RadixDecomposition> = vec![];
|
||||
let mut max = max_space;
|
||||
if max_space > input_space {
|
||||
max = input_space;
|
||||
}
|
||||
for msg_space in min_space..max + 1 {
|
||||
let mut block_number = input_space / msg_space;
|
||||
//Manual ceil of the division
|
||||
if input_space % msg_space != 0 {
|
||||
block_number += 1;
|
||||
}
|
||||
out.push(RadixDecomposition {
|
||||
msg_space,
|
||||
block_number,
|
||||
})
|
||||
}
|
||||
out
|
||||
}
|
||||
41
concrete-float/src/keycache.rs
Normal file
41
concrete-float/src/keycache.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::path::Path;
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use crate::{ClientKey, ServerKey};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct FloatKeyCache;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref KEY_CACHE: FloatKeyCache = FloatKeyCache::default();
|
||||
}
|
||||
|
||||
pub fn get_sks(str: &str) -> ServerKey {
|
||||
let fiptr = format!("key/sks_key/{}", str);
|
||||
let filepath = Path::new(&fiptr);
|
||||
let file = BufReader::new(File::open(filepath).unwrap());
|
||||
let saved_key: ServerKey = bincode::deserialize_from(file).unwrap();
|
||||
saved_key
|
||||
}
|
||||
|
||||
pub fn get_cks(str: &str) -> ClientKey {
|
||||
let fiptr = format!("key/cks_key/{}", str);
|
||||
let filepath = Path::new(&fiptr);
|
||||
let file = BufReader::new(File::open(filepath).unwrap());
|
||||
let saved_key: ClientKey = bincode::deserialize_from(file).unwrap();
|
||||
saved_key
|
||||
}
|
||||
|
||||
pub fn save_sks(key: ServerKey, str: &str) {
|
||||
let filepath = format!("key/sks_key/{}", str);
|
||||
let file = BufWriter::new(File::create(filepath).unwrap());
|
||||
bincode::serialize_into(file, &key).unwrap();
|
||||
}
|
||||
|
||||
pub fn save_cks(key: ClientKey ,str: &str) {
|
||||
let filepath = format!("key/cks_key/{}", str);
|
||||
let file = BufWriter::new(File::create(filepath).unwrap());
|
||||
bincode::serialize_into(file, &key).unwrap();
|
||||
}
|
||||
92
concrete-float/src/lib.rs
Executable file
92
concrete-float/src/lib.rs
Executable file
@@ -0,0 +1,92 @@
|
||||
/*
|
||||
#![allow(clippy::excessive_precision)]
|
||||
//! Welcome the the `concrete-integer` documentation!
|
||||
//!
|
||||
//! # Description
|
||||
//!
|
||||
//! This library makes it possible to execute modular operations over encrypted integer.
|
||||
//!
|
||||
//! It allows to execute an integer circuit on an untrusted server because both circuit inputs
|
||||
//! outputs are kept private.
|
||||
//!
|
||||
//! Data are encrypted on the client side, before being sent to the server.
|
||||
//! On the server side every computation is performed on ciphertexts
|
||||
//!
|
||||
//! # Quick Example
|
||||
//!
|
||||
//! The following piece of code shows how to generate keys and run a integer circuit
|
||||
//! homomorphically.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use concrete_float::gen_keys;
|
||||
//! use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
//!
|
||||
//! //4 blocks for the radix decomposition
|
||||
//! let number_of_blocks = 4;
|
||||
//! // Modulus = (2^2)*4 = 2^8 (from the parameters chosen and the number of blocks
|
||||
//! let modulus = 1 << 8;
|
||||
//!
|
||||
//! // Generation of the client/server keys, using the default parameters:
|
||||
//! let (mut client_key, mut server_key) = gen_keys(&PARAM_MESSAGE_2_CARRY_2, number_of_blocks);
|
||||
//!
|
||||
//! let msg1 = 153;
|
||||
//! let msg2 = 125;
|
||||
//!
|
||||
//! // Encryption of two messages using the client key:
|
||||
//! let ct_1 = client_key.encrypt(msg1);
|
||||
//! let ct_2 = client_key.encrypt(msg2);
|
||||
//!
|
||||
//! // Homomorphic evaluation of an integer circuit (here, an addition) using the server key:
|
||||
//! let ct_3 = server_key.unchecked_add(&ct_1, &ct_2);
|
||||
//!
|
||||
//! // Decryption of the ciphertext using the client key:
|
||||
//! let output = client_key.decrypt(&ct_3);
|
||||
//! assert_eq!(output, (msg1 + msg2) % modulus);
|
||||
//! ```
|
||||
//!
|
||||
//! # Warning
|
||||
//! This uses cryptographic parameters from the `concrete-shortint` crates.
|
||||
//! Currently, the radix approach is only compatible with parameter sets such
|
||||
//! that the message and carry buffers have the same size.
|
||||
extern crate core;
|
||||
*/
|
||||
extern crate core;
|
||||
|
||||
pub mod ciphertext;
|
||||
pub mod client_key;
|
||||
pub mod parameters;
|
||||
pub mod server_key;
|
||||
use crate::client_key::ClientKey;
|
||||
use crate::server_key::ServerKey;
|
||||
//pub mod keycache;
|
||||
//pub mod wopbs;
|
||||
#[cfg(doctest)]
|
||||
//mod test_user_docs;
|
||||
use tfhe::shortint;
|
||||
use tfhe::shortint;
|
||||
|
||||
/// Generate a couple of client and server keys with given parameters
|
||||
///
|
||||
/// * the client key is used to encrypt and decrypt and has to be kept secret;
|
||||
/// * the server key is used to perform homomorphic operations on the server side and it is meant to
|
||||
/// be published (the client sends it to the server).
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::gen_keys;
|
||||
/// use concrete_shortint::parameters::DEFAULT_PARAMETERS;
|
||||
///
|
||||
/// let size_mantissa = 4;
|
||||
/// let size_exponent = 1;
|
||||
/// ```
|
||||
pub fn gen_keys(
|
||||
parameters_set: shortint::ClassicPBSParameters,
|
||||
parameters_set_wopbs: shortint::WopbsParameters,
|
||||
size_mantissa: usize,
|
||||
size_exponent: usize,
|
||||
) -> (ClientKey, ServerKey) {
|
||||
let params = (parameters_set, parameters_set_wopbs);
|
||||
let cks = ClientKey::new(params, size_mantissa, size_exponent);
|
||||
let sks = ServerKey::new(&cks);
|
||||
|
||||
(cks, sks)
|
||||
}
|
||||
1057
concrete-float/src/parameters/mod.rs
Normal file
1057
concrete-float/src/parameters/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
158
concrete-float/src/server_key/add.rs
Normal file
158
concrete-float/src/server_key/add.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use rayon::prelude::*;
|
||||
use tfhe::shortint;
|
||||
|
||||
//use crate::keycache::{get_sks, get_cks};
|
||||
|
||||
impl ServerKey {
|
||||
/// Computes homomorphically an addition between two ciphertexts encrypting integer values.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
/// ciphertext.
|
||||
///
|
||||
/// The result is returned as a new ciphertext.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// ```
|
||||
pub fn unchecked_add_mantissa(
|
||||
&self,
|
||||
ct_left: &Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) -> Ciphertext {
|
||||
let mut result = ct_left.clone();
|
||||
self.unchecked_add_mantissa_assign(&mut result, ct_right);
|
||||
result
|
||||
}
|
||||
|
||||
/// Computes homomorphically an addition between two ciphertexts encrypting integer values.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
/// ciphertext.
|
||||
///
|
||||
/// The result is assigned to the `ct_left` ciphertext.
|
||||
/// ```rust
|
||||
/// ```
|
||||
pub fn unchecked_add_mantissa_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) {
|
||||
for (ct_left_i, ct_right_i) in ct_left
|
||||
.ct_vec_mantissa
|
||||
.iter_mut()
|
||||
.zip(ct_right.ct_vec_mantissa.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
|
||||
}
|
||||
}
|
||||
|
||||
/// we suppose that the mantissa are align
|
||||
pub fn add_mantissa(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) {
|
||||
for (ct_left_i, ct_right_i) in ct_left
|
||||
.ct_vec_mantissa
|
||||
.iter_mut()
|
||||
.zip(ct_right.ct_vec_mantissa.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Verifies if ct1 and ct2 can be added together.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
///```rust
|
||||
/// ```
|
||||
pub fn is_add_possible(
|
||||
&self,
|
||||
ct_left: &[shortint::ciphertext::Ciphertext],
|
||||
ct_right: &[shortint::ciphertext::Ciphertext],
|
||||
) -> bool {
|
||||
for (ct_left_i, ct_right_i) in ct_left.iter().zip(ct_right.iter()) {
|
||||
if self.key.is_add_possible(ct_left_i, ct_right_i).is_err() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn add_total(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let res_sign = self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign);
|
||||
let (mut ct1_aligned, mut ct2_aligned) = self.align_mantissa(&ct1, &ct2);
|
||||
let ct_sub = self.sub_mantissa(&ct1_aligned, &ct2_aligned);
|
||||
self.add_mantissa(&mut ct1_aligned, &mut ct2_aligned);
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let ggsw = self.ggsw_ks_cbs(&res_sign, 0); // let ggsw = self.wopbs_key.extract_one_bit_cbs(&self.key, &res_sign, 63);
|
||||
let mut res = self.cmuxes_full(&ct1_aligned, &ct_sub, &ggsw);
|
||||
self.clean_degree(&mut res);
|
||||
res
|
||||
}
|
||||
|
||||
/// Computes homomorphically an addition between two ciphertexts encrypting integer values.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
/// ciphertext.
|
||||
///
|
||||
/// The result is returned as a new ciphertext.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// ```
|
||||
pub fn unchecked_add_mantissa_parallelized(
|
||||
&self,
|
||||
ct_left: &Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) -> Ciphertext {
|
||||
let mut result = ct_left.clone();
|
||||
self.unchecked_add_mantissa_assign_parallelized(&mut result, ct_right);
|
||||
result
|
||||
}
|
||||
|
||||
/// Computes homomorphically an addition between two ciphertexts encrypting integer values.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
/// ciphertext.
|
||||
///
|
||||
/// The result is assigned to the `ct_left` ciphertext.
|
||||
/// ```rust
|
||||
/// ```
|
||||
pub fn unchecked_add_mantissa_assign_parallelized(
|
||||
&self,
|
||||
ct_left: &mut Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) {
|
||||
ct_left
|
||||
.ct_vec_mantissa
|
||||
.par_iter_mut()
|
||||
.zip(ct_right.ct_vec_mantissa.par_iter())
|
||||
.for_each(|(ct_left_i, ct_right_i)| {
|
||||
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
|
||||
});
|
||||
}
|
||||
|
||||
/// we suppose that the mantissa are align
|
||||
pub fn add_mantissa_parallelized(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) {
|
||||
// The operation is too small to be worth parallelizing
|
||||
ct_left
|
||||
.ct_vec_mantissa
|
||||
.iter_mut()
|
||||
.zip(ct_right.ct_vec_mantissa.iter())
|
||||
.for_each(|(ct_left_i, ct_right_i)| {
|
||||
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn add_total_parallelized(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let res_sign = self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign);
|
||||
let (mut ct1_aligned, mut ct2_aligned) = self.align_mantissa_parallelized(&ct1, &ct2);
|
||||
let ct_sub = self.sub_mantissa_parallelized(&ct1_aligned, &ct2_aligned);
|
||||
self.add_mantissa_parallelized(&mut ct1_aligned, &mut ct2_aligned);
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let ggsw = self.ggsw_ks_cbs_parallelized(&res_sign, 0); // let ggsw = self.wopbs_key.extract_one_bit_cbs(&self.key, &res_sign, 63);
|
||||
let mut res = self.cmuxes_full_parallelized(&ct1_aligned, &ct_sub, &ggsw);
|
||||
self.clean_degree_parallelized(&mut res);
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
162
concrete-float/src/server_key/align_mantissa.rs
Normal file
162
concrete-float/src/server_key/align_mantissa.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use aligned_vec::ABox;
|
||||
use rayon::prelude::*;
|
||||
use tfhe::core_crypto::fft_impl::fft64::c64;
|
||||
use tfhe::core_crypto::fft_impl::fft64::crypto::ggsw::FourierGgswCiphertext;
|
||||
use tfhe::shortint;
|
||||
|
||||
impl ServerKey {
|
||||
// align the two mantissas of to floating points
|
||||
pub fn align_mantissa(
|
||||
&self,
|
||||
ct_left: &Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) -> (Ciphertext, Ciphertext) {
|
||||
let (ct_res, sign) = self.sub(&ct_left.ct_vec_exponent, &ct_right.ct_vec_exponent);
|
||||
let (vec_ggsw, sign_ggsw) =
|
||||
self.create_vec_ggsw_after_sub(&ct_res, &sign, ct_left.ct_vec_mantissa.len());
|
||||
let mut need_to_be_aligned = self.cmuxes(
|
||||
&ct_left.ct_vec_mantissa,
|
||||
&ct_right.ct_vec_mantissa,
|
||||
&sign_ggsw,
|
||||
);
|
||||
let aligned_exp = self.cmuxes(
|
||||
&ct_right.ct_vec_exponent,
|
||||
&ct_left.ct_vec_exponent,
|
||||
&sign_ggsw,
|
||||
);
|
||||
let aligned = self.cmux_tree_mantissa(&mut need_to_be_aligned, &vec_ggsw);
|
||||
let ct_left_aligned = self.cmuxes(&aligned, &ct_left.ct_vec_mantissa, &sign_ggsw);
|
||||
let ct_right_aligned = self.cmuxes(&ct_right.ct_vec_mantissa, &aligned, &sign_ggsw);
|
||||
let new_left = Ciphertext {
|
||||
ct_vec_mantissa: ct_left_aligned,
|
||||
ct_vec_exponent: aligned_exp.clone(),
|
||||
ct_sign: ct_left.ct_sign.clone(),
|
||||
e_min: ct_left.e_min,
|
||||
};
|
||||
let new_right = Ciphertext {
|
||||
ct_vec_mantissa: ct_right_aligned,
|
||||
ct_vec_exponent: aligned_exp,
|
||||
ct_sign: ct_right.ct_sign.clone(),
|
||||
e_min: ct_right.e_min,
|
||||
};
|
||||
(new_left, new_right)
|
||||
}
|
||||
|
||||
pub fn align_mantissa_parallelized(
|
||||
&self,
|
||||
ct_left: &Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) -> (Ciphertext, Ciphertext) {
|
||||
let (mut ct_res, sign) =
|
||||
self.abs_diff_parallelized(&ct_left.ct_vec_exponent, &ct_right.ct_vec_exponent);
|
||||
|
||||
let (vec_ggsw, sign_ggsw) = self.create_vec_ggsw_after_sub_parallelized(
|
||||
&mut ct_res,
|
||||
&sign,
|
||||
ct_left.ct_vec_mantissa.len(),
|
||||
);
|
||||
|
||||
let (mut need_to_be_aligned, aligned_exp) = rayon::join(
|
||||
|| {
|
||||
self.cmuxes_parallelized(
|
||||
&ct_left.ct_vec_mantissa,
|
||||
&ct_right.ct_vec_mantissa,
|
||||
&sign_ggsw,
|
||||
)
|
||||
},
|
||||
|| {
|
||||
self.cmuxes_parallelized(
|
||||
&ct_right.ct_vec_exponent,
|
||||
&ct_left.ct_vec_exponent,
|
||||
&sign_ggsw,
|
||||
)
|
||||
},
|
||||
);
|
||||
let aligned = self.cmux_tree_mantissa_parallelized(&mut need_to_be_aligned, &vec_ggsw);
|
||||
let (ct_left_aligned, ct_right_aligned) = rayon::join(
|
||||
|| self.cmuxes_parallelized(&aligned, &ct_left.ct_vec_mantissa, &sign_ggsw),
|
||||
|| self.cmuxes_parallelized(&ct_right.ct_vec_mantissa, &aligned, &sign_ggsw),
|
||||
);
|
||||
let new_left = Ciphertext {
|
||||
ct_vec_mantissa: ct_left_aligned,
|
||||
ct_vec_exponent: aligned_exp.clone(),
|
||||
ct_sign: ct_left.ct_sign.clone(),
|
||||
e_min: ct_left.e_min,
|
||||
};
|
||||
let new_right = Ciphertext {
|
||||
ct_vec_mantissa: ct_right_aligned,
|
||||
ct_vec_exponent: aligned_exp,
|
||||
ct_sign: ct_right.ct_sign.clone(),
|
||||
e_min: ct_right.e_min,
|
||||
};
|
||||
(new_left, new_right)
|
||||
}
|
||||
|
||||
pub fn create_vec_ggsw_after_sub(
|
||||
&self,
|
||||
ct_res: &Vec<shortint::ciphertext::Ciphertext>,
|
||||
sign: &shortint::ciphertext::Ciphertext,
|
||||
len_mantissa: usize,
|
||||
) -> (
|
||||
Vec<FourierGgswCiphertext<ABox<[c64]>>>,
|
||||
FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as usize;
|
||||
|
||||
let mut ct_res = ct_res.clone();
|
||||
self.full_propagate_exponent(&mut ct_res);
|
||||
let mut vec_ggsw = Vec::new();
|
||||
for i in 0..ct_res.len() {
|
||||
if len_mantissa < ((f64::log2(msg_modulus as f64) as usize) * i) {
|
||||
let mut ggsw = vec![self.ggsw_pbs_ks_cbs(&ct_res[i], msg_space)];
|
||||
ggsw.append(&mut vec_ggsw);
|
||||
vec_ggsw = ggsw
|
||||
} else {
|
||||
let mut ggsw = self.extract_bit_cbs(&ct_res[i]);
|
||||
ggsw.append(&mut vec_ggsw);
|
||||
vec_ggsw = ggsw;
|
||||
}
|
||||
}
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let sign_ggsw = self.ggsw_ks_cbs(&sign, 0);
|
||||
(vec_ggsw, sign_ggsw)
|
||||
}
|
||||
|
||||
pub fn create_vec_ggsw_after_sub_parallelized(
|
||||
&self,
|
||||
ct_res: &mut [shortint::ciphertext::Ciphertext],
|
||||
sign: &shortint::ciphertext::Ciphertext,
|
||||
len_mantissa: usize,
|
||||
) -> (
|
||||
Vec<FourierGgswCiphertext<ABox<[c64]>>>,
|
||||
FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as usize;
|
||||
|
||||
self.full_propagate_exponent_parallelized(ct_res);
|
||||
|
||||
let vec_ggsw: Vec<_> = ct_res
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.map(|(i, block)| {
|
||||
if (msg_modulus.ilog2() as usize * i) > len_mantissa {
|
||||
vec![self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(&block, msg_space)]
|
||||
} else {
|
||||
self.extract_bit_cbs_parallelized(&block)
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let sign_ggsw = self.ggsw_ks_cbs_parallelized(&sign, 0);
|
||||
(vec_ggsw, sign_ggsw)
|
||||
}
|
||||
}
|
||||
53
concrete-float/src/server_key/division.rs
Normal file
53
concrete-float/src/server_key/division.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::server_key::ServerKey;
|
||||
use tfhe::integer::ciphertext::RadixCiphertext;
|
||||
use tfhe::integer::IntegerCiphertext;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn division(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let log_msg_modulus = f64::log2(msg_modulus as f64) as u64;
|
||||
let len_vec_exp = ct1.ct_vec_exponent.len();
|
||||
let len_vec_man = ct1.ct_vec_mantissa.len();
|
||||
|
||||
let mut res = self.create_trivial_zero(
|
||||
ct1.ct_vec_mantissa.len(),
|
||||
ct1.ct_vec_exponent.len(),
|
||||
ct1.e_min,
|
||||
);
|
||||
let zero = self.create_trivial_zero(
|
||||
ct1.ct_vec_mantissa.len(),
|
||||
ct1.ct_vec_exponent.len(),
|
||||
ct1.e_min,
|
||||
);
|
||||
res.ct_sign = self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign);
|
||||
res.ct_vec_exponent = ct1.ct_vec_exponent.clone();
|
||||
|
||||
let cst = ct1.e_min + len_vec_man as i64 - 1;
|
||||
for i in 0..len_vec_exp {
|
||||
let cst = (cst.abs() as u64) >> (log_msg_modulus * i as u64);
|
||||
self.key.unchecked_scalar_add_assign(
|
||||
&mut res.ct_vec_exponent[i],
|
||||
(cst % msg_modulus) as u8,
|
||||
);
|
||||
}
|
||||
let (res_exp, sign) = self.sub(&res.ct_vec_exponent, &ct2.ct_vec_exponent);
|
||||
res.ct_vec_exponent = res_exp;
|
||||
let mut cct1 = RadixCiphertext::from(ct1.ct_vec_mantissa.clone());
|
||||
let mut cct2 = RadixCiphertext::from(ct2.ct_vec_mantissa.clone());
|
||||
|
||||
let int_key = tfhe::integer::ServerKey::from_shortint_ex(self.key.clone());
|
||||
|
||||
int_key.extend_radix_with_trivial_zero_blocks_lsb_assign(&mut cct1, len_vec_man - 1);
|
||||
int_key.extend_radix_with_trivial_zero_blocks_msb_assign(&mut cct2, len_vec_man - 1);
|
||||
|
||||
let res_mantissa = int_key.unchecked_div_parallelized(&cct1, &cct2);
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let sign_ggsw = self.ggsw_ks_cbs(&sign, 0);
|
||||
|
||||
res.ct_vec_mantissa = res_mantissa.blocks()[..len_vec_man].to_vec();
|
||||
res = self.cmuxes_full(&zero, &res, &sign_ggsw);
|
||||
res
|
||||
}
|
||||
}
|
||||
390
concrete-float/src/server_key/mod.rs
Normal file
390
concrete-float/src/server_key/mod.rs
Normal file
@@ -0,0 +1,390 @@
|
||||
//! Module with the definition of the ServerKey.
|
||||
//!
|
||||
//! This module implements the generation of the server public key, together with all the
|
||||
//! available homomorphic integer operations.
|
||||
mod add;
|
||||
mod align_mantissa;
|
||||
mod division;
|
||||
mod mul;
|
||||
mod relu;
|
||||
mod sigmoid;
|
||||
mod sub;
|
||||
mod tools;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use tfhe::shortint;
|
||||
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::client_key::ClientKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shortint::ciphertext::{Degree, MaxDegree};
|
||||
|
||||
/// Error returned when the carry buffer is full.
|
||||
pub use shortint::CheckError;
|
||||
|
||||
/// A structure containing the server public key.
|
||||
///
|
||||
/// The server key is generated by the client and is meant to be published: the client
|
||||
/// sends it to the server so it can compute homomorphic integer circuits.
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct ServerKey {
|
||||
pub key: shortint::server_key::ServerKey,
|
||||
pub integer_key: tfhe::integer::server_key::ServerKey,
|
||||
pub wopbs_key: shortint::wopbs::WopbsKey,
|
||||
}
|
||||
|
||||
impl ServerKey {
|
||||
/// Generates a server key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::parameters::{PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32};
|
||||
/// use concrete_float::{ClientKey, ServerKey};
|
||||
/// //mantissa and exponent defined over 4 blocks ///
|
||||
/// let size_mantissa = 4;
|
||||
/// let size_exponent = 2;
|
||||
///
|
||||
/// // Generate the client key:
|
||||
/// let param = (PARAM_MESSAGE_2_CARRY_2_32, WOP_PARAM_MESSAGE_2_CARRY_2_32);
|
||||
/// let cks = ClientKey::new(param, size_mantissa, size_exponent);
|
||||
///
|
||||
/// // Generate the server key:
|
||||
/// let sks = ServerKey::new(&cks);
|
||||
/// ```
|
||||
pub fn new(cks: &ClientKey) -> ServerKey {
|
||||
// It should remain just enough space to add a carry
|
||||
let max =
|
||||
(cks.key.parameters.message_modulus().0 - 1) * cks.key.parameters.carry_modulus().0 - 1;
|
||||
let key =
|
||||
shortint::server_key::ServerKey::new_with_max_degree(&cks.key, MaxDegree::new(max));
|
||||
let integer_key = tfhe::integer::server_key::ServerKey::from_shortint_ex(key.clone());
|
||||
let wopbs_key =
|
||||
shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(&cks.key, &key.clone());
|
||||
ServerKey {
|
||||
key,
|
||||
integer_key,
|
||||
wopbs_key,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a ciphertext filled with zeros
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use concrete_float::gen_keys;
|
||||
/// use concrete_shortint::parameters::DEFAULT_PARAMETERS;
|
||||
///
|
||||
/// let size_mantissa = 4;
|
||||
/// let size_exponent = 4;
|
||||
/// let e_min = -2;
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys(&DEFAULT_PARAMETERS, size, size);
|
||||
///
|
||||
/// let ctxt = sks.create_trivial_zero(size_mantissa, size_exponent, e_min, vec![]);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec = cks.decrypt(&ctxt);
|
||||
/// assert_eq!(0, dec);
|
||||
/// ```
|
||||
pub fn create_trivial_zero(
|
||||
&self,
|
||||
size_mantissa: usize,
|
||||
size_exponent: usize,
|
||||
e_min: i64,
|
||||
) -> Ciphertext {
|
||||
let mut vec_res_mantissa = Vec::<shortint::Ciphertext>::with_capacity(size_mantissa);
|
||||
let mut zero = self.key.create_trivial(0_u64);
|
||||
zero.degree = Degree::new(0);
|
||||
for _ in 0..size_mantissa {
|
||||
vec_res_mantissa.push(zero.clone());
|
||||
}
|
||||
|
||||
let mut vec_res_exponent = Vec::<shortint::Ciphertext>::with_capacity(size_exponent);
|
||||
for _ in 0..size_exponent {
|
||||
vec_res_exponent.push(zero.clone());
|
||||
}
|
||||
|
||||
let sign = zero;
|
||||
|
||||
Ciphertext {
|
||||
ct_vec_mantissa: vec_res_mantissa,
|
||||
ct_vec_exponent: vec_res_exponent,
|
||||
ct_sign: sign,
|
||||
e_min,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_trivial_zero_from_ct(&self, ctxt: &Ciphertext) -> Ciphertext {
|
||||
self.create_trivial_zero(
|
||||
ctxt.ct_vec_mantissa.len(),
|
||||
ctxt.ct_vec_exponent.len(),
|
||||
ctxt.e_min,
|
||||
)
|
||||
}
|
||||
|
||||
/// Propagate the carry of the 'index' block to the next one.
|
||||
/// if index is equals to the MS LWE, this operation do nothing.
|
||||
/// We want to keep all the information on this LWE ( with this operation we can't create a
|
||||
/// new LWE
|
||||
pub fn propagate_mantissa(&self, ctxt: &mut [shortint::Ciphertext], index: usize) {
|
||||
if index < ctxt.len() - 1 {
|
||||
let carry = self.key.carry_extract(&ctxt[index]);
|
||||
ctxt[index] = self.key.message_extract(&ctxt[index]);
|
||||
self.key.unchecked_add_assign(&mut ctxt[index + 1], &carry);
|
||||
}
|
||||
//TODO maybe just BS to decrease the noise ?
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
pub fn full_propagate_mantissa(&self, ctxt: &mut [shortint::Ciphertext]) {
|
||||
let len = ctxt.len();
|
||||
for i in 0..len {
|
||||
self.propagate_mantissa(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn propagate_exponent(&self, ctxt: &mut Vec<shortint::Ciphertext>, index: usize) {
|
||||
if index < ctxt.len() - 1 {
|
||||
let carry = self.key.carry_extract(&ctxt[index]);
|
||||
ctxt[index] = self.key.message_extract(&ctxt[index]);
|
||||
self.key.unchecked_add_assign(&mut ctxt[index + 1], &carry);
|
||||
} else {
|
||||
ctxt[index] = self.key.message_extract(&ctxt[index]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
/// except the msb lwe
|
||||
pub fn partial_propagate(&self, ctxt: &mut Vec<shortint::Ciphertext>) {
|
||||
for i in 0..(ctxt.len() - 1) {
|
||||
self.propagate_exponent(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
pub fn full_propagate_exponent(&self, ctxt: &mut Vec<shortint::Ciphertext>) {
|
||||
for i in 0..(ctxt.len()) {
|
||||
self.propagate_exponent(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// boolean bootstrapping
|
||||
pub fn reduce_noise_sign(&self, ctxt: &mut Ciphertext) {
|
||||
let msg_modulus = ctxt.ct_sign.message_modulus.0 as u64;
|
||||
let car_modulus = ctxt.ct_sign.carry_modulus.0 as u64;
|
||||
let msg_space = msg_modulus * car_modulus;
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut ctxt.ct_sign, (msg_space / 2) as u8);
|
||||
let accumulator = self
|
||||
.key
|
||||
.generate_lookup_table(|x| (x & (msg_space / 2)).wrapping_neg());
|
||||
//self.key.keyswitch_programmable_bootstrap_assign(&mut ctxt.ct_sign, &accumulator);
|
||||
self.key
|
||||
.apply_lookup_table_assign(&mut ctxt.ct_sign, &accumulator);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut ctxt.ct_sign, (msg_space / 2) as u8);
|
||||
// We can always add as the sign is managed on the padding bit, the only important thing is
|
||||
// the noise
|
||||
ctxt.ct_sign.degree = Degree::new(0);
|
||||
}
|
||||
|
||||
fn propagate_mantissa_increase_exponent_if_necessary(
|
||||
&self,
|
||||
ctxt: &mut Ciphertext,
|
||||
index: usize,
|
||||
) {
|
||||
if index < ctxt.ct_vec_mantissa.len() - 1 {
|
||||
let carry = self.key.carry_extract(&ctxt.ct_vec_mantissa[index]);
|
||||
ctxt.ct_vec_mantissa[index] = self.key.message_extract(&ctxt.ct_vec_mantissa[index]);
|
||||
self.key
|
||||
.unchecked_add_assign(&mut ctxt.ct_vec_mantissa[index + 1], &carry);
|
||||
} else {
|
||||
self.increase_exponent_if_necessary(ctxt);
|
||||
}
|
||||
}
|
||||
|
||||
fn increase_exponent_if_necessary(&self, ctxt: &mut Ciphertext) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = f64::log2((msg_modulus * car_modulus) as f64) as usize;
|
||||
let len = ctxt.ct_vec_mantissa.len();
|
||||
let carry = self
|
||||
.key
|
||||
.carry_extract(&ctxt.ct_vec_mantissa.last().unwrap());
|
||||
ctxt.ct_vec_mantissa[len - 1] = self
|
||||
.key
|
||||
.message_extract(&ctxt.ct_vec_mantissa.last().clone().unwrap());
|
||||
let mut tmp = ctxt.clone();
|
||||
tmp.ct_vec_mantissa.push(carry.clone());
|
||||
let _ = tmp.ct_vec_mantissa.remove(0);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut tmp.ct_vec_exponent[0], 1);
|
||||
let ggsw_carry = self.ggsw_pbs_ks_cbs(&carry, msg_space);
|
||||
let res = self.cmuxes_full(ctxt, &tmp, &ggsw_carry);
|
||||
ctxt.ct_vec_mantissa = res.ct_vec_mantissa;
|
||||
ctxt.ct_vec_exponent = res.ct_vec_exponent;
|
||||
}
|
||||
|
||||
pub fn full_propagate_mantissa_increase_exponent_if_necessary(&self, ctxt: &mut Ciphertext) {
|
||||
let len = ctxt.ct_vec_mantissa.len();
|
||||
for i in 0..len {
|
||||
self.propagate_mantissa_increase_exponent_if_necessary(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clean_degree(&self, ctxt: &mut Ciphertext) {
|
||||
self.reduce_noise_sign(ctxt);
|
||||
self.full_propagate_exponent(&mut ctxt.ct_vec_exponent);
|
||||
self.full_propagate_mantissa_increase_exponent_if_necessary(ctxt)
|
||||
}
|
||||
|
||||
/// Propagate the carry of the 'index' block to the next one.
|
||||
/// if index is equals to the MS LWE, this operation do nothing.
|
||||
/// We want to keep all the information on this LWE ( with this operation we can't create a
|
||||
/// new LWE
|
||||
pub fn propagate_mantissa_parallelized(&self, ctxt: &mut [shortint::Ciphertext], index: usize) {
|
||||
// todo!("propagate_mantissa_parallelized");
|
||||
if index < ctxt.len() - 1 {
|
||||
let (carry, msg) = rayon::join(
|
||||
|| self.key.carry_extract(&ctxt[index]),
|
||||
|| self.key.message_extract(&ctxt[index]),
|
||||
);
|
||||
ctxt[index] = msg;
|
||||
self.key.unchecked_add_assign(&mut ctxt[index + 1], &carry);
|
||||
}
|
||||
//TODO maybe just BS to decrease the noise ?
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
pub fn full_propagate_mantissa_parallelized(&self, ctxt: &mut [shortint::Ciphertext]) {
|
||||
// todo!("full_propagate_mantissa_parallelized");
|
||||
let len = ctxt.len();
|
||||
for i in 0..len {
|
||||
self.propagate_mantissa_parallelized(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO use the low latency propagation
|
||||
pub fn propagate_exponent_parallelized(&self, ctxt: &mut [shortint::Ciphertext], index: usize) {
|
||||
if index < ctxt.len() - 1 {
|
||||
let (carry, msg) = rayon::join(
|
||||
|| self.key.carry_extract(&ctxt[index]),
|
||||
|| self.key.message_extract(&ctxt[index]),
|
||||
);
|
||||
ctxt[index] = msg;
|
||||
self.key.unchecked_add_assign(&mut ctxt[index + 1], &carry);
|
||||
} else {
|
||||
self.key.message_extract_assign(&mut ctxt[index]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
/// except the msb lwe
|
||||
pub fn partial_propagate_parallelized(&self, ctxt: &mut Vec<shortint::Ciphertext>) {
|
||||
for i in 0..(ctxt.len() - 1) {
|
||||
self.propagate_exponent_parallelized(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate all the carries.
|
||||
pub fn full_propagate_exponent_parallelized(&self, ctxt: &mut [shortint::Ciphertext]) {
|
||||
for i in 0..(ctxt.len()) {
|
||||
self.propagate_exponent_parallelized(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
fn propagate_mantissa_increase_exponent_if_necessary_parallelized(
|
||||
&self,
|
||||
ctxt: &mut Ciphertext,
|
||||
index: usize,
|
||||
) {
|
||||
if index < ctxt.ct_vec_mantissa.len() - 1 {
|
||||
let (carry, msg) = rayon::join(
|
||||
|| self.key.carry_extract(&ctxt.ct_vec_mantissa[index]),
|
||||
|| self.key.message_extract(&ctxt.ct_vec_mantissa[index]),
|
||||
);
|
||||
ctxt.ct_vec_mantissa[index] = msg;
|
||||
self.key
|
||||
.unchecked_add_assign(&mut ctxt.ct_vec_mantissa[index + 1], &carry);
|
||||
} else {
|
||||
self.increase_exponent_if_necessary_parallelized(ctxt);
|
||||
}
|
||||
}
|
||||
|
||||
fn increase_exponent_if_necessary_parallelized(&self, ctxt: &mut Ciphertext) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = f64::log2((msg_modulus * car_modulus) as f64) as usize;
|
||||
let len = ctxt.ct_vec_mantissa.len();
|
||||
let (carry, msg) = rayon::join(
|
||||
|| {
|
||||
self.key
|
||||
.carry_extract(&ctxt.ct_vec_mantissa.last().unwrap())
|
||||
},
|
||||
|| {
|
||||
self.key
|
||||
.message_extract(&ctxt.ct_vec_mantissa.last().clone().unwrap())
|
||||
},
|
||||
);
|
||||
|
||||
ctxt.ct_vec_mantissa[len - 1] = msg;
|
||||
let mut tmp = ctxt.clone();
|
||||
tmp.ct_vec_mantissa.push(carry.clone());
|
||||
let _ = tmp.ct_vec_mantissa.remove(0);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut tmp.ct_vec_exponent[0], 1);
|
||||
let ggsw_carry = self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(&carry, msg_space);
|
||||
let res = self.cmuxes_full_parallelized(ctxt, &tmp, &ggsw_carry);
|
||||
ctxt.ct_vec_mantissa = res.ct_vec_mantissa;
|
||||
ctxt.ct_vec_exponent = res.ct_vec_exponent;
|
||||
}
|
||||
|
||||
fn increase_exponent_if_necessary_parallelized_carry(
|
||||
&self,
|
||||
ctxt: &mut Ciphertext,
|
||||
mantissa_carry: &shortint::Ciphertext,
|
||||
) {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = (msg_modulus * car_modulus).ilog2() as usize;
|
||||
|
||||
let mut tmp = ctxt.clone();
|
||||
tmp.ct_vec_mantissa.push(mantissa_carry.clone());
|
||||
let _ = tmp.ct_vec_mantissa.remove(0);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut tmp.ct_vec_exponent[0], 1);
|
||||
let ggsw_carry =
|
||||
self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(&mantissa_carry, msg_space);
|
||||
let res = self.cmuxes_full_parallelized(ctxt, &tmp, &ggsw_carry);
|
||||
ctxt.ct_vec_mantissa = res.ct_vec_mantissa;
|
||||
ctxt.ct_vec_exponent = res.ct_vec_exponent;
|
||||
}
|
||||
|
||||
pub fn full_propagate_mantissa_increase_exponent_if_necessary_parallelized(
|
||||
&self,
|
||||
ctxt: &mut Ciphertext,
|
||||
) {
|
||||
let len = ctxt.ct_vec_mantissa.len();
|
||||
for i in 0..len {
|
||||
self.propagate_mantissa_increase_exponent_if_necessary_parallelized(ctxt, i);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clean_degree_parallelized(&self, ctxt: &mut Ciphertext) {
|
||||
// todo!("clean_degree_parallelized");
|
||||
self.reduce_noise_sign(ctxt);
|
||||
// let now = std::time::Instant::now();
|
||||
self.full_propagate_exponent_parallelized(&mut ctxt.ct_vec_exponent);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("elapsed exponent propagate: {elapsed:?}");
|
||||
|
||||
// let now = std::time::Instant::now();
|
||||
self.full_propagate_mantissa_increase_exponent_if_necessary_parallelized(ctxt);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("elapsed mantissa propagate: {elapsed:?}");
|
||||
}
|
||||
}
|
||||
373
concrete-float/src/server_key/mul.rs
Normal file
373
concrete-float/src/server_key/mul.rs
Normal file
@@ -0,0 +1,373 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use std::cmp::{max, min};
|
||||
use tfhe::shortint;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn mul(&self, ct1: &mut Ciphertext, ct2: &mut Ciphertext) -> Ciphertext {
|
||||
// carry need to be empty
|
||||
for ct in ct1.ct_vec_mantissa.iter_mut() {
|
||||
if ct.degree.get() > self.wopbs_key.param.message_modulus.0 {
|
||||
self.full_propagate_mantissa(&mut ct1.ct_vec_mantissa);
|
||||
break;
|
||||
}
|
||||
}
|
||||
for ct in ct2.ct_vec_mantissa.iter() {
|
||||
if ct.degree.get() > self.wopbs_key.param.message_modulus.0 {
|
||||
self.full_propagate_mantissa(&mut ct2.ct_vec_mantissa);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut res = self.mul_mantissa(ct1, ct2);
|
||||
res = self.add_exponent_for_mul(&mut res.clone(), ct2);
|
||||
res.ct_sign = self.add_sign_for_mul(ct1, ct2);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn mul_parallelized(
|
||||
&self,
|
||||
ct1: &mut Ciphertext,
|
||||
ct2: &mut Ciphertext,
|
||||
) -> (Ciphertext, shortint::Ciphertext) {
|
||||
// let now = std::time::Instant::now();
|
||||
let (mut res, mantissa_carry) = self.mul_mantissa_parallelized(ct1, ct2);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("mul_mantissa: {elapsed:?}");
|
||||
|
||||
res = self.add_exponent_for_mul_parallelized(&mut res.clone(), ct2, &mantissa_carry);
|
||||
res.ct_sign = self.add_sign_for_mul_parallelized(ct1, ct2);
|
||||
(res, mantissa_carry)
|
||||
}
|
||||
|
||||
fn mul_mantissa(&self, ct1: &mut Ciphertext, ct2: &mut Ciphertext) -> Ciphertext {
|
||||
let mantissa_len = ct1.ct_vec_mantissa.len();
|
||||
let value = (mantissa_len - 1) / 2;
|
||||
let mut result = self.create_trivial_zero(
|
||||
2 * mantissa_len - value - 1,
|
||||
ct1.ct_vec_exponent.len(),
|
||||
ct1.e_min,
|
||||
);
|
||||
|
||||
for (i, ct2_i) in ct2.ct_vec_mantissa.iter().enumerate() {
|
||||
let bound = max((value - i) as i64, 0) as usize;
|
||||
let tmp = self.block_mul(
|
||||
&ct1.ct_vec_mantissa[bound..].to_vec(),
|
||||
ct2_i,
|
||||
i,
|
||||
ct1.ct_vec_mantissa.len(),
|
||||
);
|
||||
if !self.is_add_possible(
|
||||
&tmp,
|
||||
&result.ct_vec_mantissa
|
||||
[min(0, (value - i) as i64).abs() as usize..(i + mantissa_len - value)],
|
||||
) {
|
||||
// we propagate only the necessary blocks,
|
||||
// to not loose any information, we propagate one blocks before and one blocks after
|
||||
self.full_propagate_mantissa(
|
||||
&mut result.ct_vec_mantissa[min(0, (value + 1 - i) as i64).abs() as usize
|
||||
..min(i + mantissa_len + 2 - value, 2 * mantissa_len - 1 - value)],
|
||||
);
|
||||
//self.full_propagate_mantissa(&mut result.ct_vec_mantissa);
|
||||
}
|
||||
for (ct_left_j, ct_right_j) in result.ct_vec_mantissa[min(0, (value - i) as i64).abs()
|
||||
as usize
|
||||
..min(i + mantissa_len + 1 - value, 2 * mantissa_len - 1 - value)]
|
||||
.iter_mut()
|
||||
.zip(tmp.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_j, ct_right_j);
|
||||
}
|
||||
}
|
||||
|
||||
// the (log_msg_modulus * mantissa.len()) most significant bit of a multiplication are
|
||||
// include either in the [mantissa_len, 2*mantissa_len] or in [mantissa_len - 1,
|
||||
// 2*mantissa_len - 1] we choose the first one if the block 2*mantissa_len is not
|
||||
// empty otherwise we choose the first one
|
||||
let mut result_trunc = self.create_trivial_zero_from_ct(ct1);
|
||||
result_trunc.ct_vec_mantissa =
|
||||
result.ct_vec_mantissa[(mantissa_len - 1 - value)..].to_vec();
|
||||
result_trunc.ct_vec_exponent = ct1.ct_vec_exponent.clone();
|
||||
|
||||
result_trunc
|
||||
}
|
||||
|
||||
// Return the float ciphertext and the mantissa carry
|
||||
fn mul_mantissa_parallelized(
|
||||
&self,
|
||||
ct1: &Ciphertext,
|
||||
ct2: &Ciphertext,
|
||||
) -> (Ciphertext, shortint::Ciphertext) {
|
||||
use tfhe::integer::{IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext};
|
||||
|
||||
let mantissa_len = ct1.ct_vec_mantissa.len();
|
||||
let mantissa_len_for_mul_with_carry = mantissa_len * 2;
|
||||
let mut ct1_mantissa = ct1.ct_vec_mantissa.to_vec();
|
||||
ct1_mantissa.resize(mantissa_len_for_mul_with_carry, self.key.create_trivial(0));
|
||||
let mut ct2_mantissa = ct2.ct_vec_mantissa.to_vec();
|
||||
ct2_mantissa.resize(mantissa_len_for_mul_with_carry, self.key.create_trivial(0));
|
||||
let ct1_mantissa_as_integer = RadixCiphertext::from_blocks(ct1_mantissa);
|
||||
let ct2_mantissa_as_integer = RadixCiphertext::from_blocks(ct2_mantissa);
|
||||
|
||||
// println!("ct1_len = {}", ct1_mantissa_as_integer.blocks().len());
|
||||
// println!("ct2_len = {}", ct2_mantissa_as_integer.blocks().len());
|
||||
|
||||
// let now = std::time::Instant::now();
|
||||
let mul_result = self
|
||||
.integer_key
|
||||
.mul_parallelized(&ct1_mantissa_as_integer, &ct2_mantissa_as_integer);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("integer mul: {elapsed:?}");
|
||||
|
||||
let mut mul_result_blocks = mul_result.into_blocks();
|
||||
let carry_block = mul_result_blocks.pop().unwrap();
|
||||
let mantissa = mul_result_blocks[mantissa_len - 1..].to_vec();
|
||||
assert_eq!(mantissa.len(), ct1.ct_vec_mantissa.len());
|
||||
let mut result_trunc = self.create_trivial_zero_from_ct(ct1);
|
||||
result_trunc.ct_vec_mantissa = mantissa;
|
||||
result_trunc.ct_vec_exponent = ct1.ct_vec_exponent.clone();
|
||||
|
||||
(result_trunc, carry_block)
|
||||
}
|
||||
|
||||
// multiply one block of a mantissa by each block of another mantissa and create a mantissa of
|
||||
// this mul
|
||||
fn block_mul(
|
||||
&self,
|
||||
ct1: &Vec<shortint::ciphertext::Ciphertext>,
|
||||
ct2: &shortint::ciphertext::Ciphertext,
|
||||
index: usize,
|
||||
len_man: usize,
|
||||
) -> Vec<shortint::ciphertext::Ciphertext> {
|
||||
let zero = self.key.create_trivial(0);
|
||||
let mut result = vec![zero.clone()];
|
||||
let mut result_lsb = ct1.clone();
|
||||
let mut result_msb = ct1.clone();
|
||||
if index != len_man - 1 {
|
||||
for (ct_lsb_i, ct_msb_i) in result_lsb.iter_mut().zip(result_msb.iter_mut()) {
|
||||
self.key.unchecked_mul_msb_assign(ct_msb_i, ct2);
|
||||
self.key.unchecked_mul_lsb_assign(ct_lsb_i, ct2);
|
||||
}
|
||||
result_lsb.push(zero.clone());
|
||||
result.append(&mut result_msb.clone());
|
||||
} else {
|
||||
for (ct_lsb_i, ct_msb_i) in result_lsb[..len_man - 1]
|
||||
.iter_mut()
|
||||
.zip(result_msb[..len_man - 1].iter_mut())
|
||||
{
|
||||
self.key.unchecked_mul_msb_assign(ct_msb_i, ct2);
|
||||
self.key.unchecked_mul_lsb_assign(ct_lsb_i, ct2);
|
||||
}
|
||||
|
||||
let msg_mod = self.key.message_modulus.0 as u64;
|
||||
let tmp = self.key.unchecked_scalar_mul(ct2, msg_mod as u8);
|
||||
self.key
|
||||
.unchecked_add_assign(result_lsb.last_mut().unwrap(), &tmp);
|
||||
|
||||
// Generate the accumulator for the multiplication
|
||||
let acc = self
|
||||
.key
|
||||
.generate_lookup_table(|x| (x / msg_mod) * (x % msg_mod));
|
||||
self.key
|
||||
.apply_lookup_table_assign(result_lsb.last_mut().unwrap(), &acc);
|
||||
|
||||
result.append(&mut result_msb.clone());
|
||||
result.pop();
|
||||
}
|
||||
|
||||
for (ct1_i, ct2_i) in result.iter_mut().zip(result_lsb.iter()) {
|
||||
self.key.unchecked_add_assign(ct1_i, ct2_i)
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
//sum the two sign for the mul
|
||||
fn add_sign_for_mul(
|
||||
&self,
|
||||
ct1: &mut Ciphertext,
|
||||
ct2: &mut Ciphertext,
|
||||
) -> shortint::ciphertext::Ciphertext {
|
||||
if self
|
||||
.key
|
||||
.is_add_possible(&ct1.ct_sign, &ct2.ct_sign)
|
||||
.is_err()
|
||||
{
|
||||
self.reduce_noise_sign(ct1);
|
||||
self.reduce_noise_sign(ct2);
|
||||
}
|
||||
self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign)
|
||||
}
|
||||
|
||||
fn add_sign_for_mul_parallelized(
|
||||
&self,
|
||||
ct1: &mut Ciphertext,
|
||||
ct2: &mut Ciphertext,
|
||||
) -> shortint::ciphertext::Ciphertext {
|
||||
if self
|
||||
.key
|
||||
.is_add_possible(&ct1.ct_sign, &ct2.ct_sign)
|
||||
.is_err()
|
||||
{
|
||||
rayon::join(
|
||||
|| self.reduce_noise_sign(ct1),
|
||||
|| self.reduce_noise_sign(ct2),
|
||||
);
|
||||
}
|
||||
self.key.unchecked_add(&ct1.ct_sign, &ct2.ct_sign)
|
||||
}
|
||||
|
||||
// add the two exponent and subtract the value e_min and the shift on the MSB blocks
|
||||
fn add_exponent_for_mul(&self, ct1: &mut Ciphertext, ct2: &mut Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let carry_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let log_msg_modulus = f64::log2(msg_modulus as f64) as u64;
|
||||
let log_msg_space = f64::log2((carry_modulus * msg_modulus) as f64) as usize;
|
||||
let len_vec_exp = ct1.ct_vec_exponent.len();
|
||||
|
||||
if !self.is_add_possible(&ct1.ct_vec_exponent, &ct2.ct_vec_exponent) {
|
||||
self.partial_propagate(&mut ct1.ct_vec_exponent);
|
||||
self.partial_propagate(&mut ct2.ct_vec_exponent);
|
||||
}
|
||||
let mut res = ct1.clone();
|
||||
for (ct_left_j, ct_right_j) in res
|
||||
.ct_vec_exponent
|
||||
.iter_mut()
|
||||
.zip(ct2.ct_vec_exponent.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_j, ct_right_j);
|
||||
}
|
||||
let cst = ct1.e_min + ct1.ct_vec_mantissa.len() as i64 - 1;
|
||||
let cst = (cst.abs() as u64) >> (log_msg_modulus * (len_vec_exp - 1) as u64);
|
||||
|
||||
//check if the exponent is big enough (return 1 if e is to small, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x < cst) as u64));
|
||||
let mut ct_sign = self
|
||||
.key
|
||||
.apply_lookup_table(&mut res.ct_vec_exponent.last().unwrap(), &accumulator);
|
||||
|
||||
//check if the mantissa is not equals to zero (return 1 if ms_lwe== 0, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x == 0) as u64));
|
||||
let ms_lwe = self
|
||||
.key
|
||||
.apply_lookup_table(&mut ct1.ct_vec_mantissa.last().unwrap(), &accumulator);
|
||||
self.key.unchecked_add_assign(&mut ct_sign, &ms_lwe);
|
||||
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x > 0) as u64));
|
||||
let ct_sign = self.key.apply_lookup_table(&mut ct_sign, &accumulator);
|
||||
|
||||
let sign_ggsw = self.ggsw_ks_cbs(&ct_sign, log_msg_space);
|
||||
let zero = self.create_trivial_zero_from_ct(ct1);
|
||||
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x - cst) % msg_modulus);
|
||||
self.key
|
||||
.apply_lookup_table_assign(&mut res.ct_vec_exponent[len_vec_exp - 1], &accumulator);
|
||||
res = self.cmuxes_full(&res, &zero, &sign_ggsw);
|
||||
res
|
||||
}
|
||||
|
||||
// add the two exponent and subtract the value e_min and the shift on the MSB blocks
|
||||
fn add_exponent_for_mul_parallelized(
|
||||
&self,
|
||||
ct1: &mut Ciphertext,
|
||||
ct2: &mut Ciphertext,
|
||||
mantissa_carry: &shortint::Ciphertext,
|
||||
) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let carry_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let log_msg_modulus = msg_modulus.ilog2() as u64;
|
||||
let log_msg_space = (carry_modulus * msg_modulus).ilog2() as usize;
|
||||
let len_vec_exp = ct1.ct_vec_exponent.len();
|
||||
|
||||
if !self.is_add_possible(&ct1.ct_vec_exponent, &ct2.ct_vec_exponent) {
|
||||
rayon::join(
|
||||
|| self.partial_propagate(&mut ct1.ct_vec_exponent),
|
||||
|| self.partial_propagate(&mut ct2.ct_vec_exponent),
|
||||
);
|
||||
}
|
||||
let mut res = ct1.clone();
|
||||
for (ct_left_j, ct_right_j) in res
|
||||
.ct_vec_exponent
|
||||
.iter_mut()
|
||||
.zip(ct2.ct_vec_exponent.iter())
|
||||
{
|
||||
self.key.unchecked_add_assign(ct_left_j, ct_right_j);
|
||||
}
|
||||
let cst = ct1.e_min + ct1.ct_vec_mantissa.len() as i64 - 1;
|
||||
let cst = (cst.abs() as u64) >> (log_msg_modulus * (len_vec_exp - 1) as u64);
|
||||
|
||||
let (mut ct_sign, ms_lwe) = rayon::join(
|
||||
|| {
|
||||
//check if the exponent is big enough (return 1 if e is to small, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x < cst) as u64));
|
||||
self.key
|
||||
.apply_lookup_table(&mut res.ct_vec_exponent.last().unwrap(), &accumulator)
|
||||
},
|
||||
|| {
|
||||
//check if the mantissa is not equals to zero (return 1 if ms_lwe== 0, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x == 0) as u64));
|
||||
let mut last_mantissa_block = ct1.ct_vec_mantissa.last().unwrap().clone();
|
||||
// We recreate a mantissa block containing the msg + carry as we only want to know
|
||||
// if it was 0
|
||||
self.key
|
||||
.unchecked_add_assign(&mut last_mantissa_block, &mantissa_carry);
|
||||
self.key
|
||||
.apply_lookup_table(&last_mantissa_block, &accumulator)
|
||||
},
|
||||
);
|
||||
|
||||
self.key.unchecked_add_assign(&mut ct_sign, &ms_lwe);
|
||||
|
||||
rayon::join(
|
||||
|| {
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x > 0) as u64));
|
||||
self.key
|
||||
.apply_lookup_table_assign(&mut ct_sign, &accumulator);
|
||||
},
|
||||
|| {
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x - cst) % msg_modulus);
|
||||
self.key.apply_lookup_table_assign(
|
||||
&mut res.ct_vec_exponent[len_vec_exp - 1],
|
||||
&accumulator,
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
let sign_ggsw = self.ggsw_ks_cbs_parallelized(&ct_sign, log_msg_space);
|
||||
|
||||
let zero = self.create_trivial_zero_from_ct(ct1);
|
||||
res = self.cmuxes_full_parallelized(&res, &zero, &sign_ggsw);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn mul_total(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let mut res = self.mul(&mut ct1.clone(), &mut ct2.clone());
|
||||
self.clean_degree(&mut res);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn mul_total_parallelized(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
// let now = std::time::Instant::now();
|
||||
let (mut res, mantissa_carry) = self.mul_parallelized(&mut ct1.clone(), &mut ct2.clone());
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("mul_parallelized: {elapsed:?}");
|
||||
|
||||
// self.clean_degree_parallelized(&mut res);
|
||||
|
||||
self.reduce_noise_sign(&mut res);
|
||||
// let now = std::time::Instant::now();
|
||||
self.full_propagate_exponent_parallelized(&mut res.ct_vec_exponent);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("elapsed exponent propagate: {elapsed:?}");
|
||||
|
||||
// let now = std::time::Instant::now();
|
||||
// No need to propagate the mantissa it is clean after the integer mul parallelized
|
||||
// self.full_propagate_mantissa_increase_exponent_if_necessary_parallelized(&mut res);
|
||||
|
||||
// TODO change the management of the carry
|
||||
self.increase_exponent_if_necessary_parallelized_carry(&mut res, &mantissa_carry);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("elapsed mantissa propagate: {elapsed:?}");
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
10
concrete-float/src/server_key/relu.rs
Normal file
10
concrete-float/src/server_key/relu.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn relu(&self, ct: &Ciphertext) -> Ciphertext {
|
||||
let zero = self.create_trivial_zero_from_ct(ct);
|
||||
let ggsw = self.ggsw_ks_cbs(&ct.ct_sign, 0);
|
||||
self.cmuxes_full(&ct, &zero, &ggsw)
|
||||
}
|
||||
}
|
||||
43
concrete-float/src/server_key/sigmoid.rs
Normal file
43
concrete-float/src/server_key/sigmoid.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::server_key::ServerKey;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn sigmoid(&self, ct: &Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let carry_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let log_msg_modulus = f64::log2(msg_modulus as f64) as u64;
|
||||
let log_carry_modulus = f64::log2(carry_modulus as f64) as u64;
|
||||
let cst = ct.e_min + ct.ct_vec_mantissa.len() as i64 - 1;
|
||||
let cst = (cst.abs() as u64) >> (log_msg_modulus * (ct.ct_vec_exponent.len() - 1) as u64);
|
||||
|
||||
let mut one = self.create_trivial_zero_from_ct(ct);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut one.ct_vec_mantissa.last_mut().unwrap(), 1 as u8);
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut one.ct_vec_exponent.last_mut().unwrap(), cst as u8);
|
||||
|
||||
let mut minus_one = one.clone();
|
||||
self.change_sign_assign(&mut minus_one);
|
||||
let ggsw = self.ggsw_ks_cbs(&ct.ct_sign, 0);
|
||||
let tmp = self.cmuxes_full(&one, &minus_one, &ggsw);
|
||||
|
||||
let value = msg_modulus / 2;
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x > value) as u64);
|
||||
let ct_last = self
|
||||
.key
|
||||
.apply_lookup_table(&mut ct.ct_vec_mantissa.last().unwrap(), &accumulator);
|
||||
|
||||
//check if the exponent is big enough (return 1 if e is to small, 0 otherwise)
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x < cst) as u64));
|
||||
let mut ct_sign = self
|
||||
.key
|
||||
.apply_lookup_table(&mut ct.ct_vec_exponent.last().unwrap(), &accumulator);
|
||||
|
||||
self.key.unchecked_add_assign(&mut ct_sign, &ct_last);
|
||||
let accumulator = self.key.generate_lookup_table(|x| ((x > 0) as u64));
|
||||
let ct_sign = self.key.apply_lookup_table(&mut ct_sign, &accumulator);
|
||||
|
||||
let ggsw = self.ggsw_ks_cbs(&ct_sign, (log_carry_modulus + log_msg_modulus) as usize);
|
||||
self.cmuxes_full(&tmp, &ct, &ggsw)
|
||||
}
|
||||
}
|
||||
448
concrete-float/src/server_key/sub.rs
Normal file
448
concrete-float/src/server_key/sub.rs
Normal file
@@ -0,0 +1,448 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use rayon::prelude::*;
|
||||
use shortint::ciphertext::Degree;
|
||||
use std::cmp::max;
|
||||
use tfhe::core_crypto::prelude::{Cleartext, Plaintext};
|
||||
use tfhe::shortint;
|
||||
|
||||
impl ServerKey {
|
||||
// This operation return |a - b| and sing(a-b)
|
||||
// after sub all the blocks have the smallest degree except the most significant block
|
||||
pub fn sub(
|
||||
&self,
|
||||
ctxt_left: &Vec<shortint::Ciphertext>,
|
||||
ctxt_right: &Vec<shortint::Ciphertext>,
|
||||
) -> (Vec<shortint::Ciphertext>, shortint::Ciphertext) {
|
||||
let mut ct_tmp: Vec<shortint::Ciphertext> = Vec::new();
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as u64;
|
||||
let size_ct = ctxt_left.len();
|
||||
for ct in ctxt_left.iter() {
|
||||
ct_tmp.push(
|
||||
self.key
|
||||
.unchecked_scalar_add(ct, ((msg_space / 2) - car_modulus / 2) as u8),
|
||||
);
|
||||
}
|
||||
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut ct_tmp[0], (car_modulus / 2) as u8);
|
||||
let cpy_right = ctxt_right.clone();
|
||||
for (c_left, c_right) in ct_tmp.iter_mut().zip(cpy_right.iter()) {
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_sub_assign(&mut c_left.ct, &c_right.ct);
|
||||
let noise_level = c_left.noise_level() + c_right.noise_level();
|
||||
c_left.set_noise_level(noise_level);
|
||||
}
|
||||
self.partial_propagate(&mut ct_tmp);
|
||||
//extract the sign (the first value add on the most significant block)
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x & (msg_space / 2)));
|
||||
let mut sign = self
|
||||
.key
|
||||
.apply_lookup_table(ct_tmp.last_mut().unwrap(), &accumulator);
|
||||
// the value sign encrypt only 1 or 0 so the degree is 1
|
||||
|
||||
// We can always add as the sign is managed on the padding bit, the only important thing is
|
||||
// the noise
|
||||
sign.degree = Degree::new(0);
|
||||
|
||||
// add the sign on each block
|
||||
for i in 0..(size_ct - 1) {
|
||||
self.key.unchecked_add_assign(&mut ct_tmp[i], &sign);
|
||||
}
|
||||
|
||||
// if the sign on each block ==0, we take the opposite, otherwise we return the value.
|
||||
// to find the opposite we perform the same idea than the subtraction (but only with pbs as
|
||||
// we know one value ) opposite = (1 << (len * precision)) - x
|
||||
for (i, ct) in ct_tmp.iter_mut().enumerate() {
|
||||
if i == 0 {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_modulus - x))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_modulus - x)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
ct.degree = Degree::new(msg_modulus as usize)
|
||||
} else if i == size_ct - 1 {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_space / 2 - x - 1))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_space / 2 - x - 1)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
ct.degree = Degree::new(max(
|
||||
(msg_space as usize / 2) - ct.degree.get(),
|
||||
ct.degree.get(),
|
||||
));
|
||||
} else {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_modulus - x - 1))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_modulus - x - 1)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
ct.degree = Degree::new(msg_modulus as usize)
|
||||
}
|
||||
}
|
||||
// move the sign bit on the msb
|
||||
// uncheck add, we juste create the sign
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_cleartext_mul_assign(
|
||||
&mut sign.ct,
|
||||
Cleartext(2),
|
||||
);
|
||||
//self.key.unchecked_scalar_mul_assign(&mut sign, 2);
|
||||
(ct_tmp, sign)
|
||||
}
|
||||
|
||||
// subtract the two mantissas
|
||||
// after the subtraction put the msb of the result on the mst significant block
|
||||
// if exponent == 0 and the first block == 0, the result is 0
|
||||
pub fn sub_mantissa(&self, ctxt_left: &Ciphertext, ctxt_right: &Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as usize;
|
||||
let (res, sign) = self.sub(&ctxt_left.ct_vec_mantissa, &ctxt_right.ct_vec_mantissa);
|
||||
|
||||
let mut new = self.create_trivial_zero_from_ct(ctxt_left);
|
||||
new.ct_vec_mantissa = res;
|
||||
new.ct_vec_exponent = ctxt_left.ct_vec_exponent.clone();
|
||||
// if sign == 0 => need to change the sign of the operation
|
||||
// if sign == 1 we want to keep the same sign
|
||||
// new_s = old_s + sign + 1
|
||||
new.ct_sign = self.key.unchecked_add(&sign, ctxt_left.sign());
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut new.ct_sign, msg_space as u8);
|
||||
|
||||
new = self.realign_sub(&new);
|
||||
new
|
||||
}
|
||||
|
||||
// move the msb on the most significant block.
|
||||
// if e = 0 and the first block is empty, return zero
|
||||
// (no subnormal value)
|
||||
pub fn realign_sub(&self, ct0: &Ciphertext) -> Ciphertext {
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = f64::log2((msg_modulus * car_modulus) as f64) as usize;
|
||||
let size_mantissa = ct0.ct_vec_mantissa.len();
|
||||
|
||||
let zero = self.create_trivial_zero_from_ct(ct0);
|
||||
let mut res = zero.clone();
|
||||
res.ct_vec_mantissa = ct0.ct_vec_mantissa.clone();
|
||||
res.ct_sign = ct0.ct_sign.clone();
|
||||
let mut msb_mantissa_ggsw =
|
||||
self.ggsw_pbs_ks_cbs(&res.ct_vec_mantissa[size_mantissa - 1], msg_space);
|
||||
for i in 0..size_mantissa {
|
||||
let mut tmp = zero.clone();
|
||||
tmp.ct_vec_mantissa = zero.ct_vec_mantissa.clone();
|
||||
for j in 0..(size_mantissa - 1) {
|
||||
tmp.ct_vec_mantissa[j + 1] = res.ct_vec_mantissa[j].clone();
|
||||
}
|
||||
for (k, ct_exp_i) in tmp.ct_vec_exponent.iter_mut().enumerate() {
|
||||
self.key.unchecked_scalar_add_assign(
|
||||
ct_exp_i,
|
||||
(((i + 1) >> (f64::log2(msg_modulus as f64) as usize * (k))) % msg_modulus)
|
||||
as u8,
|
||||
);
|
||||
}
|
||||
|
||||
// return tmp if ggsw == 0; res otherwise
|
||||
res.ct_vec_mantissa = self.cmuxes(
|
||||
&tmp.ct_vec_mantissa,
|
||||
&res.ct_vec_mantissa,
|
||||
&msb_mantissa_ggsw,
|
||||
);
|
||||
res.ct_vec_exponent = self.cmuxes(
|
||||
&tmp.ct_vec_exponent,
|
||||
&res.ct_vec_exponent,
|
||||
&msb_mantissa_ggsw,
|
||||
);
|
||||
|
||||
if i < size_mantissa - 1 {
|
||||
msb_mantissa_ggsw =
|
||||
self.ggsw_pbs_ks_cbs(&res.ct_vec_mantissa[size_mantissa - 1], msg_space);
|
||||
}
|
||||
}
|
||||
|
||||
let (mut diff_exp, sub_exp_sign) = self.sub(&ct0.ct_vec_exponent, &res.ct_vec_exponent);
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let sign_ggsw = self.ggsw_ks_cbs(&sub_exp_sign, 0); //let sign_ggsw = self.wopbs_key.extract_one_bit_cbs(&self.key, &sub_exp_sign, 63);
|
||||
diff_exp = self.cmuxes(&zero.ct_vec_exponent, &diff_exp, &msb_mantissa_ggsw);
|
||||
res.ct_vec_exponent = self.cmuxes(&zero.ct_vec_exponent, &diff_exp, &sign_ggsw);
|
||||
res.ct_vec_mantissa = self.cmuxes(&zero.ct_vec_mantissa, &res.ct_vec_mantissa, &sign_ggsw);
|
||||
res.ct_sign = res.ct_sign;
|
||||
res
|
||||
}
|
||||
|
||||
// change the sign
|
||||
pub fn change_sign_assign(&self, ct0: &mut Ciphertext) {
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_plaintext_add_assign(
|
||||
&mut ct0.ct_sign.ct,
|
||||
Plaintext(1 << 63),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn change_sign(&self, ct0: &Ciphertext) -> Ciphertext {
|
||||
let mut ct = ct0.clone();
|
||||
self.change_sign_assign(&mut ct);
|
||||
ct
|
||||
}
|
||||
|
||||
pub fn sub_total(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let ct2 = self.change_sign(ct2);
|
||||
self.add_total(&ct1, &ct2)
|
||||
}
|
||||
|
||||
// This operation return |a - b| and sing(a-b)
|
||||
// after sub all the blocks have the smallest degree except the most significant block
|
||||
// TODO: would the overflowing_sub from integer (with some slight adaptations perhaps) do the
|
||||
// trick ?
|
||||
pub fn abs_diff_parallelized(
|
||||
&self,
|
||||
ctxt_left: &Vec<shortint::Ciphertext>,
|
||||
ctxt_right: &Vec<shortint::Ciphertext>,
|
||||
) -> (Vec<shortint::Ciphertext>, shortint::Ciphertext) {
|
||||
let mut ct_tmp: Vec<shortint::Ciphertext> = Vec::with_capacity(ctxt_left.len());
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as u64;
|
||||
let size_ct = ctxt_left.len();
|
||||
for ct in ctxt_left.iter() {
|
||||
ct_tmp.push(
|
||||
self.key
|
||||
.unchecked_scalar_add(ct, ((msg_space / 2) - car_modulus / 2) as u8),
|
||||
);
|
||||
}
|
||||
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut ct_tmp[0], (car_modulus / 2) as u8);
|
||||
let cpy_right = ctxt_right.clone();
|
||||
// The operation is too small to be worth parallelizing
|
||||
ct_tmp
|
||||
.iter_mut()
|
||||
.zip(cpy_right.iter())
|
||||
.for_each(|(c_left, c_right)| {
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_sub_assign(
|
||||
&mut c_left.ct,
|
||||
&c_right.ct,
|
||||
);
|
||||
let noise_level = c_left.noise_level() + c_right.noise_level();
|
||||
c_left.set_noise_level(noise_level);
|
||||
});
|
||||
|
||||
self.partial_propagate_parallelized(&mut ct_tmp);
|
||||
//extract the sign (the first value add on the most significant block)
|
||||
let accumulator = self.key.generate_lookup_table(|x| (x & (msg_space / 2)));
|
||||
let mut sign = self
|
||||
.key
|
||||
.apply_lookup_table(ct_tmp.last_mut().unwrap(), &accumulator);
|
||||
// the value sign encrypt only 1 or 0 so the degree is 1
|
||||
|
||||
// We can always add as the sign is managed on the padding bit, the only important thing is
|
||||
// the noise
|
||||
sign.degree = Degree::new(0);
|
||||
|
||||
// add the sign on each block, except the last one
|
||||
// Operation is too small to be worth parallelizing
|
||||
ct_tmp[0..(size_ct - 1)]
|
||||
.iter_mut()
|
||||
.for_each(|tmp_block| self.key.unchecked_add_assign(tmp_block, &sign));
|
||||
|
||||
// if the sign on each block ==0, we take the opposite, otherwise we return the value.
|
||||
// to find the opposite we perform the same idea than the subtraction (but only with pbs as
|
||||
// we know one value ) opposite = (1 << (len * precision)) - x
|
||||
ct_tmp.par_iter_mut().enumerate().for_each(|(i, ct)| {
|
||||
if i == 0 {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_modulus - x))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_modulus - x)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
} else if i == size_ct - 1 {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_space / 2 - x - 1))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_space / 2 - x - 1)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
} else {
|
||||
let accumulator = self.key.generate_lookup_table(|x| {
|
||||
(((x - (msg_space / 2)) - (msg_modulus - x - 1))
|
||||
* ((x & (msg_space / 2)) / (msg_space / 2)))
|
||||
+ (msg_modulus - x - 1)
|
||||
});
|
||||
self.key.apply_lookup_table_assign(ct, &accumulator);
|
||||
}
|
||||
});
|
||||
|
||||
// move the sign bit on the msb
|
||||
// uncheck add, we juste create the sign
|
||||
tfhe::core_crypto::algorithms::lwe_ciphertext_cleartext_mul_assign(
|
||||
&mut sign.ct,
|
||||
Cleartext(2),
|
||||
);
|
||||
//self.key.unchecked_scalar_mul_assign(&mut sign, 2);
|
||||
(ct_tmp, sign)
|
||||
}
|
||||
|
||||
// subtract the two mantissas
|
||||
// after the subtraction put the msb of the result on the mst significant block
|
||||
// if exponent == 0 and the first block == 0, the result is 0
|
||||
pub fn sub_mantissa_parallelized(
|
||||
&self,
|
||||
ctxt_left: &Ciphertext,
|
||||
ctxt_right: &Ciphertext,
|
||||
) -> Ciphertext {
|
||||
// todo!("sub_mantissa_parallelized");
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as u64;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as u64;
|
||||
let msg_space = (msg_modulus * car_modulus) as usize;
|
||||
// let now = std::time::Instant::now();
|
||||
let (res, sign) =
|
||||
self.abs_diff_parallelized(&ctxt_left.ct_vec_mantissa, &ctxt_right.ct_vec_mantissa);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("sub_mantissa_parallelized::sub_parallelized: {elapsed:?}");
|
||||
|
||||
let mut new = self.create_trivial_zero_from_ct(ctxt_left);
|
||||
new.ct_vec_mantissa = res;
|
||||
new.ct_vec_exponent = ctxt_left.ct_vec_exponent.clone();
|
||||
// if sign == 0 => need to change the sign of the operation
|
||||
// if sign == 1 we want to keep the same sign
|
||||
// new_s = old_s + sign + 1
|
||||
new.ct_sign = self.key.unchecked_add(&sign, ctxt_left.sign());
|
||||
self.key
|
||||
.unchecked_scalar_add_assign(&mut new.ct_sign, msg_space as u8);
|
||||
|
||||
// let now = std::time::Instant::now();
|
||||
new = self.realign_sub_parallelized(&new);
|
||||
// let elapsed = now.elapsed();
|
||||
// println!("sub_mantissa_parallelized::realign_sub_parallelized: {elapsed:?}");
|
||||
|
||||
new
|
||||
}
|
||||
|
||||
// move the msb on the most significant block.
|
||||
// if e = 0 and the first block is empty, return zero
|
||||
// (no subnormal value)
|
||||
pub fn realign_sub_parallelized(&self, ct0: &Ciphertext) -> Ciphertext {
|
||||
// todo!("realign_sub_parallelized");
|
||||
let msg_modulus = self.wopbs_key.param.message_modulus.0 as usize;
|
||||
let car_modulus = self.wopbs_key.param.carry_modulus.0 as usize;
|
||||
let msg_space = (msg_modulus * car_modulus).ilog2() as usize;
|
||||
let size_mantissa = ct0.ct_vec_mantissa.len();
|
||||
|
||||
let zero = self.create_trivial_zero_from_ct(ct0);
|
||||
|
||||
let cmux_tree_size = if size_mantissa.is_power_of_two() {
|
||||
size_mantissa
|
||||
} else {
|
||||
size_mantissa.next_power_of_two()
|
||||
};
|
||||
|
||||
let mut ciphertexts_to_cmux: Vec<Ciphertext> = Vec::with_capacity(cmux_tree_size);
|
||||
let mut cmux_outputs: Vec<Ciphertext> = Vec::with_capacity(cmux_tree_size / 2);
|
||||
|
||||
(0..cmux_tree_size)
|
||||
.into_par_iter()
|
||||
.map(|ciphertext_idx| {
|
||||
if ciphertext_idx < size_mantissa {
|
||||
let mut ciphertext = zero.clone();
|
||||
|
||||
for (k, ct_exp_i) in ciphertext.ct_vec_exponent.iter_mut().enumerate() {
|
||||
self.key.unchecked_scalar_add_assign(
|
||||
ct_exp_i,
|
||||
((ciphertext_idx >> (msg_modulus.ilog2() as usize * (k))) % msg_modulus)
|
||||
as u8,
|
||||
);
|
||||
}
|
||||
|
||||
let exponent_block_count = size_mantissa - ciphertext_idx;
|
||||
ciphertext.ct_vec_mantissa[ciphertext_idx..]
|
||||
.clone_from_slice(&ct0.ct_vec_mantissa[..exponent_block_count]);
|
||||
|
||||
ciphertext
|
||||
} else {
|
||||
zero.clone()
|
||||
}
|
||||
})
|
||||
.collect_into_vec(&mut ciphertexts_to_cmux);
|
||||
|
||||
while ciphertexts_to_cmux.len() > 1 {
|
||||
ciphertexts_to_cmux
|
||||
.par_chunks_exact(2)
|
||||
.map(|chunk| {
|
||||
let less_modified_exponent = &chunk[0];
|
||||
let more_modified_exponent = &chunk[1];
|
||||
|
||||
let msb_mantissa_ggsw = self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(
|
||||
&less_modified_exponent.ct_vec_mantissa[size_mantissa - 1],
|
||||
msg_space,
|
||||
);
|
||||
|
||||
// return tmp if ggsw == 0; res otherwise
|
||||
let (mantissa, exponent) = rayon::join(
|
||||
|| {
|
||||
self.cmuxes_parallelized(
|
||||
&more_modified_exponent.ct_vec_mantissa,
|
||||
&less_modified_exponent.ct_vec_mantissa,
|
||||
&msb_mantissa_ggsw,
|
||||
)
|
||||
},
|
||||
|| {
|
||||
self.cmuxes_parallelized(
|
||||
&more_modified_exponent.ct_vec_exponent,
|
||||
&less_modified_exponent.ct_vec_exponent,
|
||||
&msb_mantissa_ggsw,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
let mut res = zero.clone();
|
||||
res.ct_vec_exponent = exponent;
|
||||
res.ct_vec_mantissa = mantissa;
|
||||
|
||||
res
|
||||
})
|
||||
.collect_into_vec(&mut cmux_outputs);
|
||||
|
||||
std::mem::swap(&mut ciphertexts_to_cmux, &mut cmux_outputs);
|
||||
}
|
||||
|
||||
let mut res = ciphertexts_to_cmux.into_iter().next().unwrap();
|
||||
|
||||
let (mut diff_exp, sub_exp_sign) =
|
||||
self.abs_diff_parallelized(&ct0.ct_vec_exponent, &res.ct_vec_exponent);
|
||||
|
||||
// message space == 0 because the sign is on the padding bit
|
||||
let (sign_ggsw, msb_mantissa_ggsw) = rayon::join(
|
||||
|| self.ggsw_ks_cbs_parallelized(&sub_exp_sign, 0),
|
||||
|| {
|
||||
self.is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(
|
||||
&res.ct_vec_mantissa[size_mantissa - 1],
|
||||
msg_space,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
let (exponent, mantissa) = rayon::join(
|
||||
|| {
|
||||
diff_exp =
|
||||
self.cmuxes_parallelized(&zero.ct_vec_exponent, &diff_exp, &msb_mantissa_ggsw);
|
||||
self.cmuxes_parallelized(&zero.ct_vec_exponent, &diff_exp, &sign_ggsw)
|
||||
},
|
||||
|| self.cmuxes_parallelized(&zero.ct_vec_mantissa, &res.ct_vec_mantissa, &sign_ggsw),
|
||||
);
|
||||
|
||||
res.ct_vec_exponent = exponent;
|
||||
res.ct_vec_mantissa = mantissa;
|
||||
res.ct_sign = ct0.ct_sign.clone();
|
||||
res
|
||||
}
|
||||
|
||||
pub fn sub_total_parallelized(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> Ciphertext {
|
||||
let ct2 = self.change_sign(ct2);
|
||||
self.add_total_parallelized(&ct1, &ct2)
|
||||
}
|
||||
}
|
||||
835
concrete-float/src/server_key/tests.rs
Normal file
835
concrete-float/src/server_key/tests.rs
Normal file
@@ -0,0 +1,835 @@
|
||||
#![allow(dead_code)]
|
||||
use std::cmp::{max, min};
|
||||
use rand::Rng;
|
||||
use tfhe::shortint;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use crate::parameters::{PARAM_SAM_32, WOP_PARAM_SAM_32, PARAM_MESSAGE_2_CARRY_2_32,
|
||||
PARAM_MESSAGE_2_CARRY_2_64, WOP_PARAM_MESSAGE_2_CARRY_2_32,
|
||||
WOP_PARAM_MESSAGE_2_CARRY_2_64, FINAL_WOP_PARAM_2_2_32, FINAL_PARAM_2_2_32,
|
||||
FINAL_WOP_PARAM_8, FINAL_PARAM_8, FINAL_PARAM_15,
|
||||
FINAL_WOP_PARAM_15, FINAL_PARAM_16, FINAL_WOP_PARAM_16, FINAL_PARAM_32,
|
||||
FINAL_WOP_PARAM_32, FINAL_PARAM_64, FINAL_WOP_PARAM_64,
|
||||
FINAL_PARAM_64_BIS, FINAL_WOP_PARAM_64_BIS,
|
||||
FINAL_PARAM_32_BIS, FINAL_WOP_PARAM_32_BIS, FINAL_PARAM_16_BIS,
|
||||
FINAL_WOP_PARAM_16_BIS, FINAL_PARAM_15_BIS, FINAL_WOP_PARAM_15_BIS,
|
||||
FINAL_PARAM_8_BIS, FINAL_WOP_PARAM_8_BIS, FINAL_PARAM_32_TCHESS, FINAL_WOP_PARAM_32_TCHESS
|
||||
};
|
||||
use crate::server_key::*;
|
||||
use crate::{gen_keys, ClientKey};
|
||||
|
||||
const NB_OPE: i32 = 50;
|
||||
const LEN_MAN: usize = 13; //13;
|
||||
const LEN_EXP: usize = 4; //4;
|
||||
|
||||
|
||||
const LEN_MAN8: usize = 2;
|
||||
const LEN_EXP8: usize = 2;
|
||||
|
||||
const LEN_MAN16: usize = 6;
|
||||
const LEN_EXP16: usize = 3;
|
||||
|
||||
const LEN_MAN32: usize = 13;
|
||||
const LEN_EXP32: usize = 4;
|
||||
|
||||
const LEN_MAN64: usize = 27;
|
||||
const LEN_EXP64: usize = 5;
|
||||
|
||||
macro_rules! named_param {
|
||||
($param:ident) => {
|
||||
(stringify!($param), $param)
|
||||
};
|
||||
}
|
||||
|
||||
struct Parameters {
|
||||
pbsparameters: shortint::ClassicPBSParameters,
|
||||
wopbsparameters: shortint::WopbsParameters,
|
||||
len_man: usize,
|
||||
len_exp: usize,
|
||||
}
|
||||
|
||||
const PARAM_FP_64_BITS: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_64_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_64_BIS,
|
||||
len_man: LEN_MAN64,
|
||||
len_exp: LEN_EXP64,
|
||||
};
|
||||
|
||||
const PARAM_FP_32_BITS: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_32_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_32_BIS,
|
||||
len_man: LEN_MAN32,
|
||||
len_exp: LEN_EXP32,
|
||||
};
|
||||
|
||||
const PARAM_FP_16_BITS: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_16_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_16_BIS,
|
||||
len_man: LEN_MAN16,
|
||||
len_exp: LEN_EXP16,
|
||||
};
|
||||
|
||||
const PARAM_FP_8_BITS: Parameters = Parameters {
|
||||
pbsparameters: FINAL_PARAM_8_BIS,
|
||||
wopbsparameters: FINAL_WOP_PARAM_8_BIS,
|
||||
len_man: LEN_MAN8,
|
||||
len_exp: LEN_EXP8,
|
||||
};
|
||||
|
||||
const PARAMS: [(&str, Parameters); 1] =
|
||||
[
|
||||
//named_param!(PARAM_FP_64_BITS),
|
||||
named_param!(PARAM_FP_32_BITS),
|
||||
//named_param!(PARAM_FP_16_BITS),
|
||||
//named_param!(PARAM_FP_8_BITS),
|
||||
];
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_float_encrypt() {
|
||||
for (_, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
print_info(&cks);
|
||||
println!("parameters :: {:?}", cks.key.parameters);
|
||||
let msg = 1.;
|
||||
|
||||
// Encryption of one message:
|
||||
let mut ct = cks.encrypt(msg);
|
||||
print_res(&cks, &ct, "decrypt", msg as f32, msg);
|
||||
sks.clean_degree(&mut ct);
|
||||
print_res(&cks, &ct, "decrypt", msg as f32, msg);
|
||||
let res = cks.decrypt(&ct);
|
||||
|
||||
assert_eq!(res, msg);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_mul() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
let msg2 = rng.gen::<f32>() as f64;
|
||||
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
print_res(&cks, &ct1, "ct 1", msg1 as f32, msg1);
|
||||
print_res(&cks, &ct2, "ct 2", msg2 as f32, msg2);
|
||||
|
||||
let res = sks.mul_total_parallelized(&mut ct1.clone(), &mut ct2.clone());
|
||||
print_res(&cks, &res, "Multiplication", (msg2 * msg1) as f32, msg2 * msg1);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res.abs() < ((msg1 * msg2) * 1.001).abs());
|
||||
assert!(res.abs() > ((msg1 * msg2) * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_div() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
let msg2 = rng.gen::<f32>() as f64;
|
||||
let msg1 = -rng.gen::<f32>() as f64;
|
||||
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
print_res(&cks, &ct1, "ct1", (msg1) as f32, msg1);
|
||||
print_res(&cks, &ct2, "ct2", (msg2) as f32, msg2);
|
||||
|
||||
let mut res = sks.division(&ct1, &ct2);
|
||||
print_res(&cks, &res, "Division", (msg1 / msg2) as f32, msg1 / msg2);
|
||||
sks.clean_degree(&mut res);
|
||||
let res = cks.decrypt(&res);
|
||||
|
||||
assert!(res.abs() < ((msg1 / msg2) * 1.001).abs());
|
||||
assert!(res.abs() > ((msg1 / msg2) * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_cos() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
|
||||
let one = cks.encrypt(1.); //should be in trivial encrypt
|
||||
let one_div_by_2 = cks.encrypt(1. / 2.); //should be in trivial encrypt
|
||||
let one_div_by_24 = cks.encrypt(1. / 24.); //should be in trivial encrypt
|
||||
|
||||
print_res(&cks, &one, "one", 1 as f32, 1.);
|
||||
print_res(&cks, &one_div_by_2, "oneDivBy2", (1. / 2.) as f32, 1. / 2.);
|
||||
print_res(&cks, &one_div_by_24, "oneDivBy24", (1. / 24.) as f32, 1. / 24.);
|
||||
print_res(&cks, &ct1, "ct1", msg1 as f32, msg1);
|
||||
|
||||
|
||||
let ct1_square = sks.mul_total_parallelized(&ct1, &ct1);
|
||||
print_res(&cks, &ct1_square, "ct1_square", (msg1 * msg1) as f32, msg1 * msg1);
|
||||
|
||||
let ct1_square_square = sks.mul_total_parallelized(&ct1_square, &ct1_square);
|
||||
print_res(&cks, &ct1_square_square, "ct1_square_square", (msg1 * msg1 * msg1 * msg1) as f32, msg1 * msg1 * msg1 * msg1);
|
||||
|
||||
let ct1_square_time_one_div_by_2 = sks.mul_total_parallelized(&ct1_square, &one_div_by_2);
|
||||
print_res(&cks, &ct1_square_time_one_div_by_2, "ct1_square_time_1DivBy2", (msg1 * msg1 / 2.) as f32, msg1 * msg1 / 2.);
|
||||
|
||||
let ct1_square_square_time_one_div_by_24 = sks.mul_total_parallelized(&ct1_square_square, &one_div_by_24);
|
||||
print_res(&cks, &ct1_square_square_time_one_div_by_24, "ct1_square_square_time_1DivBy24", (msg1 * msg1 * msg1 * msg1 / 24.) as f32, msg1 * msg1 * msg1 * msg1 / 24.);
|
||||
|
||||
let res = sks.add_total_parallelized(&one, &ct1_square_square_time_one_div_by_24);
|
||||
print_res(&cks, &res, "first res", (1. + msg1 * msg1 * msg1 * msg1 / 24.) as f32, 1. + msg1 * msg1 * msg1 * msg1 / 24.);
|
||||
|
||||
|
||||
let res = sks.sub_total_parallelized(&res, &ct1_square_time_one_div_by_2);
|
||||
println!("Cosine, exact result : {:?}", msg1.cos());
|
||||
let approximation = 1. + msg1 * msg1 * msg1 * msg1 / 24. - msg1 * msg1 / 2.;
|
||||
print_res(&cks, &res, "Cosine approximation", approximation as f32, approximation);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res < (approximation * 1.001).abs());
|
||||
assert!(res > (approximation * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_sin() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
|
||||
print_res(&cks, &ct1, "ct1", msg1 as f32, msg1);
|
||||
|
||||
let one_div_by_6 = cks.encrypt(1. / 6.); //should be in trivial encrypt
|
||||
let one_div_by_120 = cks.encrypt(1. / 120.); //should be in trivial encrypt
|
||||
|
||||
let ct1_square = sks.mul_total_parallelized(&ct1, &ct1);
|
||||
print_res(&cks, &ct1_square, "ct1_square", (msg1 * msg1) as f32, msg1 * msg1);
|
||||
|
||||
let ct1_cube = sks.mul_total_parallelized(&ct1_square, &ct1);
|
||||
print_res(&cks, &ct1_cube, "ct1_cube", (msg1 * msg1 * msg1) as f32, msg1 * msg1 * msg1);
|
||||
|
||||
let ct1_power_five = sks.mul_total_parallelized(&ct1_square, &ct1_cube);
|
||||
print_res(&cks, &ct1_power_five, "ct1_power_five", (msg1 * msg1 * msg1 * msg1 * msg1) as f32, msg1 * msg1 * msg1 * msg1 * msg1);
|
||||
|
||||
|
||||
let ct1_cube_time_one_div_by_6 = sks.mul_total_parallelized(&ct1_cube, &one_div_by_6);
|
||||
print_res(&cks, &ct1_cube_time_one_div_by_6, "ct1_cube_time_one_div_by_6", (msg1 * msg1 * msg1 / 6.) as f32, msg1 * msg1 * msg1 / 6.);
|
||||
|
||||
let ct1_power_five_time_one_div_by_120 = sks.mul_total_parallelized(&ct1_power_five, &one_div_by_120);
|
||||
print_res(&cks, &ct1_power_five_time_one_div_by_120, "ct1_power_five_time_one_div_by_120", (msg1 * msg1 * msg1 * msg1 * msg1 / 120.) as f32, msg1 * msg1 * msg1 * msg1 * msg1 / 120.);
|
||||
|
||||
|
||||
let res = sks.add_total_parallelized(&ct1, &ct1_power_five_time_one_div_by_120);
|
||||
print_res(&cks, &ct1_power_five_time_one_div_by_120, "res_1", (msg1 * msg1 * msg1 * msg1 * msg1 / 120.) as f32, msg1 * msg1 * msg1 * msg1 * msg1 / 120.);
|
||||
|
||||
let res = sks.sub_total_parallelized(&res, &ct1_cube_time_one_div_by_6);
|
||||
|
||||
println!("Sine, exact result : {:?}", msg1.sin());
|
||||
let approximation = msg1 + msg1 * msg1 * msg1 * msg1 * msg1 / 120. - msg1 * msg1 * msg1 / 6.;
|
||||
print_res(&cks, &res, "Sine approximation", approximation as f32, approximation);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res < (approximation * 1.001).abs());
|
||||
assert!(res > (approximation * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_add() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
let msg2 = rng.gen::<f32>() as f64;
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
|
||||
print_res(&cks, &ct1, "ct 1", msg1 as f32, msg1);
|
||||
print_res(&cks, &ct2, "ct 2", msg2 as f32, msg2);
|
||||
|
||||
let res = sks.add_total_parallelized(&ct1, &ct2);
|
||||
|
||||
print_res(&cks, &res, "Addition", (msg1 + msg2) as f32, msg1 + msg2);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res.abs() < ((msg1 + msg2) * 1.001).abs());
|
||||
assert!(res.abs() > ((msg1 + msg2) * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_sub() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
|
||||
let msg1 = rng.gen::<f32>() as f64;
|
||||
let msg2 = rng.gen::<f32>() as f64;
|
||||
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
|
||||
|
||||
print_res(&cks, &ct1, "ct 1", msg1 as f32, msg1);
|
||||
print_res(&cks, &ct2, "ct 2", msg2 as f32, msg2);
|
||||
let res = sks.sub_total_parallelized(&ct1, &ct2);
|
||||
|
||||
print_res(&cks, &res, "Subtraction", (msg1 - msg2) as f32, msg1 - msg2);
|
||||
|
||||
let res = cks.decrypt(&res);
|
||||
assert!(res.abs() < ((msg1 - msg2) * 1.001).abs());
|
||||
assert!(res.abs() > ((msg1 - msg2) * 0.999).abs());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn depth_test_parallelized() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
let max_i = 1_000.;
|
||||
|
||||
let mut vec_float_32 = vec![];
|
||||
let mut vec_float_64 = vec![];
|
||||
let mut vec_hom_float = vec![];
|
||||
let mut vec_deep = vec![];
|
||||
let mut vec_nb_operation = vec![];
|
||||
|
||||
let len_vec = 3 as u16;
|
||||
for i in 0..len_vec {
|
||||
let msg = rng.gen::<f32>() as f64;
|
||||
println!("msg_{:?}: {:?}", i, msg);
|
||||
let ct = cks.encrypt(msg);
|
||||
print_res(&cks, &ct, "encrypt/decrypt", msg as f32, msg);
|
||||
|
||||
vec_float_64.push(msg);
|
||||
vec_float_32.push(msg as f32);
|
||||
vec_hom_float.push(ct);
|
||||
vec_deep.push(0);
|
||||
vec_nb_operation.push(0);
|
||||
}
|
||||
|
||||
for i in 0..NB_OPE {
|
||||
println!("\n----Round {:?}----", i);
|
||||
let r_ope = rng.gen::<u16>() % 3;
|
||||
let r_value_1 = (rng.gen::<u16>() % len_vec) as usize;
|
||||
let mut r_value_2 = (rng.gen::<u16>() % len_vec) as usize;
|
||||
let mut r_place = (rng.gen::<u16>() % 2) as usize;
|
||||
while r_value_1 == r_value_2 {
|
||||
r_value_2 = (rng.gen::<u16>() % len_vec) as usize;
|
||||
}
|
||||
if r_place == 0 {
|
||||
r_place = r_value_1
|
||||
} else {
|
||||
r_place = r_value_2
|
||||
}
|
||||
|
||||
vec_deep[r_place] = min(vec_deep[r_value_1], vec_deep[r_value_2]) + 1;
|
||||
vec_nb_operation[r_place] =
|
||||
max(vec_nb_operation[r_value_1], vec_nb_operation[r_value_2]) + 1;
|
||||
if r_ope == 0 {
|
||||
println!(
|
||||
"block {:?} * block {:?} -> block{:?}\n",
|
||||
r_value_1, r_value_2, r_place
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
vec_float_64[r_value_2],
|
||||
vec_float_64[r_value_1] * vec_float_64[r_value_2]
|
||||
);
|
||||
vec_hom_float[r_place] = sks.mul_total_parallelized(
|
||||
&mut vec_hom_float[r_value_1].clone(),
|
||||
&mut vec_hom_float[r_value_2].clone(),
|
||||
);
|
||||
vec_float_32[r_place] = vec_float_32[r_value_1] * vec_float_32[r_value_2];
|
||||
vec_float_64[r_place] = vec_float_64[r_value_1] * vec_float_64[r_value_2];
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
} else if r_ope == 1 {
|
||||
println!(
|
||||
"block {:?} + block {:?} -> block{:?}\n",
|
||||
r_value_1, r_value_2, r_place
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} + {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
vec_float_64[r_value_2],
|
||||
vec_float_64[r_value_1] + vec_float_64[r_value_2]
|
||||
);
|
||||
|
||||
vec_hom_float[r_place] =
|
||||
sks.add_total_parallelized(&vec_hom_float[r_value_1], &vec_hom_float[r_value_2]);
|
||||
vec_float_32[r_place] = vec_float_32[r_value_1] + vec_float_32[r_value_2];
|
||||
vec_float_64[r_place] = vec_float_64[r_value_1] + vec_float_64[r_value_2];
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res add",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"block {:?} - block {:?} -> block{:?}\n",
|
||||
r_value_1, r_value_2, r_place
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} - {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
vec_float_64[r_value_2],
|
||||
vec_float_64[r_value_1] - vec_float_64[r_value_2]
|
||||
);
|
||||
|
||||
vec_hom_float[r_place] =
|
||||
sks.sub_total_parallelized(&vec_hom_float[r_value_1], &vec_hom_float[r_value_2]);
|
||||
vec_float_32[r_place] = vec_float_32[r_value_1] - vec_float_32[r_value_2];
|
||||
vec_float_64[r_place] = vec_float_64[r_value_1] - vec_float_64[r_value_2];
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res sub",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
if vec_float_64[r_value_1].abs() > max_i {
|
||||
let msg_tmp = (1. / max_i) * rng.gen::<f32>() as f64; // 1. / (vec_float_64[r_value_1].abs() + vec_float_64[r_value_2].clone().abs() );
|
||||
let mut ct_tmp = cks.encrypt(msg_tmp);
|
||||
|
||||
println!(
|
||||
"block {:?} * {:?} -> block{:?}\n",
|
||||
r_value_1, msg_tmp, r_value_1
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
msg_tmp,
|
||||
vec_float_64[r_value_1] * msg_tmp
|
||||
);
|
||||
|
||||
vec_hom_float[r_place] =
|
||||
sks.mul_total_parallelized(&mut vec_hom_float[r_value_1].clone(), &mut ct_tmp);
|
||||
vec_float_32[r_value_1] = vec_float_32[r_value_1] * msg_tmp as f32;
|
||||
vec_float_64[r_value_1] = vec_float_64[r_value_1] * msg_tmp;
|
||||
vec_nb_operation[r_value_1] += 1;
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
if vec_float_64[r_value_2].abs() > max_i {
|
||||
let msg_tmp = (1. / max_i) * rng.gen::<f32>() as f64; // 1. / (vec_float_64[r_value_1].abs() + vec_float_64[r_value_2].clone().abs() );
|
||||
let mut ct_tmp = cks.encrypt(msg_tmp);
|
||||
|
||||
println!(
|
||||
"block {:?} * {:?} -> block{:?}\n",
|
||||
r_value_2, msg_tmp, r_value_2
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
msg_tmp,
|
||||
vec_float_64[r_value_1] * msg_tmp
|
||||
);
|
||||
|
||||
vec_hom_float[r_value_2] =
|
||||
sks.mul_total_parallelized(&mut vec_hom_float[r_value_2].clone(), &mut ct_tmp);
|
||||
vec_float_32[r_value_2] = vec_float_32[r_value_2] * msg_tmp as f32;
|
||||
vec_float_64[r_value_2] = vec_float_64[r_value_2] * msg_tmp;
|
||||
vec_nb_operation[r_value_2] += 1;
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
|
||||
if vec_float_64[r_value_1].abs() < 1. / max_i {
|
||||
let msg_tmp = max_i * rng.gen::<f32>() as f64; // 1. / (vec_float_64[r_value_1].abs() + vec_float_64[r_value_2].clone().abs() );
|
||||
let mut ct_tmp = cks.encrypt(msg_tmp);
|
||||
|
||||
println!(
|
||||
"block {:?} * {:?} -> block{:?}\n",
|
||||
r_value_1, msg_tmp, r_value_1
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
msg_tmp,
|
||||
vec_float_64[r_value_1] * msg_tmp
|
||||
);
|
||||
|
||||
vec_hom_float[r_value_1] =
|
||||
sks.mul_total_parallelized(&mut vec_hom_float[r_value_1].clone(), &mut ct_tmp);
|
||||
vec_float_32[r_value_1] = vec_float_32[r_value_1] * msg_tmp as f32;
|
||||
vec_float_64[r_value_1] = vec_float_64[r_value_1] * msg_tmp;
|
||||
vec_nb_operation[r_value_1] += 1;
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
if vec_float_64[r_value_2].abs() < 1. / max_i {
|
||||
let msg_tmp = max_i * rng.gen::<f32>() as f64; // 1. / (vec_float_64[r_value_1].abs() + vec_float_64[r_value_2].clone().abs() );
|
||||
let mut ct_tmp = cks.encrypt(msg_tmp);
|
||||
|
||||
println!(
|
||||
"block {:?} * {:?} -> block{:?}\n",
|
||||
r_value_2, msg_tmp, r_value_2
|
||||
);
|
||||
println!(
|
||||
"expected: {:?} * {:?} = {:?}",
|
||||
vec_float_64[r_value_1],
|
||||
msg_tmp,
|
||||
vec_float_64[r_value_1] * msg_tmp
|
||||
);
|
||||
|
||||
vec_hom_float[r_value_2] =
|
||||
sks.mul_total_parallelized(&mut vec_hom_float[r_value_2].clone(), &mut ct_tmp);
|
||||
vec_float_32[r_value_2] = vec_float_32[r_value_2] * msg_tmp as f32;
|
||||
vec_float_64[r_value_2] = vec_float_64[r_value_2] * msg_tmp;
|
||||
vec_nb_operation[r_value_2] += 1;
|
||||
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[r_place],
|
||||
"res mul",
|
||||
vec_float_32[r_place],
|
||||
vec_float_64[r_place],
|
||||
);
|
||||
}
|
||||
println!("----End Round {:?}----", i);
|
||||
println!("--------------------");
|
||||
println!("--------------------");
|
||||
println!("--------------------");
|
||||
}
|
||||
|
||||
for i in 0..len_vec as usize {
|
||||
println!("------");
|
||||
print_res(
|
||||
&cks,
|
||||
&vec_hom_float[i],
|
||||
"Final result",
|
||||
vec_float_32[i],
|
||||
vec_float_64[i],
|
||||
);
|
||||
//println!("Deep : {:?}", vec_deep[i]);
|
||||
//println!("Ope : {:?}", vec_nb_operation[i]);
|
||||
|
||||
let res = cks.decrypt(&vec_hom_float[i]);
|
||||
assert!(res.abs() < (vec_float_64[i] * 1.001).abs());
|
||||
assert!(res.abs() > (vec_float_64[i] * 0.999).abs());
|
||||
//println!("------");
|
||||
}
|
||||
//println!("Info :");
|
||||
//println!("len mantissa : {:?}", LEN_MAN);
|
||||
//println!("len exponent : {:?}", LEN_EXP);
|
||||
//println!("number operations : {:?}", NB_OPE);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_same_as_ls_22_32() {
|
||||
let (cks, sks) = gen_keys(
|
||||
PARAM_MESSAGE_2_CARRY_2_32,
|
||||
WOP_PARAM_MESSAGE_2_CARRY_2_32,
|
||||
LEN_MAN32,
|
||||
LEN_EXP32,
|
||||
);
|
||||
print_info(&cks);
|
||||
let msg1 = -2.7914999921796382_e-15;
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
print_res(&cks, &ct1, "Encrypt/Decrypt", msg1 as f32, msg1);
|
||||
|
||||
let msg2 = 8.3867001884896375_e-12;
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
print_res(&cks, &ct2, "Encrypt/Decrypt", msg2 as f32, msg2);
|
||||
|
||||
let msg3 = 1.82634005135360_e14;
|
||||
let ct3 = cks.encrypt(msg3);
|
||||
print_res(&cks, &ct3, "Encrypt/Decrypt", msg3 as f32, msg3);
|
||||
|
||||
let msg4 = -6.278269952_e9;
|
||||
let ct4 = cks.encrypt(msg4);
|
||||
print_res(&cks, &ct4, "Encrypt/Decrypt", msg4 as f32, msg4);
|
||||
|
||||
let res_1 = sks.add_total_parallelized(&ct1, &ct2);
|
||||
print_res(
|
||||
&cks,
|
||||
&res_1,
|
||||
"res add",
|
||||
msg1 as f32 + msg2 as f32,
|
||||
msg1 + msg2,
|
||||
);
|
||||
|
||||
let res_2 = sks.sub_total_parallelized(&ct3, &ct4);
|
||||
print_res(
|
||||
&cks,
|
||||
&res_2,
|
||||
"res add",
|
||||
msg3 as f32 - msg4 as f32,
|
||||
msg3 - msg4,
|
||||
);
|
||||
|
||||
let mut witness32 = (msg3 as f32 - msg4 as f32) * (msg1 as f32 + msg2 as f32);
|
||||
let mut witness64 = (msg3 - msg4) * (msg1 + msg2);
|
||||
let res = sks.mul_total_parallelized(&res_1, &res_2);
|
||||
print_res(&cks, &res, "res mul", witness32, witness64);
|
||||
|
||||
let res = sks.mul_total_parallelized(&res, &res);
|
||||
witness32 *= witness32;
|
||||
witness64 *= witness64;
|
||||
print_res(&cks, &res, "res mul", witness32, witness64);
|
||||
let res = cks.decrypt(&res);
|
||||
|
||||
assert!(res.abs() < ((witness32 * 1.001 as f32) as f64).abs());
|
||||
assert!(res.abs() > ((witness32 * 0.999 as f32) as f64).abs());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn float_same_as_ls_22_64() {
|
||||
let (cks, sks) = gen_keys(
|
||||
PARAM_MESSAGE_2_CARRY_2_64,
|
||||
WOP_PARAM_MESSAGE_2_CARRY_2_64,
|
||||
LEN_MAN64,
|
||||
LEN_EXP64,
|
||||
);
|
||||
|
||||
print_info(&cks);
|
||||
let msg1 = -9.1763514236254290_e-32;
|
||||
let ct1 = cks.encrypt(msg1);
|
||||
print_res(&cks, &ct1, "Encrypt/Decrypt", msg1 as f32, msg1);
|
||||
|
||||
let msg2 = 6.2467247246375865_e-24;
|
||||
let ct2 = cks.encrypt(msg2);
|
||||
print_res(&cks, &ct2, "Encrypt/Decrypt", msg2 as f32, msg2);
|
||||
|
||||
let msg3 = 2.4523526872362373_e22;
|
||||
let ct3 = cks.encrypt(msg3);
|
||||
print_res(&cks, &ct3, "Encrypt/Decrypt", msg3 as f32, msg3);
|
||||
|
||||
let msg4 = -5.4324663335297274_e17;
|
||||
let ct4 = cks.encrypt(msg4);
|
||||
print_res(&cks, &ct4, "Encrypt/Decrypt", msg4 as f32, msg4);
|
||||
|
||||
let res_1 = sks.add_total_parallelized(&ct1, &ct2);
|
||||
print_res(
|
||||
&cks,
|
||||
&res_1,
|
||||
"res add",
|
||||
msg1 as f32 + msg2 as f32,
|
||||
msg1 + msg2,
|
||||
);
|
||||
|
||||
let res_2 = sks.sub_total_parallelized(&ct3, &ct4);
|
||||
print_res(
|
||||
&cks,
|
||||
&res_2,
|
||||
"res add",
|
||||
msg3 as f32 - msg4 as f32,
|
||||
msg3 - msg4,
|
||||
);
|
||||
|
||||
let mut witness32 = (msg3 as f32 - msg4 as f32) * (msg1 as f32 + msg2 as f32);
|
||||
let mut witness64 = (msg3 - msg4) * (msg1 + msg2);
|
||||
let res = sks.mul_total_parallelized(&res_1, &res_2);
|
||||
print_res(&cks, &res, "res mul", witness32, witness64);
|
||||
|
||||
let res = sks.mul_total_parallelized(&res, &res);
|
||||
witness32 *= witness32;
|
||||
witness64 *= witness64;
|
||||
print_res(&cks, &res, "res mul", witness32, witness64);
|
||||
let res = cks.decrypt(&res);
|
||||
|
||||
assert!(res.abs() < (witness64 * 1.001).abs());
|
||||
assert!(res.abs() > (witness64 * 0.999).abs());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_relu() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
let msg = rng.gen::<f32>() as f64 - rng.gen::<f32>() as f64;
|
||||
let ct = cks.encrypt(msg);
|
||||
print_res(&cks, &ct, "decrypt", msg as f32, msg);
|
||||
let res = sks.relu(&ct);
|
||||
print_res(&cks, &res, "relu", 0.0_f32.max(msg as f32), msg.max(0.));
|
||||
let res = cks.decrypt(&res);
|
||||
assert_eq!(res, msg.max(0.));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_float_sigmoid() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for (name_parameters, param) in PARAMS {
|
||||
let (cks, sks) = gen_keys(
|
||||
param.pbsparameters,
|
||||
param.wopbsparameters,
|
||||
param.len_man,
|
||||
param.len_exp,
|
||||
);
|
||||
|
||||
println!("--------------------------");
|
||||
println!("---- {name_parameters} ----");
|
||||
println!("--------------------------");
|
||||
|
||||
|
||||
let msg = (rng.gen::<f32>() as f64 + 0.4).abs();
|
||||
let ct = cks.encrypt(msg);
|
||||
print_res(&cks, &ct, "ct", msg as f32, msg);
|
||||
let res = sks.sigmoid(&ct);
|
||||
print_res(&cks, &res, "approx sigmoid", 1.0_f32.min(msg as f32), msg.min(1.));
|
||||
let res = cks.decrypt(&res);
|
||||
|
||||
assert!(res > msg.min(1.) * 0.999);
|
||||
assert!(res < msg.min(1.) * 1.001);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_res(
|
||||
cks: &ClientKey,
|
||||
ct: &Ciphertext,
|
||||
operation: &str,
|
||||
witness32: f32,
|
||||
witness64: f64,
|
||||
) {
|
||||
println!("\n--------------------",);
|
||||
println!("{:?}:\n", operation);
|
||||
println!("Result : {:?}", cks.decrypt(&ct));
|
||||
println!("Clear 32-bits: {:?}", witness32);
|
||||
println!("Clear 64-bits: {:?}\n", witness64);
|
||||
println!("--------------------");
|
||||
}
|
||||
|
||||
pub fn print_info(cks: &ClientKey) {
|
||||
println!("\n-----Info-----");
|
||||
println!("length exp {:?}", cks.vector_length_exponent);
|
||||
println!("length man {:?}", cks.vector_length_mantissa);
|
||||
let msg_modulus = cks.parameters().message_modulus().0;
|
||||
let car_modulus = cks.parameters().carry_modulus().0;
|
||||
println!("msg modulus {:?}, 0b{:b}", msg_modulus, msg_modulus);
|
||||
println!("car modulus {:?}, 0b{:b}", car_modulus, car_modulus);
|
||||
println!(
|
||||
"total space {:?}, 0b{:b}",
|
||||
car_modulus * msg_modulus,
|
||||
car_modulus * msg_modulus
|
||||
);
|
||||
let log_msg_modulus = f64::log2(msg_modulus as f64) as usize;
|
||||
let bias = -((1 << (cks.vector_length_exponent.0 * log_msg_modulus - 1)) as i64)
|
||||
- (cks.vector_length_mantissa.0 as i64 - 1);
|
||||
println!("Bias {:?}", bias);
|
||||
println!("--------------\n");
|
||||
}
|
||||
551
concrete-float/src/server_key/tools.rs
Normal file
551
concrete-float/src/server_key/tools.rs
Normal file
@@ -0,0 +1,551 @@
|
||||
use crate::server_key::Ciphertext;
|
||||
use crate::ServerKey;
|
||||
use shortint::ciphertext::{Ciphertext as ShortintCiphertext, Degree};
|
||||
|
||||
use std::cmp::{max, min};
|
||||
use tfhe::core_crypto::algorithms::{
|
||||
cmux_assign, extract_lwe_sample_from_glwe_ciphertext, keyswitch_lwe_ciphertext,
|
||||
par_keyswitch_lwe_ciphertext,
|
||||
};
|
||||
|
||||
use aligned_vec::ABox;
|
||||
use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut, StackReq};
|
||||
use tfhe::core_crypto::commons::parameters::*;
|
||||
use tfhe::core_crypto::entities::*;
|
||||
use tfhe::core_crypto::fft_impl::fft64::c64;
|
||||
use tfhe::core_crypto::fft_impl::fft64::crypto::ggsw::fill_with_forward_fourier_scratch;
|
||||
use tfhe::core_crypto::fft_impl::fft64::crypto::wop_pbs::{
|
||||
circuit_bootstrap_boolean, circuit_bootstrap_boolean_parallelized,
|
||||
circuit_bootstrap_boolean_scratch, extract_bits, extract_bits_parallelized,
|
||||
extract_bits_scratch,
|
||||
};
|
||||
use tfhe::core_crypto::fft_impl::fft64::math::fft::par_convert_polynomials_list_to_fourier;
|
||||
use tfhe::core_crypto::prelude::{ContiguousEntityContainer, Fft};
|
||||
use tfhe::shortint;
|
||||
use tfhe::shortint::ciphertext::NoiseLevel;
|
||||
|
||||
use rayon::prelude::*;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn ggsw_pbs_ks_cbs(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let accumulator = self.key.generate_lookup_table(|x| min(1, x) as u64);
|
||||
let res = self.key.apply_lookup_table(&ct1, &accumulator);
|
||||
self.ggsw_ks_cbs(&res, message_space)
|
||||
}
|
||||
|
||||
/// return ggsw(0) if ct1 = 0, return ggsw(1) otherwise
|
||||
pub fn ggsw_ks_cbs(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let ciphertext_modulus = ct1.ct.ciphertext_modulus();
|
||||
|
||||
let mut res_ks = LweCiphertext::new(
|
||||
0u64,
|
||||
LweSize(self.wopbs_key.param.lwe_dimension.to_lwe_size().0),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
keyswitch_lwe_ciphertext(&self.key.key_switching_key, &ct1.ct, &mut res_ks);
|
||||
self.ggsw_cbs(&res_ks.as_view(), message_space)
|
||||
}
|
||||
|
||||
/// return ggsw(0) if ct1 = 0, return ggsw(1) otherwise
|
||||
pub fn ggsw_cbs(
|
||||
&self,
|
||||
ct: &LweCiphertext<&[u64]>,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let glwe_dimension = self.wopbs_key.param.glwe_dimension;
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let base_log_cbs = self.wopbs_key.param.cbs_base_log;
|
||||
let level_count_cbs = self.wopbs_key.param.cbs_level;
|
||||
let ciphertext_modulus = ct.ciphertext_modulus();
|
||||
|
||||
let fourier_bsk = match &self.wopbs_key.wopbs_server_key.bootstrapping_key {
|
||||
shortint::server_key::ShortintBootstrappingKey::Classic(fbsk) => fbsk.as_view(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let mut cbs_res = GgswCiphertext::new(
|
||||
0u64,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
base_log_cbs,
|
||||
level_count_cbs,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
let mut ggsw = FourierGgswCiphertext::new(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
base_log_cbs,
|
||||
level_count_cbs,
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
circuit_bootstrap_boolean_scratch::<u64>(
|
||||
ct.lwe_size(),
|
||||
fourier_bsk.output_lwe_dimension().to_lwe_size(),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
let mut stack = PodStack::new(&mut mem);
|
||||
circuit_bootstrap_boolean(
|
||||
fourier_bsk,
|
||||
ct.as_view(),
|
||||
cbs_res.as_mut_view(),
|
||||
DeltaLog(63 - message_space),
|
||||
self.wopbs_key.cbs_pfpksk.as_view(),
|
||||
fft,
|
||||
stack.rb_mut(),
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
|
||||
let mut stack = PodStack::new(&mut mem);
|
||||
ggsw.as_mut_view()
|
||||
.fill_with_forward_fourier(cbs_res.as_view(), fft, stack.rb_mut());
|
||||
ggsw
|
||||
}
|
||||
|
||||
pub fn extract_bit_cbs(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
) -> Vec<FourierGgswCiphertext<ABox<[c64]>>> {
|
||||
let glwe_dimension = self.wopbs_key.param.glwe_dimension;
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let lwe_dimension = self.wopbs_key.param.lwe_dimension;
|
||||
let message_modulus = self.wopbs_key.param.message_modulus;
|
||||
let log_message_modulus = f64::log2(message_modulus.0 as f64) as usize;
|
||||
let log_carry_modulus = f64::log2(self.wopbs_key.param.carry_modulus.0 as f64) as usize;
|
||||
let ciphertext_modulus = ct1.ct.ciphertext_modulus();
|
||||
|
||||
let ksk = &self.key.key_switching_key;
|
||||
let delta_log = 63 - log_message_modulus * log_carry_modulus;
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let req = || {
|
||||
StackReq::try_any_of([
|
||||
fill_with_forward_fourier_scratch(fft)?,
|
||||
extract_bits_scratch::<u64>(
|
||||
lwe_dimension,
|
||||
LweDimension(polynomial_size.0 * glwe_dimension.0 + 1),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)?,
|
||||
])
|
||||
};
|
||||
let req = req().unwrap();
|
||||
let mut mem = GlobalPodBuffer::new(req);
|
||||
let stack = PodStack::new(&mut mem);
|
||||
let fourier_bsk = match &self.wopbs_key.wopbs_server_key.bootstrapping_key {
|
||||
shortint::server_key::ShortintBootstrappingKey::Classic(fbsk) => fbsk.as_view(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let mut lwe_out_list = LweCiphertextList::new(
|
||||
0u64,
|
||||
ksk.output_lwe_size(),
|
||||
LweCiphertextCount(log_message_modulus),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
extract_bits(
|
||||
lwe_out_list.as_mut_view(),
|
||||
ct1.ct.as_view(),
|
||||
ksk.as_view(),
|
||||
fourier_bsk,
|
||||
DeltaLog(delta_log),
|
||||
ExtractedBitsCount(log_message_modulus),
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
let mut out_vec_ggsw: Vec<FourierGgswCiphertext<ABox<[c64]>>> = Vec::new();
|
||||
for lwe in lwe_out_list.iter() {
|
||||
let ggsw = self.ggsw_cbs_parallelized(&lwe, 0);
|
||||
out_vec_ggsw.append(&mut vec![ggsw]);
|
||||
}
|
||||
out_vec_ggsw
|
||||
}
|
||||
|
||||
//return ct0 if we have ggsw(0)
|
||||
//return ct1 if we have ggsw(1)
|
||||
//with cti a ShortintCiphertext
|
||||
pub fn cmux(
|
||||
&self,
|
||||
ct0: &ShortintCiphertext,
|
||||
ct1: &ShortintCiphertext,
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> ShortintCiphertext {
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let glwe_dim = self.wopbs_key.param.glwe_dimension;
|
||||
let mut vec_0 = vec![0u64; polynomial_size.0 * (glwe_dim.0 + 1)];
|
||||
let mut vec_1 = vec![0u64; polynomial_size.0 * (glwe_dim.0 + 1)];
|
||||
for (i, (ct_i_0, ct_i_1)) in ct0
|
||||
.ct
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(ct1.ct.as_ref().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if i % polynomial_size.0 == 0 {
|
||||
vec_0[i] = *ct_i_0;
|
||||
vec_1[i] = *ct_i_1;
|
||||
} else {
|
||||
let index =
|
||||
(i / polynomial_size.0 + 1) * polynomial_size.0 - (i % polynomial_size.0);
|
||||
vec_0[index] = 0 - (*ct_i_0);
|
||||
vec_1[index] = 0 - (*ct_i_1);
|
||||
}
|
||||
}
|
||||
let mut rlwe_0 =
|
||||
GlweCiphertext::from_container(vec_0, polynomial_size, self.key.ciphertext_modulus);
|
||||
let mut rlwe_1 =
|
||||
GlweCiphertext::from_container(vec_1, polynomial_size, self.key.ciphertext_modulus);
|
||||
|
||||
cmux_assign(&mut rlwe_0, &mut rlwe_1, ggsw);
|
||||
|
||||
let mut output = LweCiphertext::new(
|
||||
0_u64,
|
||||
LweSize(polynomial_size.0 * glwe_dim.0 + 1),
|
||||
self.key.ciphertext_modulus,
|
||||
);
|
||||
extract_lwe_sample_from_glwe_ciphertext(&rlwe_0, &mut output, MonomialDegree(0));
|
||||
let ct_out = shortint::Ciphertext::new(
|
||||
output,
|
||||
Degree::new(max(ct0.degree.get(), ct1.degree.get())),
|
||||
NoiseLevel::NOMINAL, // TODO: check this is valid in the context of floats
|
||||
ct0.message_modulus,
|
||||
ct0.carry_modulus,
|
||||
PBSOrder::KeyswitchBootstrap,
|
||||
);
|
||||
ct_out
|
||||
}
|
||||
|
||||
//return ct0 in ct0 if we have ggsw(0)
|
||||
//return ct1 in ct0 if we have ggsw(1)
|
||||
//with cti = [Ciphertext]
|
||||
pub fn cmuxes(
|
||||
&self,
|
||||
ct0: &[ShortintCiphertext],
|
||||
ct1: &[ShortintCiphertext],
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> Vec<shortint::Ciphertext> {
|
||||
let mut vec_output: Vec<ShortintCiphertext> = Vec::new();
|
||||
for (ct_0, ct_1) in ct0.iter().zip(ct1.iter()) {
|
||||
let output = self.cmux(ct_0, ct_1, ggsw);
|
||||
vec_output.push(output);
|
||||
}
|
||||
vec_output
|
||||
}
|
||||
|
||||
//return ct0 in a nwe ct if we have ggsw(0)
|
||||
//return ct1 in a new ct if we have ggsw(1)
|
||||
//with cti a fp
|
||||
pub fn cmuxes_full(
|
||||
&self,
|
||||
ct0: &Ciphertext,
|
||||
ct1: &Ciphertext,
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> Ciphertext {
|
||||
let res_man = self.cmuxes(&ct0.ct_vec_mantissa, &ct1.ct_vec_mantissa, &ggsw);
|
||||
let res_exp = self.cmuxes(&ct0.ct_vec_exponent, &ct1.ct_vec_exponent, &ggsw);
|
||||
let res_sig = self.cmux(&ct0.ct_sign, &ct1.ct_sign, &ggsw);
|
||||
let mut new = self.create_trivial_zero_from_ct(ct0);
|
||||
new.ct_vec_mantissa = res_man;
|
||||
new.ct_vec_exponent = res_exp;
|
||||
new.ct_sign = res_sig;
|
||||
new
|
||||
}
|
||||
|
||||
pub fn cmux_tree_mantissa(
|
||||
&self,
|
||||
vec_mantissa: &Vec<shortint::Ciphertext>,
|
||||
vec_ggsw: &[FourierGgswCiphertext<ABox<[c64]>>],
|
||||
) -> Vec<shortint::Ciphertext> {
|
||||
let zero = self.key.create_trivial(0_u64);
|
||||
let mut cpy = vec_mantissa.clone();
|
||||
let mut vec_fp = Vec::new();
|
||||
for _ in 0..(vec_mantissa.len() + 1) {
|
||||
vec_fp.push(cpy.clone());
|
||||
cpy.push(zero.clone());
|
||||
let _ = cpy.remove(0);
|
||||
}
|
||||
let vec_zero = cpy;
|
||||
for ggsw in vec_ggsw.iter().rev() {
|
||||
if vec_fp.len() == 1 {
|
||||
vec_fp[0] = self.cmuxes(&mut vec_fp[0], &vec_zero, ggsw);
|
||||
} else {
|
||||
if vec_fp.len() % 2 == 0 {
|
||||
for i in 0..vec_fp.len() / 2 {
|
||||
let ct_0 = vec_fp.get_mut(2 * i).unwrap().clone();
|
||||
let ct_1 = vec_fp.get_mut(2 * i + 1).unwrap().clone();
|
||||
vec_fp[i] = self.cmuxes(&ct_0, &ct_1, ggsw);
|
||||
}
|
||||
vec_fp.truncate(vec_fp.len() / 2);
|
||||
} else {
|
||||
for i in 0..vec_fp.len() / 2 {
|
||||
let ct_0 = vec_fp.get_mut(2 * i).unwrap().clone();
|
||||
let ct_1 = vec_fp.get_mut(2 * i + 1).unwrap().clone();
|
||||
vec_fp[i] = self.cmuxes(&ct_0, &ct_1, ggsw);
|
||||
}
|
||||
let last = vec_fp.len();
|
||||
let ct_0 = vec_fp.last().unwrap().clone();
|
||||
let ct_1 = &vec_zero;
|
||||
vec_fp[last / 2] = self.cmuxes(&ct_0, &ct_1, ggsw);
|
||||
vec_fp.truncate((vec_fp.len() + 1) / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_fp[0].clone()
|
||||
}
|
||||
|
||||
pub fn is_block_non_zero_ggsw_pbs_ks_cbs_parallelized(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let accumulator = self.key.generate_lookup_table(|x| u64::from(x != 0));
|
||||
let res = self.key.apply_lookup_table(&ct1, &accumulator);
|
||||
self.ggsw_ks_cbs_parallelized(&res, message_space)
|
||||
}
|
||||
|
||||
/// return ggsw(0) if ct1 = 0, return ggsw(1) otherwise
|
||||
pub fn ggsw_ks_cbs_parallelized(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let ciphertext_modulus = ct1.ct.ciphertext_modulus();
|
||||
|
||||
let mut res_ks = LweCiphertext::new(
|
||||
0u64,
|
||||
LweSize(self.wopbs_key.param.lwe_dimension.to_lwe_size().0),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
par_keyswitch_lwe_ciphertext(&self.key.key_switching_key, &ct1.ct, &mut res_ks);
|
||||
self.ggsw_cbs_parallelized(&res_ks.as_view(), message_space)
|
||||
}
|
||||
|
||||
/// return ggsw(0) if ct1 = 0, return ggsw(1) otherwise
|
||||
pub fn ggsw_cbs_parallelized(
|
||||
&self,
|
||||
ct: &LweCiphertext<&[u64]>,
|
||||
message_space: usize,
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
// todo!("ggsw_cbs_parallelized");
|
||||
let glwe_dimension = self.wopbs_key.param.glwe_dimension;
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let base_log_cbs = self.wopbs_key.param.cbs_base_log;
|
||||
let level_count_cbs = self.wopbs_key.param.cbs_level;
|
||||
let ciphertext_modulus = ct.ciphertext_modulus();
|
||||
|
||||
let fourier_bsk = match &self.wopbs_key.wopbs_server_key.bootstrapping_key {
|
||||
shortint::server_key::ShortintBootstrappingKey::Classic(fbsk) => fbsk.as_view(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let mut cbs_res = GgswCiphertext::new(
|
||||
0u64,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
base_log_cbs,
|
||||
level_count_cbs,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
let mut ggsw = FourierGgswCiphertext::new(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
base_log_cbs,
|
||||
level_count_cbs,
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
circuit_bootstrap_boolean_scratch::<u64>(
|
||||
ct.lwe_size(),
|
||||
fourier_bsk.output_lwe_dimension().to_lwe_size(),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
let mut stack = PodStack::new(&mut mem);
|
||||
circuit_bootstrap_boolean_parallelized(
|
||||
fourier_bsk,
|
||||
ct.as_view(),
|
||||
cbs_res.as_mut_view(),
|
||||
DeltaLog(63 - message_space),
|
||||
self.wopbs_key.cbs_pfpksk.as_view(),
|
||||
fft,
|
||||
stack.rb_mut(),
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
|
||||
let mut _stack = PodStack::new(&mut mem);
|
||||
|
||||
par_convert_polynomials_list_to_fourier(
|
||||
ggsw.as_mut_view().data(),
|
||||
cbs_res.as_ref(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
);
|
||||
// ggsw.as_mut_view()
|
||||
// .fill_with_forward_fourier(cbs_res.as_view(), fft, stack.rb_mut());
|
||||
ggsw
|
||||
}
|
||||
|
||||
pub fn extract_bit_cbs_parallelized(
|
||||
&self,
|
||||
ct1: &ShortintCiphertext,
|
||||
) -> Vec<FourierGgswCiphertext<ABox<[c64]>>> {
|
||||
// todo!("extract_bit_cbs_parallelized");
|
||||
let glwe_dimension = self.wopbs_key.param.glwe_dimension;
|
||||
let polynomial_size = self.wopbs_key.param.polynomial_size;
|
||||
let lwe_dimension = self.wopbs_key.param.lwe_dimension;
|
||||
let message_modulus = self.wopbs_key.param.message_modulus;
|
||||
let log_message_modulus = f64::log2(message_modulus.0 as f64) as usize;
|
||||
let log_carry_modulus = f64::log2(self.wopbs_key.param.carry_modulus.0 as f64) as usize;
|
||||
let ciphertext_modulus = ct1.ct.ciphertext_modulus();
|
||||
|
||||
let ksk = &self.key.key_switching_key;
|
||||
let delta_log = 63 - log_message_modulus * log_carry_modulus;
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let req = || {
|
||||
StackReq::try_any_of([
|
||||
fill_with_forward_fourier_scratch(fft)?,
|
||||
extract_bits_scratch::<u64>(
|
||||
lwe_dimension,
|
||||
LweDimension(polynomial_size.0 * glwe_dimension.0 + 1),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)?,
|
||||
])
|
||||
};
|
||||
let req = req().unwrap();
|
||||
let mut mem = GlobalPodBuffer::new(req);
|
||||
let stack = PodStack::new(&mut mem);
|
||||
let fourier_bsk = match &self.wopbs_key.wopbs_server_key.bootstrapping_key {
|
||||
shortint::server_key::ShortintBootstrappingKey::Classic(fbsk) => fbsk.as_view(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let mut lwe_out_list = LweCiphertextList::new(
|
||||
0u64,
|
||||
ksk.output_lwe_size(),
|
||||
LweCiphertextCount(log_message_modulus),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
extract_bits_parallelized(
|
||||
lwe_out_list.as_mut_view(),
|
||||
ct1.ct.as_view(),
|
||||
ksk.as_view(),
|
||||
fourier_bsk,
|
||||
DeltaLog(delta_log),
|
||||
ExtractedBitsCount(log_message_modulus),
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
let mut out_vec_ggsw: Vec<FourierGgswCiphertext<ABox<[c64]>>> = Vec::new();
|
||||
for lwe in lwe_out_list.iter() {
|
||||
let ggsw = self.ggsw_cbs(&lwe, 0);
|
||||
out_vec_ggsw.append(&mut vec![ggsw]);
|
||||
}
|
||||
out_vec_ggsw
|
||||
}
|
||||
|
||||
//return ct0 in ct0 if we have ggsw(0)
|
||||
//return ct1 in ct0 if we have ggsw(1)
|
||||
//with cti = [Ciphertext]
|
||||
pub fn cmuxes_parallelized(
|
||||
&self,
|
||||
ct0: &[ShortintCiphertext],
|
||||
ct1: &[ShortintCiphertext],
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> Vec<shortint::Ciphertext> {
|
||||
assert_eq!(ct0.len(), ct1.len());
|
||||
let len = ct0.len();
|
||||
let mut vec_output: Vec<ShortintCiphertext> = Vec::with_capacity(len);
|
||||
|
||||
ct0.par_iter()
|
||||
.zip(ct1.par_iter())
|
||||
.map(|(ct_0_i, ct_1_i)| self.cmux(ct_0_i, ct_1_i, ggsw))
|
||||
.collect_into_vec(&mut vec_output);
|
||||
|
||||
vec_output
|
||||
}
|
||||
|
||||
//return ct0 in a nwe ct if we have ggsw(0)
|
||||
//return ct1 in a new ct if we have ggsw(1)
|
||||
//with cti a fp
|
||||
pub fn cmuxes_full_parallelized(
|
||||
&self,
|
||||
ct0: &Ciphertext,
|
||||
ct1: &Ciphertext,
|
||||
ggsw: &FourierGgswCiphertext<ABox<[c64]>>,
|
||||
) -> Ciphertext {
|
||||
// todo!("cmuxes_full_parallelized");
|
||||
let (res_man, res_exp) = rayon::join(
|
||||
|| self.cmuxes_parallelized(&ct0.ct_vec_mantissa, &ct1.ct_vec_mantissa, &ggsw),
|
||||
|| self.cmuxes_parallelized(&ct0.ct_vec_exponent, &ct1.ct_vec_exponent, &ggsw),
|
||||
);
|
||||
let res_sig = self.cmux(&ct0.ct_sign, &ct1.ct_sign, &ggsw);
|
||||
let mut new = self.create_trivial_zero_from_ct(ct0);
|
||||
new.ct_vec_mantissa = res_man;
|
||||
new.ct_vec_exponent = res_exp;
|
||||
new.ct_sign = res_sig;
|
||||
new
|
||||
}
|
||||
|
||||
pub fn cmux_tree_mantissa_parallelized(
|
||||
&self,
|
||||
vec_mantissa: &Vec<shortint::Ciphertext>,
|
||||
vec_ggsw: &[FourierGgswCiphertext<ABox<[c64]>>],
|
||||
) -> Vec<shortint::Ciphertext> {
|
||||
// todo!("cmux_tree_mantissa_parallelized");
|
||||
let zero = self.key.create_trivial(0_u64);
|
||||
let mut cpy = vec_mantissa.clone();
|
||||
let mut vec_fp = Vec::new();
|
||||
for _ in 0..(vec_mantissa.len() + 1) {
|
||||
vec_fp.push(cpy.clone());
|
||||
cpy.push(zero.clone());
|
||||
let _ = cpy.remove(0);
|
||||
}
|
||||
let vec_zero = cpy;
|
||||
// TODO cmux tree in parallel
|
||||
for ggsw in vec_ggsw.iter().rev() {
|
||||
if vec_fp.len() == 1 {
|
||||
vec_fp[0] = self.cmuxes_parallelized(&mut vec_fp[0], &vec_zero, ggsw);
|
||||
} else {
|
||||
if vec_fp.len() % 2 == 0 {
|
||||
for i in 0..vec_fp.len() / 2 {
|
||||
let ct_0 = vec_fp.get_mut(2 * i).unwrap().clone();
|
||||
let ct_1 = vec_fp.get_mut(2 * i + 1).unwrap().clone();
|
||||
vec_fp[i] = self.cmuxes_parallelized(&ct_0, &ct_1, ggsw);
|
||||
}
|
||||
vec_fp.truncate(vec_fp.len() / 2);
|
||||
} else {
|
||||
for i in 0..vec_fp.len() / 2 {
|
||||
let ct_0 = vec_fp.get_mut(2 * i).unwrap().clone();
|
||||
let ct_1 = vec_fp.get_mut(2 * i + 1).unwrap().clone();
|
||||
vec_fp[i] = self.cmuxes_parallelized(&ct_0, &ct_1, ggsw);
|
||||
}
|
||||
let last = vec_fp.len();
|
||||
let ct_0 = vec_fp.last().unwrap().clone();
|
||||
let ct_1 = &vec_zero;
|
||||
vec_fp[last / 2] = self.cmuxes_parallelized(&ct_0, &ct_1, ggsw);
|
||||
vec_fp.truncate((vec_fp.len() + 1) / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_fp[0].clone()
|
||||
}
|
||||
}
|
||||
9
concrete-float/src/test_user_docs.rs
Normal file
9
concrete-float/src/test_user_docs.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use doc_comment::doctest;
|
||||
|
||||
doctest!("../docs/getting_started/first_circuit.md", first_circuit);
|
||||
doctest!("../docs/tutorials/serialization.md", serialization_tuto);
|
||||
doctest!(
|
||||
"../docs/tutorials/circuit_evaluation.md",
|
||||
circuit_evaluation
|
||||
);
|
||||
doctest!("../docs/how_to/pbs.md", pbs);
|
||||
@@ -22,8 +22,6 @@ not_multi_bit="_multi_bit"
|
||||
signed=""
|
||||
not_signed=""
|
||||
cargo_profile="release"
|
||||
# TODO: revert to release once the bug is properly fixed/identified
|
||||
cargo_profile_doctests="release_lto_off"
|
||||
avx512_feature=""
|
||||
|
||||
while [ -n "$1" ]
|
||||
@@ -163,7 +161,7 @@ cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
|
||||
if [[ "${multi_bit}" == "" ]]; then
|
||||
cargo "${RUST_TOOLCHAIN}" test \
|
||||
--profile "${cargo_profile_doctests}" \
|
||||
--profile "${cargo_profile}" \
|
||||
--package tfhe \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache,"${avx512_feature}" \
|
||||
--doc \
|
||||
|
||||
@@ -55,7 +55,7 @@ concrete-csprng = { version = "0.4.0", path = "../concrete-csprng", features = [
|
||||
] }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
rayon = { version = "1.5.0" }
|
||||
rayon = { version = "1.5" }
|
||||
bincode = { version = "1.3.3", optional = true }
|
||||
concrete-fft = { version = "0.3.0", features = ["serde", "fft128"] }
|
||||
pulp = "0.13"
|
||||
@@ -82,6 +82,7 @@ bytemuck = "1.13.1"
|
||||
boolean = ["dep:paste"]
|
||||
shortint = ["dep:paste"]
|
||||
integer = ["shortint", "dep:paste"]
|
||||
float_wopbs = ["shortint", "dep:paste"]
|
||||
internal-keycache = ["dep:lazy_static", "dep:fs2", "dep:bincode", "dep:paste"]
|
||||
safe-deserialization = ["dep:bincode"]
|
||||
|
||||
@@ -103,7 +104,6 @@ __wasm_api = [
|
||||
"dep:console_error_panic_hook",
|
||||
"dep:serde-wasm-bindgen",
|
||||
"dep:getrandom",
|
||||
"getrandom/js",
|
||||
"dep:bincode",
|
||||
"safe-deserialization",
|
||||
]
|
||||
@@ -151,6 +151,12 @@ rustdoc-args = ["--html-in-header", "katex-header.html"]
|
||||
# #
|
||||
###########
|
||||
|
||||
[[bench]]
|
||||
name = "ks-bench"
|
||||
path = "benches/core_crypto/ks_bench.rs"
|
||||
harness = false
|
||||
required-features = ["shortint", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "pbs-bench"
|
||||
path = "benches/core_crypto/pbs_bench.rs"
|
||||
@@ -204,6 +210,12 @@ path = "benches/utilities.rs"
|
||||
harness = false
|
||||
required-features = ["boolean", "shortint", "integer", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "float-wopbs-bench"
|
||||
path = "benches/float_wopbs/bench.rs"
|
||||
harness = false
|
||||
required-features = []
|
||||
|
||||
# Examples used as tools
|
||||
|
||||
[[example]]
|
||||
@@ -257,3 +269,7 @@ required-features = ["boolean"]
|
||||
|
||||
[lib]
|
||||
crate-type = ["lib", "staticlib", "cdylib"]
|
||||
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(bench)'] }
|
||||
|
||||
91
tfhe/benches/core_crypto/ks_bench.rs
Normal file
91
tfhe/benches/core_crypto/ks_bench.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::prelude::*;
|
||||
|
||||
fn criterion_bench(criterion: &mut Criterion) {
|
||||
type Scalar = u64;
|
||||
|
||||
let mut bench_group = criterion.benchmark_group("KS");
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60));
|
||||
|
||||
for params in [
|
||||
PARAM_MESSAGE_1_CARRY_1_KS_PBS,
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
PARAM_MESSAGE_3_CARRY_3_KS_PBS,
|
||||
PARAM_MESSAGE_4_CARRY_4_KS_PBS,
|
||||
]
|
||||
.into_iter()
|
||||
{
|
||||
let lwe_dimension = params.lwe_dimension;
|
||||
let lwe_modular_std_dev = params.lwe_modular_std_dev;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let encoding_with_padding = if ciphertext_modulus.is_native_modulus() {
|
||||
Scalar::ONE << (Scalar::BITS - 1)
|
||||
} else {
|
||||
Scalar::cast_from(ciphertext_modulus.get_custom_modulus() / 2)
|
||||
};
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let ks_decomp_base_log = params.ks_base_log;
|
||||
let ks_decomp_level_count = params.ks_level;
|
||||
let msg_modulus: Scalar = params.message_modulus.0.cast_into();
|
||||
let total_modulus: Scalar = (params.message_modulus.0 * params.carry_modulus.0).cast_into();
|
||||
|
||||
let msg = msg_modulus - 1;
|
||||
let delta: Scalar = encoding_with_padding / total_modulus;
|
||||
|
||||
// Create the PRNG
|
||||
let mut seeder = new_seeder();
|
||||
let seeder = seeder.as_mut();
|
||||
let mut encryption_generator =
|
||||
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
let mut secret_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
|
||||
let lwe_sk =
|
||||
allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
|
||||
|
||||
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut secret_generator,
|
||||
);
|
||||
let big_lwe_sk = glwe_sk.into_lwe_secret_key();
|
||||
let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
&big_lwe_sk,
|
||||
&lwe_sk,
|
||||
ks_decomp_base_log,
|
||||
ks_decomp_level_count,
|
||||
lwe_modular_std_dev,
|
||||
ciphertext_modulus,
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let plaintext = Plaintext(msg * delta);
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&big_lwe_sk,
|
||||
plaintext,
|
||||
lwe_modular_std_dev,
|
||||
ciphertext_modulus,
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let mut output_ct = LweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
bench_group.bench_function(¶ms.name(), |bencher| {
|
||||
bencher.iter(|| {
|
||||
keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut output_ct);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_bench);
|
||||
criterion_main!(benches);
|
||||
90
tfhe/benches/float_wopbs/bench.rs
Normal file
90
tfhe/benches/float_wopbs/bench.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use tfhe::float_wopbs::gen_keys;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use tfhe::float_wopbs::parameters::{
|
||||
PARAM_MESSAGE_2_16_BITS, PARAM_MESSAGE_4_16_BITS, PARAM_MESSAGE_8_16_BITS,
|
||||
};
|
||||
use tfhe::float_wopbs::parameters::{ PARAM_MESSAGE_2_4_8_BITS_BIV, PARAM_MESSAGE_4_2_8_BITS_BIV};
|
||||
use tfhe::shortint::WopbsParameters;
|
||||
|
||||
macro_rules! named_param {
|
||||
($param:ident) => {
|
||||
(stringify!($param), $param)
|
||||
};
|
||||
}
|
||||
|
||||
struct Parameters {
|
||||
parameters: WopbsParameters,
|
||||
bit_mantissa: usize,
|
||||
bit_exponent: usize,
|
||||
}
|
||||
|
||||
|
||||
const PARAM_4_BIT_LWE_8_BITS: Parameters = Parameters {
|
||||
parameters: PARAM_MESSAGE_2_4_8_BITS_BIV,
|
||||
bit_mantissa: 4,
|
||||
bit_exponent: 3,
|
||||
};
|
||||
|
||||
const PARAM_2_BIT_LWE_8_BITS: Parameters = Parameters {
|
||||
parameters: PARAM_MESSAGE_4_2_8_BITS_BIV,
|
||||
bit_mantissa: 4,
|
||||
bit_exponent: 3,
|
||||
};
|
||||
|
||||
|
||||
const SERVER_KEY_BENCH_PARAMS: [(&str, Parameters); 2] =
|
||||
[ named_param!(PARAM_4_BIT_LWE_8_BITS),
|
||||
named_param!(PARAM_2_BIT_LWE_8_BITS)];
|
||||
|
||||
criterion_main!(float);
|
||||
|
||||
criterion_group!(float, float_wopbs_bivariate);
|
||||
|
||||
pub fn float_wopbs_mut_eval(c: &mut Criterion) {
|
||||
for name_param in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(name_param.1.parameters);
|
||||
let bit_mantissa = &name_param.1.bit_mantissa;
|
||||
let bit_exponent = &name_param.1.bit_exponent;
|
||||
let e_min = -2;
|
||||
let msg_1 = 0.375;
|
||||
|
||||
// Encryption:
|
||||
let mut ct_1 = cks.encrypt(msg_1, e_min, *bit_mantissa, *bit_exponent);
|
||||
|
||||
let lut = sks.create_lut(&mut ct_1, |x| x);
|
||||
let bench_id = format!("8-bit floats WoP-PBS lut eval::{}", name_param.0);
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.wop_pbs(&sks, &mut ct_1, &lut);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn float_wopbs_bivariate(c: &mut Criterion) {
|
||||
for name_param in SERVER_KEY_BENCH_PARAMS {
|
||||
let (cks, sks) = gen_keys(name_param.1.parameters);
|
||||
let bit_mantissa = &name_param.1.bit_mantissa;
|
||||
let bit_exponent = &name_param.1.bit_exponent;
|
||||
|
||||
let e_min = -2;
|
||||
let msg_1 = 0.375;
|
||||
|
||||
// Encryption:
|
||||
let mut ct_1 = cks.encrypt(msg_1, e_min, *bit_mantissa, *bit_exponent);
|
||||
let msg_2 = -44.;
|
||||
let mut ct_2 = cks.encrypt(msg_2, e_min, *bit_mantissa, *bit_exponent);
|
||||
|
||||
let lut = sks.create_bivariate_lut(&mut ct_1, |x, y| y * x);
|
||||
let bench_id = format!("8-bit floats WoP-PBS bivariate::{}", name_param.0);
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
sks.wop_pbs_bivariate(&sks, &mut ct_1, &mut ct_2, &lut);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -393,7 +393,7 @@ fn _bench_wopbs_param_message_8_norm2_5(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("programmable_bootstrap");
|
||||
|
||||
let param = WOPBS_PARAM_MESSAGE_4_NORM2_6_KS_PBS;
|
||||
let param_set: ShortintParameterSet = param.try_into().unwrap();
|
||||
let param_set: ShortintParameterSet = param.into();
|
||||
let pbs_params = param_set.pbs_parameters().unwrap();
|
||||
|
||||
let keys = KEY_CACHE_WOPBS.get_from_param((pbs_params, param));
|
||||
@@ -402,14 +402,14 @@ fn _bench_wopbs_param_message_8_norm2_5(c: &mut Criterion) {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let clear = rng.gen::<usize>() % param.message_modulus.0;
|
||||
let mut ct = cks.encrypt_without_padding(clear as u64);
|
||||
let ct = cks.encrypt_without_padding(clear as u64);
|
||||
let vec_lut = wopbs_key.generate_lut_native_crt(&ct, |x| x);
|
||||
|
||||
let id = format!("Shortint WOPBS: {param:?}");
|
||||
|
||||
bench_group.bench_function(&id, |b| {
|
||||
b.iter(|| {
|
||||
let _ = wopbs_key.programmable_bootstrapping_native_crt(&mut ct, &vec_lut);
|
||||
let _ = wopbs_key.programmable_bootstrapping_native_crt(&ct, &vec_lut);
|
||||
})
|
||||
});
|
||||
|
||||
|
||||
@@ -162,11 +162,9 @@ fn main() {
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
let ct_3 = client_key.encrypt(msg3);
|
||||
|
||||
let result = server_key.checked_small_scalar_mul_assign(&mut ct_1, scalar);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let result = server_key.checked_sub_assign(&mut ct_1, &ct_2);
|
||||
assert!(result.is_ok());
|
||||
server_key.checked_small_scalar_mul_assign(&mut ct_1, scalar).unwrap();
|
||||
|
||||
server_key.checked_sub_assign(&mut ct_1, &ct_2).unwrap();
|
||||
|
||||
let result = server_key.checked_add_assign(&mut ct_1, &ct_3);
|
||||
assert!(result.is_err());
|
||||
|
||||
@@ -19,7 +19,7 @@ pub fn has_match(
|
||||
|
||||
let res = if branches.len() <= 1 {
|
||||
branches
|
||||
.get(0)
|
||||
.first()
|
||||
.map_or(exec.ct_false(), |branch| branch(&mut exec))
|
||||
.0
|
||||
} else {
|
||||
|
||||
@@ -8,7 +8,6 @@ use crate::core_crypto::commons::parameters::{CiphertextModulus, PBSOrder};
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::Fft;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
/// Memory used as buffer for the bootstrap
|
||||
///
|
||||
@@ -149,17 +148,14 @@ pub(crate) struct Bootstrapper {
|
||||
impl Bootstrapper {
|
||||
pub fn new(seeder: &mut dyn Seeder) -> Self {
|
||||
Self {
|
||||
memory: Default::default(),
|
||||
memory: Memory::default(),
|
||||
encryption_generator: EncryptionRandomGenerator::<_>::new(seeder.seed(), seeder),
|
||||
computation_buffers: Default::default(),
|
||||
computation_buffers: ComputationBuffers::default(),
|
||||
seeder: DeterministicSeeder::<_>::new(seeder.seed()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_server_key(
|
||||
&mut self,
|
||||
cks: &ClientKey,
|
||||
) -> Result<ServerKey, Box<dyn std::error::Error>> {
|
||||
pub(crate) fn new_server_key(&mut self, cks: &ClientKey) -> ServerKey {
|
||||
let standard_bootstrapping_key: LweBootstrapKeyOwned<u32> =
|
||||
par_allocate_and_generate_new_lwe_bootstrap_key(
|
||||
&cks.lwe_secret_key,
|
||||
@@ -208,17 +204,14 @@ impl Bootstrapper {
|
||||
&mut self.encryption_generator,
|
||||
);
|
||||
|
||||
Ok(ServerKey {
|
||||
ServerKey {
|
||||
bootstrapping_key: fourier_bsk,
|
||||
key_switching_key: ksk,
|
||||
pbs_order: cks.parameters.encryption_key_choice.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_compressed_server_key(
|
||||
&mut self,
|
||||
cks: &ClientKey,
|
||||
) -> Result<CompressedServerKey, Box<dyn std::error::Error>> {
|
||||
pub(crate) fn new_compressed_server_key(&mut self, cks: &ClientKey) -> CompressedServerKey {
|
||||
#[cfg(not(feature = "__wasm_api"))]
|
||||
let bootstrapping_key = par_allocate_and_generate_new_seeded_lwe_bootstrap_key(
|
||||
&cks.lwe_secret_key,
|
||||
@@ -254,18 +247,18 @@ impl Bootstrapper {
|
||||
&mut self.seeder,
|
||||
);
|
||||
|
||||
Ok(CompressedServerKey {
|
||||
CompressedServerKey {
|
||||
bootstrapping_key,
|
||||
key_switching_key,
|
||||
pbs_order: cks.parameters.encryption_key_choice.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn bootstrap(
|
||||
&mut self,
|
||||
input: &LweCiphertextOwned<u32>,
|
||||
server_key: &ServerKey,
|
||||
) -> Result<LweCiphertextOwned<u32>, Box<dyn Error>> {
|
||||
) -> LweCiphertextOwned<u32> {
|
||||
let BuffersRef {
|
||||
lookup_table: accumulator,
|
||||
mut buffer_lwe_after_pbs,
|
||||
@@ -297,37 +290,17 @@ impl Bootstrapper {
|
||||
stack,
|
||||
);
|
||||
|
||||
Ok(LweCiphertext::from_container(
|
||||
LweCiphertext::from_container(
|
||||
buffer_lwe_after_pbs.as_ref().to_owned(),
|
||||
input.ciphertext_modulus(),
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch(
|
||||
&mut self,
|
||||
input: &LweCiphertextOwned<u32>,
|
||||
server_key: &ServerKey,
|
||||
) -> Result<LweCiphertextOwned<u32>, Box<dyn Error>> {
|
||||
// Allocate the output of the KS
|
||||
let mut output = LweCiphertext::new(
|
||||
0u32,
|
||||
server_key
|
||||
.bootstrapping_key
|
||||
.input_lwe_dimension()
|
||||
.to_lwe_size(),
|
||||
input.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
keyswitch_lwe_ciphertext(&server_key.key_switching_key, input, &mut output);
|
||||
|
||||
Ok(output)
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn bootstrap_keyswitch(
|
||||
&mut self,
|
||||
mut ciphertext: LweCiphertextOwned<u32>,
|
||||
server_key: &ServerKey,
|
||||
) -> Result<Ciphertext, Box<dyn Error>> {
|
||||
) -> Ciphertext {
|
||||
let BuffersRef {
|
||||
lookup_table,
|
||||
mut buffer_lwe_after_pbs,
|
||||
@@ -367,14 +340,14 @@ impl Bootstrapper {
|
||||
&mut ciphertext,
|
||||
);
|
||||
|
||||
Ok(Ciphertext::Encrypted(ciphertext))
|
||||
Ciphertext::Encrypted(ciphertext)
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch_bootstrap(
|
||||
&mut self,
|
||||
mut ciphertext: LweCiphertextOwned<u32>,
|
||||
server_key: &ServerKey,
|
||||
) -> Result<Ciphertext, Box<dyn Error>> {
|
||||
) -> Ciphertext {
|
||||
let BuffersRef {
|
||||
lookup_table,
|
||||
mut buffer_lwe_after_ks,
|
||||
@@ -414,13 +387,13 @@ impl Bootstrapper {
|
||||
stack,
|
||||
);
|
||||
|
||||
Ok(Ciphertext::Encrypted(ciphertext))
|
||||
Ciphertext::Encrypted(ciphertext)
|
||||
}
|
||||
pub(crate) fn apply_bootstrapping_pattern(
|
||||
&mut self,
|
||||
ct: LweCiphertextOwned<u32>,
|
||||
server_key: &ServerKey,
|
||||
) -> Result<Ciphertext, Box<dyn Error>> {
|
||||
) -> Ciphertext {
|
||||
match server_key.pbs_order {
|
||||
PBSOrder::KeyswitchBootstrap => self.keyswitch_bootstrap(ct, server_key),
|
||||
PBSOrder::BootstrapKeyswitch => self.bootstrap_keyswitch(ct, server_key),
|
||||
@@ -428,6 +401,21 @@ impl Bootstrapper {
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerKey {
|
||||
pub(crate) fn keyswitch(&self, input: &LweCiphertextOwned<u32>) -> LweCiphertextOwned<u32> {
|
||||
// Allocate the output of the KS
|
||||
let mut output = LweCiphertext::new(
|
||||
0u32,
|
||||
self.bootstrapping_key.input_lwe_dimension().to_lwe_size(),
|
||||
input.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
keyswitch_lwe_ciphertext(&self.key_switching_key, input, &mut output);
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CompressedServerKey> for ServerKey {
|
||||
fn from(compressed_server_key: CompressedServerKey) -> Self {
|
||||
let CompressedServerKey {
|
||||
@@ -458,8 +446,8 @@ impl From<CompressedServerKey> for ServerKey {
|
||||
);
|
||||
|
||||
Self {
|
||||
key_switching_key,
|
||||
bootstrapping_key,
|
||||
key_switching_key,
|
||||
pbs_order,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,11 +103,11 @@ impl BooleanEngine {
|
||||
}
|
||||
|
||||
pub fn create_server_key(&mut self, cks: &ClientKey) -> ServerKey {
|
||||
self.bootstrapper.new_server_key(cks).unwrap()
|
||||
self.bootstrapper.new_server_key(cks)
|
||||
}
|
||||
|
||||
pub fn create_compressed_server_key(&mut self, cks: &ClientKey) -> CompressedServerKey {
|
||||
self.bootstrapper.new_compressed_server_key(cks).unwrap()
|
||||
self.bootstrapper.new_compressed_server_key(cks)
|
||||
}
|
||||
|
||||
pub fn create_public_key(&mut self, client_key: &ClientKey) -> PublicKey {
|
||||
@@ -149,7 +149,7 @@ impl BooleanEngine {
|
||||
|
||||
PublicKey {
|
||||
lwe_public_key,
|
||||
parameters: client_key.parameters.to_owned(),
|
||||
parameters: client_key.parameters,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,9 +217,8 @@ impl BooleanEngine {
|
||||
}
|
||||
(choice1, choice2) => panic!(
|
||||
"EncryptionKeyChoice of cks1 and cks2 must be the same.\
|
||||
cks1 has {:?}, cks2 has: {:?}
|
||||
",
|
||||
choice1, choice2
|
||||
cks1 has {choice1:?}, cks2 has: {choice2:?}
|
||||
"
|
||||
),
|
||||
};
|
||||
|
||||
@@ -434,7 +433,7 @@ impl BooleanEngine {
|
||||
pub fn replace_thread_local(new_engine: Self) {
|
||||
Self::with_thread_local_mut(|local_engine| {
|
||||
let _ = std::mem::replace(local_engine, new_engine);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
@@ -459,42 +458,6 @@ impl BooleanEngine {
|
||||
}
|
||||
}
|
||||
|
||||
/// convert into an actual LWE ciphertext even when trivial
|
||||
fn convert_into_lwe_ciphertext_32(
|
||||
&mut self,
|
||||
ct: &Ciphertext,
|
||||
server_key: &ServerKey,
|
||||
) -> LweCiphertextOwned<u32> {
|
||||
match ct {
|
||||
Ciphertext::Encrypted(ct_ct) => ct_ct.clone(),
|
||||
Ciphertext::Trivial(message) => {
|
||||
// encode the boolean message
|
||||
let plain: Plaintext<u32> = if *message {
|
||||
Plaintext(PLAINTEXT_TRUE)
|
||||
} else {
|
||||
Plaintext(PLAINTEXT_FALSE)
|
||||
};
|
||||
|
||||
let lwe_size = match server_key.pbs_order {
|
||||
PBSOrder::KeyswitchBootstrap => server_key
|
||||
.key_switching_key
|
||||
.input_key_lwe_dimension()
|
||||
.to_lwe_size(),
|
||||
PBSOrder::BootstrapKeyswitch => server_key
|
||||
.bootstrapping_key
|
||||
.input_lwe_dimension()
|
||||
.to_lwe_size(),
|
||||
};
|
||||
|
||||
allocate_and_trivially_encrypt_new_lwe_ciphertext(
|
||||
lwe_size,
|
||||
plain,
|
||||
CiphertextModulus::new_native(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mux(
|
||||
&mut self,
|
||||
ct_condition: &Ciphertext,
|
||||
@@ -537,8 +500,8 @@ impl BooleanEngine {
|
||||
}
|
||||
|
||||
// convert inputs into LweCiphertextOwned<u32>
|
||||
let ct_then_ct = self.convert_into_lwe_ciphertext_32(ct_then, server_key);
|
||||
let ct_else_ct = self.convert_into_lwe_ciphertext_32(ct_else, server_key);
|
||||
let ct_then_ct = convert_into_lwe_ciphertext_32(ct_then, server_key);
|
||||
let ct_else_ct = convert_into_lwe_ciphertext_32(ct_else, server_key);
|
||||
|
||||
let mut buffer_lwe_before_pbs_o = LweCiphertext::new(
|
||||
0u32,
|
||||
@@ -565,15 +528,13 @@ impl BooleanEngine {
|
||||
|
||||
match server_key.pbs_order {
|
||||
PBSOrder::KeyswitchBootstrap => {
|
||||
let ct_ks_1 = bootstrapper
|
||||
.keyswitch(buffer_lwe_before_pbs, server_key)
|
||||
.unwrap();
|
||||
let ct_ks_1 = server_key.keyswitch(buffer_lwe_before_pbs);
|
||||
|
||||
// Compute the first programmable bootstrapping with fixed test polynomial:
|
||||
let mut ct_pbs_1 = bootstrapper.bootstrap(&ct_ks_1, server_key).unwrap();
|
||||
let mut ct_pbs_1 = bootstrapper.bootstrap(&ct_ks_1, server_key);
|
||||
|
||||
let ct_ks_2 = bootstrapper.keyswitch(&ct_temp_2, server_key).unwrap();
|
||||
let ct_pbs_2 = bootstrapper.bootstrap(&ct_ks_2, server_key).unwrap();
|
||||
let ct_ks_2 = server_key.keyswitch(&ct_temp_2);
|
||||
let ct_pbs_2 = bootstrapper.bootstrap(&ct_ks_2, server_key);
|
||||
|
||||
// Compute the linear combination to add the two results:
|
||||
// buffer_lwe_pbs + ct_pbs_2 + (0,...,0, +1/8)
|
||||
@@ -586,11 +547,10 @@ impl BooleanEngine {
|
||||
}
|
||||
PBSOrder::BootstrapKeyswitch => {
|
||||
// Compute the first programmable bootstrapping with fixed test polynomial:
|
||||
let mut ct_pbs_1 = bootstrapper
|
||||
.bootstrap(buffer_lwe_before_pbs, server_key)
|
||||
.unwrap();
|
||||
let mut ct_pbs_1 =
|
||||
bootstrapper.bootstrap(buffer_lwe_before_pbs, server_key);
|
||||
|
||||
let ct_pbs_2 = bootstrapper.bootstrap(&ct_temp_2, server_key).unwrap();
|
||||
let ct_pbs_2 = bootstrapper.bootstrap(&ct_temp_2, server_key);
|
||||
|
||||
// Compute the linear combination to add the two results:
|
||||
// buffer_lwe_pbs + ct_pbs_2 + (0,...,0, +1/8)
|
||||
@@ -598,7 +558,7 @@ impl BooleanEngine {
|
||||
let cst = Plaintext(PLAINTEXT_TRUE);
|
||||
lwe_ciphertext_plaintext_add_assign(&mut ct_pbs_1, cst); // + 1/8
|
||||
|
||||
let ct_ks = bootstrapper.keyswitch(&ct_pbs_1, server_key).unwrap();
|
||||
let ct_ks = server_key.keyswitch(&ct_pbs_1);
|
||||
|
||||
// Output the result:
|
||||
Ciphertext::Encrypted(ct_ks)
|
||||
@@ -609,6 +569,41 @@ impl BooleanEngine {
|
||||
}
|
||||
}
|
||||
|
||||
/// convert into an actual LWE ciphertext even when trivial
|
||||
fn convert_into_lwe_ciphertext_32(
|
||||
ct: &Ciphertext,
|
||||
server_key: &ServerKey,
|
||||
) -> LweCiphertextOwned<u32> {
|
||||
match ct {
|
||||
Ciphertext::Encrypted(ct_ct) => ct_ct.clone(),
|
||||
Ciphertext::Trivial(message) => {
|
||||
// encode the boolean message
|
||||
let plain: Plaintext<u32> = if *message {
|
||||
Plaintext(PLAINTEXT_TRUE)
|
||||
} else {
|
||||
Plaintext(PLAINTEXT_FALSE)
|
||||
};
|
||||
|
||||
let lwe_size = match server_key.pbs_order {
|
||||
PBSOrder::KeyswitchBootstrap => server_key
|
||||
.key_switching_key
|
||||
.input_key_lwe_dimension()
|
||||
.to_lwe_size(),
|
||||
PBSOrder::BootstrapKeyswitch => server_key
|
||||
.bootstrapping_key
|
||||
.input_lwe_dimension()
|
||||
.to_lwe_size(),
|
||||
};
|
||||
|
||||
allocate_and_trivially_encrypt_new_lwe_ciphertext(
|
||||
lwe_size,
|
||||
plain,
|
||||
CiphertextModulus::new_native(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryGatesEngine<&Ciphertext, &Ciphertext, ServerKey> for BooleanEngine {
|
||||
fn and(
|
||||
&mut self,
|
||||
@@ -643,9 +638,7 @@ impl BinaryGatesEngine<&Ciphertext, &Ciphertext, ServerKey> for BooleanEngine {
|
||||
lwe_ciphertext_plaintext_add_assign(&mut buffer_lwe_before_pbs, cst);
|
||||
|
||||
// compute the bootstrap and the key switch
|
||||
bootstrapper
|
||||
.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
.unwrap()
|
||||
bootstrapper.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -683,9 +676,7 @@ impl BinaryGatesEngine<&Ciphertext, &Ciphertext, ServerKey> for BooleanEngine {
|
||||
lwe_ciphertext_plaintext_add_assign(&mut buffer_lwe_before_pbs, cst);
|
||||
|
||||
// compute the bootstrap and the key switch
|
||||
bootstrapper
|
||||
.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
.unwrap()
|
||||
bootstrapper.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -724,9 +715,7 @@ impl BinaryGatesEngine<&Ciphertext, &Ciphertext, ServerKey> for BooleanEngine {
|
||||
lwe_ciphertext_plaintext_add_assign(&mut buffer_lwe_before_pbs, cst);
|
||||
|
||||
// compute the bootstrap and the key switch
|
||||
bootstrapper
|
||||
.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
.unwrap()
|
||||
bootstrapper.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -763,9 +752,7 @@ impl BinaryGatesEngine<&Ciphertext, &Ciphertext, ServerKey> for BooleanEngine {
|
||||
lwe_ciphertext_plaintext_add_assign(&mut buffer_lwe_before_pbs, cst);
|
||||
|
||||
// compute the bootstrap and the key switch
|
||||
bootstrapper
|
||||
.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
.unwrap()
|
||||
bootstrapper.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -805,9 +792,7 @@ impl BinaryGatesEngine<&Ciphertext, &Ciphertext, ServerKey> for BooleanEngine {
|
||||
lwe_ciphertext_cleartext_mul_assign(&mut buffer_lwe_before_pbs, cst_mul);
|
||||
|
||||
// compute the bootstrap and the key switch
|
||||
bootstrapper
|
||||
.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
.unwrap()
|
||||
bootstrapper.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -849,9 +834,7 @@ impl BinaryGatesEngine<&Ciphertext, &Ciphertext, ServerKey> for BooleanEngine {
|
||||
lwe_ciphertext_cleartext_mul_assign(&mut buffer_lwe_before_pbs, cst_mul);
|
||||
|
||||
// compute the bootstrap and the key switch
|
||||
bootstrapper
|
||||
.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
.unwrap()
|
||||
bootstrapper.apply_bootstrapping_pattern(buffer_lwe_before_pbs, server_key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ impl KeySwitchingKey {
|
||||
*ct_dest = Ciphertext::Encrypted(cipher_dest);
|
||||
}
|
||||
Ciphertext::Encrypted(ref mut cipher_dest) => {
|
||||
keyswitch_lwe_ciphertext(&self.key_switching_key, cipher, cipher_dest)
|
||||
keyswitch_lwe_ciphertext(&self.key_switching_key, cipher, cipher_dest);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -60,5 +60,5 @@ impl Keycache {
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
pub static ref KEY_CACHE: Keycache = Default::default();
|
||||
pub static ref KEY_CACHE: Keycache = Keycache::default();
|
||||
}
|
||||
|
||||
@@ -83,8 +83,8 @@ impl BooleanParameters {
|
||||
glwe_modular_std_dev,
|
||||
pbs_base_log,
|
||||
pbs_level,
|
||||
ks_level,
|
||||
ks_base_log,
|
||||
ks_level,
|
||||
encryption_key_choice,
|
||||
}
|
||||
}
|
||||
@@ -108,8 +108,8 @@ impl BooleanKeySwitchingParameters {
|
||||
/// results with 128 bits of security.
|
||||
pub fn new(ks_base_log: DecompositionBaseLog, ks_level: DecompositionLevelCount) -> Self {
|
||||
Self {
|
||||
ks_level,
|
||||
ks_base_log,
|
||||
ks_level,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,9 +95,9 @@ mod tests {
|
||||
};
|
||||
use crate::boolean::random_boolean;
|
||||
#[cfg(not(feature = "__coverage"))]
|
||||
const NB_TEST: usize = 32;
|
||||
const NB_TESTS: usize = 32;
|
||||
#[cfg(feature = "__coverage")]
|
||||
const NB_TEST: usize = 1;
|
||||
const NB_TESTS: usize = 1;
|
||||
|
||||
#[test]
|
||||
fn test_compressed_public_key_default_parameters() {
|
||||
@@ -115,7 +115,7 @@ mod tests {
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
let cpks = CompressedPublicKey::new(cks);
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
let expected_result = !(b1 && b2);
|
||||
|
||||
@@ -89,9 +89,9 @@ mod tests {
|
||||
|
||||
use super::PublicKey;
|
||||
#[cfg(not(feature = "__coverage"))]
|
||||
const NB_TEST: usize = 32;
|
||||
const NB_TESTS: usize = 32;
|
||||
#[cfg(feature = "__coverage")]
|
||||
const NB_TEST: usize = 1;
|
||||
const NB_TESTS: usize = 1;
|
||||
|
||||
#[test]
|
||||
fn test_public_key_default_parameters() {
|
||||
@@ -109,7 +109,7 @@ mod tests {
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
let pks = PublicKey::new(cks);
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
let expected_result = !(b1 && b2);
|
||||
@@ -146,7 +146,7 @@ mod tests {
|
||||
let cpks = CompressedPublicKey::new(cks);
|
||||
let pks = PublicKey::from(cpks);
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
let expected_result = !(b1 && b2);
|
||||
|
||||
@@ -92,38 +92,38 @@ where
|
||||
{
|
||||
fn and_assign(&self, ct_left: Lhs, ct_right: Rhs) {
|
||||
<Self as DefaultImplementation>::Engine::with_thread_local_mut(|engine| {
|
||||
engine.and_assign(ct_left, ct_right, self)
|
||||
})
|
||||
engine.and_assign(ct_left, ct_right, self);
|
||||
});
|
||||
}
|
||||
|
||||
fn nand_assign(&self, ct_left: Lhs, ct_right: Rhs) {
|
||||
<Self as DefaultImplementation>::Engine::with_thread_local_mut(|engine| {
|
||||
engine.nand_assign(ct_left, ct_right, self)
|
||||
})
|
||||
engine.nand_assign(ct_left, ct_right, self);
|
||||
});
|
||||
}
|
||||
|
||||
fn nor_assign(&self, ct_left: Lhs, ct_right: Rhs) {
|
||||
<Self as DefaultImplementation>::Engine::with_thread_local_mut(|engine| {
|
||||
engine.nor_assign(ct_left, ct_right, self)
|
||||
})
|
||||
engine.nor_assign(ct_left, ct_right, self);
|
||||
});
|
||||
}
|
||||
|
||||
fn or_assign(&self, ct_left: Lhs, ct_right: Rhs) {
|
||||
<Self as DefaultImplementation>::Engine::with_thread_local_mut(|engine| {
|
||||
engine.or_assign(ct_left, ct_right, self)
|
||||
})
|
||||
engine.or_assign(ct_left, ct_right, self);
|
||||
});
|
||||
}
|
||||
|
||||
fn xor_assign(&self, ct_left: Lhs, ct_right: Rhs) {
|
||||
<Self as DefaultImplementation>::Engine::with_thread_local_mut(|engine| {
|
||||
engine.xor_assign(ct_left, ct_right, self)
|
||||
})
|
||||
engine.xor_assign(ct_left, ct_right, self);
|
||||
});
|
||||
}
|
||||
|
||||
fn xnor_assign(&self, ct_left: Lhs, ct_right: Rhs) {
|
||||
<Self as DefaultImplementation>::Engine::with_thread_local_mut(|engine| {
|
||||
engine.xnor_assign(ct_left, ct_right, self)
|
||||
})
|
||||
engine.xnor_assign(ct_left, ct_right, self);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ impl ServerKey {
|
||||
}
|
||||
|
||||
pub fn not_assign(&self, ct: &mut Ciphertext) {
|
||||
BooleanEngine::with_thread_local_mut(|engine| engine.not_assign(ct))
|
||||
BooleanEngine::with_thread_local_mut(|engine| engine.not_assign(ct));
|
||||
}
|
||||
|
||||
pub fn mux(
|
||||
|
||||
@@ -7,12 +7,12 @@ use crate::boolean::{random_boolean, random_integer};
|
||||
|
||||
/// Number of assert in randomized tests
|
||||
#[cfg(not(feature = "__coverage"))]
|
||||
const NB_TEST: usize = 128;
|
||||
const NB_TESTS: usize = 128;
|
||||
|
||||
// Use lower numbers for coverage to ensure fast tests to counter balance slowdown due to code
|
||||
// instrumentation
|
||||
#[cfg(feature = "__coverage")]
|
||||
const NB_TEST: usize = 1;
|
||||
const NB_TESTS: usize = 1;
|
||||
|
||||
/// Number of ciphertext in the deep circuit test
|
||||
const NB_CT: usize = 8;
|
||||
@@ -261,7 +261,7 @@ fn test_encrypt_decrypt_lwe_secret_key(parameters: BooleanParameters) {
|
||||
let keys = KEY_CACHE.get_from_param(parameters);
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
// encryption of false
|
||||
let ct_false = cks.encrypt(false);
|
||||
|
||||
@@ -310,7 +310,7 @@ fn test_and_gate(parameters: BooleanParameters) {
|
||||
let keys = KEY_CACHE.get_from_param(parameters);
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
// generation of two random booleans
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
@@ -385,7 +385,7 @@ fn test_mux_gate(parameters: BooleanParameters) {
|
||||
let keys = KEY_CACHE.get_from_param(parameters);
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
// generation of three random booleans
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
@@ -419,7 +419,7 @@ fn test_nand_gate(parameters: BooleanParameters) {
|
||||
let keys = KEY_CACHE.get_from_param(parameters);
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
// generation of two random booleans
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
@@ -494,7 +494,7 @@ fn test_nor_gate(parameters: BooleanParameters) {
|
||||
let keys = KEY_CACHE.get_from_param(parameters);
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
// generation of two random booleans
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
@@ -569,7 +569,7 @@ fn test_not_gate(parameters: BooleanParameters) {
|
||||
let keys = KEY_CACHE.get_from_param(parameters);
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
// generation of one random booleans
|
||||
let b1 = random_boolean();
|
||||
let expected_result = !b1;
|
||||
@@ -602,7 +602,7 @@ fn test_or_gate(parameters: BooleanParameters) {
|
||||
let keys = KEY_CACHE.get_from_param(parameters);
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
// generation of two random booleans
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
@@ -677,7 +677,7 @@ fn test_xnor_gate(parameters: BooleanParameters) {
|
||||
let keys = KEY_CACHE.get_from_param(parameters);
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
// generation of two random booleans
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
@@ -752,7 +752,7 @@ fn test_xor_gate(parameters: BooleanParameters) {
|
||||
let keys = KEY_CACHE.get_from_param(parameters);
|
||||
let (cks, sks) = (keys.client_key(), keys.server_key());
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
for _ in 0..NB_TESTS {
|
||||
// generation of two random booleans
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
|
||||
@@ -117,6 +117,6 @@ pub unsafe extern "C" fn core_crypto_par_generate_lwe_multi_bit_bootstrapping_ke
|
||||
&mut bsk,
|
||||
glwe_encryption_std_dev,
|
||||
&mut encryption_random_generator,
|
||||
)
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ pub unsafe extern "C" fn u256_from_little_endian_bytes(
|
||||
let input = std::slice::from_raw_parts(input, len);
|
||||
inner.copy_from_le_byte_slice(input);
|
||||
|
||||
*result = U256::from(inner)
|
||||
*result = U256::from(inner);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ pub unsafe extern "C" fn u256_from_big_endian_bytes(
|
||||
let input = std::slice::from_raw_parts(input, len);
|
||||
inner.copy_from_be_byte_slice(input);
|
||||
|
||||
*result = U256::from(inner)
|
||||
*result = U256::from(inner);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::c_api::buffer::*;
|
||||
use crate::c_api::utils::*;
|
||||
use crate::shortint::ciphertext::Degree;
|
||||
use std::os::raw::c_int;
|
||||
|
||||
use crate::shortint;
|
||||
@@ -17,7 +18,7 @@ pub unsafe extern "C" fn shortint_ciphertext_set_degree(
|
||||
|
||||
let inner_ct = &mut ciphertext.0;
|
||||
|
||||
inner_ct.degree.0 = degree;
|
||||
inner_ct.degree = Degree::new(degree);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -33,7 +34,7 @@ pub unsafe extern "C" fn shortint_ciphertext_get_degree(
|
||||
|
||||
let inner_ct = &ciphertext.0;
|
||||
|
||||
*result = inner_ct.degree.0;
|
||||
*result = inner_ct.degree.get();
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -135,14 +135,11 @@ pub unsafe extern "C" fn shortint_server_key_bivariate_programmable_bootstrap(
|
||||
let ct_left = get_mut_checked(ct_left).unwrap();
|
||||
let ct_right = get_mut_checked(ct_right).unwrap();
|
||||
|
||||
let res = crate::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
|
||||
engine.smart_apply_lookup_table_bivariate(
|
||||
&server_key.0,
|
||||
&mut ct_left.0,
|
||||
&mut ct_right.0,
|
||||
&lookup_table.0,
|
||||
)
|
||||
});
|
||||
let res = server_key.0.smart_apply_lookup_table_bivariate(
|
||||
&mut ct_left.0,
|
||||
&mut ct_right.0,
|
||||
&lookup_table.0,
|
||||
);
|
||||
|
||||
let heap_allocated_result = Box::new(ShortintCiphertext(res));
|
||||
|
||||
@@ -162,14 +159,10 @@ pub unsafe extern "C" fn shortint_server_key_bivariate_programmable_bootstrap_as
|
||||
let lookup_table = get_ref_checked(lookup_table).unwrap();
|
||||
let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap();
|
||||
let ct_right = get_mut_checked(ct_right).unwrap();
|
||||
|
||||
crate::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
|
||||
engine.smart_apply_lookup_table_bivariate_assign(
|
||||
&server_key.0,
|
||||
&mut ct_left_and_result.0,
|
||||
&mut ct_right.0,
|
||||
&lookup_table.0,
|
||||
)
|
||||
});
|
||||
server_key.0.smart_apply_lookup_table_bivariate_assign(
|
||||
&mut ct_left_and_result.0,
|
||||
&mut ct_right.0,
|
||||
&lookup_table.0,
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -504,7 +504,7 @@ pub fn encrypt_constant_seeded_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Nois
|
||||
encoded,
|
||||
noise_parameters,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Convenience function to share the core logic of the parallel seeded GGSW encryption between all
|
||||
|
||||
@@ -1067,7 +1067,7 @@ pub fn encrypt_seeded_glwe_ciphertext_list_with_existing_generator<
|
||||
&plaintext_list,
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -107,8 +107,7 @@ pub fn extract_lwe_sample_from_glwe_ciphertext<Scalar, InputCont, OutputCont>(
|
||||
assert_eq!(
|
||||
in_lwe_dim, out_lwe_dim,
|
||||
"Mismatch between equivalent LweDimension of input ciphertext and output ciphertext. \
|
||||
Got {:?} for input and {:?} for output.",
|
||||
in_lwe_dim, out_lwe_dim,
|
||||
Got {in_lwe_dim:?} for input and {out_lwe_dim:?} for output.",
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
@@ -369,8 +368,7 @@ pub fn par_extract_lwe_sample_from_glwe_ciphertext_with_thread_count<
|
||||
assert_eq!(
|
||||
in_lwe_dim, out_lwe_dim,
|
||||
"Mismatch between equivalent LweDimension of input ciphertext and output ciphertext. \
|
||||
Got {:?} for input and {:?} for output.",
|
||||
in_lwe_dim, out_lwe_dim,
|
||||
Got {in_lwe_dim:?} for input and {out_lwe_dim:?} for output.",
|
||||
);
|
||||
|
||||
assert!(
|
||||
|
||||
@@ -67,5 +67,5 @@ pub fn generate_binary_glwe_secret_key<Scalar, InCont, Gen>(
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
generator.fill_slice_with_random_uniform_binary(glwe_secret_key.as_mut())
|
||||
generator.fill_slice_with_random_uniform_binary(glwe_secret_key.as_mut());
|
||||
}
|
||||
|
||||
@@ -534,7 +534,7 @@ pub fn par_generate_seeded_lwe_bootstrap_key<
|
||||
noise_parameters,
|
||||
&mut generator,
|
||||
);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// Parallel variant of [`allocate_and_generate_new_seeded_lwe_bootstrap_key`], it is recommended to
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::*;
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::dispersion::DispersionParameter;
|
||||
use crate::core_crypto::commons::generators::{EncryptionRandomGenerator, SecretRandomGenerator};
|
||||
use crate::core_crypto::commons::math::random::{ActivatedRandomGenerator, RandomGenerator};
|
||||
@@ -16,7 +17,56 @@ use rayon::prelude::*;
|
||||
pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
output_mask: &mut LweMask<OutputCont>,
|
||||
output_body: LweBodyRefMut<Scalar>,
|
||||
output_body: &mut LweBodyRefMut<Scalar>,
|
||||
encoded: Plaintext<Scalar>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus(),
|
||||
"Mismatched moduli between mask ({:?}) and body ({:?})",
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_mask.ciphertext_modulus();
|
||||
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
fill_lwe_mask_and_body_for_encryption_native_mod_compatible(
|
||||
lwe_secret_key,
|
||||
output_mask,
|
||||
output_body,
|
||||
encoded,
|
||||
noise_parameters,
|
||||
generator,
|
||||
);
|
||||
} else {
|
||||
fill_lwe_mask_and_body_for_encryption_other_mod(
|
||||
lwe_secret_key,
|
||||
output_mask,
|
||||
output_body,
|
||||
encoded,
|
||||
noise_parameters,
|
||||
generator,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fill_lwe_mask_and_body_for_encryption_native_mod_compatible<
|
||||
Scalar,
|
||||
KeyCont,
|
||||
OutputCont,
|
||||
Gen,
|
||||
>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
output_mask: &mut LweMask<OutputCont>,
|
||||
output_body: &mut LweBodyRefMut<Scalar>,
|
||||
encoded: Plaintext<Scalar>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
@@ -38,23 +88,78 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// generate a randomly uniform mask
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
|
||||
// generate an error from the normal distribution described by std_dev
|
||||
*output_body.data = generator.random_noise_custom_mod(noise_parameters, ciphertext_modulus);
|
||||
*output_body.data = (*output_body.data).wrapping_add(encoded.0);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
|
||||
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
|
||||
*output_body.data = (*output_body.data).wrapping_mul(torus_scaling);
|
||||
}
|
||||
|
||||
let noise = generator.random_noise_custom_mod(noise_parameters, ciphertext_modulus);
|
||||
// compute the multisum between the secret key and the mask
|
||||
*output_body.data = (*output_body.data).wrapping_add(slice_wrapping_dot_product(
|
||||
let mask_key_dot_product = (*output_body.data).wrapping_add(slice_wrapping_dot_product(
|
||||
output_mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
));
|
||||
|
||||
// Store sum(ai * si) + delta * m + e in the body
|
||||
*output_body.data = mask_key_dot_product
|
||||
.wrapping_add(encoded.0)
|
||||
.wrapping_add(noise);
|
||||
|
||||
match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::Native => (),
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
// Manage power of 2 encoding to map to the native case
|
||||
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
|
||||
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
|
||||
*output_body.data = (*output_body.data).wrapping_mul(torus_scaling);
|
||||
}
|
||||
CiphertextModulusKind::Other => unreachable!(),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn fill_lwe_mask_and_body_for_encryption_other_mod<Scalar, KeyCont, OutputCont, Gen>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
output_mask: &mut LweMask<OutputCont>,
|
||||
output_body: &mut LweBodyRefMut<Scalar>,
|
||||
encoded: Plaintext<Scalar>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus(),
|
||||
"Mismatched moduli between mask ({:?}) and body ({:?})",
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_mask.ciphertext_modulus();
|
||||
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// generate a randomly uniform mask
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
|
||||
// generate an error from the normal distribution described by std_dev
|
||||
let noise = generator.random_noise_custom_mod(noise_parameters, ciphertext_modulus);
|
||||
|
||||
let ciphertext_modulus_as_scalar: Scalar = ciphertext_modulus.get_custom_modulus().cast_into();
|
||||
|
||||
// compute the multisum between the secret key and the mask
|
||||
let mask_key_dot_product = slice_wrapping_dot_product_custom_mod(
|
||||
output_mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
ciphertext_modulus_as_scalar,
|
||||
);
|
||||
|
||||
// Store sum(ai * si) + delta * m + e in the body
|
||||
*output_body.data = mask_key_dot_product
|
||||
.wrapping_add_custom_mod(encoded.0, ciphertext_modulus_as_scalar)
|
||||
.wrapping_add_custom_mod(noise, ciphertext_modulus_as_scalar);
|
||||
}
|
||||
|
||||
/// Encrypt an input plaintext in an output [`LWE ciphertext`](`LweCiphertext`).
|
||||
@@ -135,12 +240,12 @@ pub fn encrypt_lwe_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
lwe_secret_key.lwe_dimension()
|
||||
);
|
||||
|
||||
let (mut mask, body) = output.get_mut_mask_and_body();
|
||||
let (mut mask, mut body) = output.get_mut_mask_and_body();
|
||||
|
||||
fill_lwe_mask_and_body_for_encryption(
|
||||
lwe_secret_key,
|
||||
&mut mask,
|
||||
body,
|
||||
&mut body,
|
||||
encoded,
|
||||
noise_parameters,
|
||||
generator,
|
||||
@@ -294,15 +399,15 @@ pub fn trivially_encrypt_lwe_ciphertext<Scalar, OutputCont>(
|
||||
output.get_mut_mask().as_mut().fill(Scalar::ZERO);
|
||||
|
||||
let output_body = output.get_mut_body();
|
||||
|
||||
*output_body.data = encoded.0;
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
*output_body.data = (*output_body.data)
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
}
|
||||
|
||||
*output_body.data = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::Other => encoded.0,
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
// Manage non native power of 2 encoding
|
||||
encoded.0 * ciphertext_modulus.get_power_of_two_scaling_to_native_torus()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// A trivial encryption uses a zero mask and no noise.
|
||||
@@ -373,15 +478,15 @@ where
|
||||
*new_ct.get_mut_body().data = encoded.0;
|
||||
|
||||
let output_body = new_ct.get_mut_body();
|
||||
|
||||
*output_body.data = encoded.0;
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
*output_body.data = (*output_body.data)
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
}
|
||||
|
||||
*output_body.data = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::Other => encoded.0,
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
// Manage non native power of 2 encoding
|
||||
encoded.0 * ciphertext_modulus.get_power_of_two_scaling_to_native_torus()
|
||||
}
|
||||
};
|
||||
|
||||
new_ct
|
||||
}
|
||||
@@ -398,6 +503,24 @@ pub fn decrypt_lwe_ciphertext<Scalar, KeyCont, InputCont>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
) -> Plaintext<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
let ciphertext_modulus = lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
decrypt_lwe_ciphertext_native_mod_compatible(lwe_secret_key, lwe_ciphertext)
|
||||
} else {
|
||||
decrypt_lwe_ciphertext_other_mod(lwe_secret_key, lwe_ciphertext)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_lwe_ciphertext_native_mod_compatible<Scalar, KeyCont, InputCont>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
) -> Plaintext<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
@@ -417,21 +540,55 @@ where
|
||||
|
||||
let (mask, body) = lwe_ciphertext.get_mask_and_body();
|
||||
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
Plaintext((*body.data).wrapping_sub(slice_wrapping_dot_product(
|
||||
let mask_key_dot_product = slice_wrapping_dot_product(mask.as_ref(), lwe_secret_key.as_ref());
|
||||
let plaintext = (*body.data).wrapping_sub(mask_key_dot_product);
|
||||
|
||||
match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::Native => Plaintext(plaintext),
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
// Manage power of 2 encoding
|
||||
Plaintext(
|
||||
plaintext
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
)
|
||||
}
|
||||
CiphertextModulusKind::Other => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_lwe_ciphertext_other_mod<Scalar, KeyCont, InputCont>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
) -> Plaintext<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
lwe_ciphertext.lwe_size().to_lwe_dimension() == lwe_secret_key.lwe_dimension(),
|
||||
"Mismatch between LweDimension of output ciphertext and input secret key. \
|
||||
Got {:?} in output, and {:?} in secret key.",
|
||||
lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
lwe_secret_key.lwe_dimension()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
let (mask, body) = lwe_ciphertext.get_mask_and_body();
|
||||
|
||||
let ciphertext_modulus_as_scalar: Scalar = ciphertext_modulus.get_custom_modulus().cast_into();
|
||||
|
||||
Plaintext((*body.data).wrapping_sub_custom_mod(
|
||||
slice_wrapping_dot_product_custom_mod(
|
||||
mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
)))
|
||||
} else {
|
||||
Plaintext(
|
||||
(*body.data)
|
||||
.wrapping_sub(slice_wrapping_dot_product(
|
||||
mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
))
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
)
|
||||
}
|
||||
ciphertext_modulus_as_scalar,
|
||||
),
|
||||
ciphertext_modulus_as_scalar,
|
||||
))
|
||||
}
|
||||
|
||||
/// Encrypt an input plaintext list in an output [`LWE ciphertext list`](`LweCiphertextList`).
|
||||
@@ -541,7 +698,7 @@ pub fn encrypt_lwe_ciphertext_list<Scalar, KeyCont, OutputCont, InputCont, Gen>(
|
||||
encoded_plaintext_ref.into(),
|
||||
noise_parameters,
|
||||
&mut loop_generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,7 +809,7 @@ pub fn par_encrypt_lwe_ciphertext_list<Scalar, KeyCont, OutputCont, InputCont, G
|
||||
encoded_plaintext_ref.into(),
|
||||
noise_parameters,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -781,10 +938,6 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
|
||||
lwe_public_key.lwe_size().to_lwe_dimension()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
output.as_mut().fill(Scalar::ZERO);
|
||||
|
||||
let mut tmp_zero_encryption =
|
||||
@@ -806,17 +959,7 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
|
||||
lwe_ciphertext_add_assign(output, &tmp_zero_encryption);
|
||||
}
|
||||
|
||||
let body = output.get_mut_body();
|
||||
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_add(encoded.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
encoded
|
||||
.0
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
lwe_ciphertext_plaintext_add_assign(output, encoded);
|
||||
}
|
||||
|
||||
/// Encrypt an input plaintext in an output [`LWE ciphertext`](`LweCiphertext`) using a
|
||||
@@ -920,8 +1063,6 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
|
||||
|
||||
let ciphertext_modulus = output.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
let mut tmp_zero_encryption =
|
||||
LweCiphertext::new(Scalar::ZERO, lwe_public_key.lwe_size(), ciphertext_modulus);
|
||||
|
||||
@@ -933,7 +1074,7 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
|
||||
let (mut mask, body) = tmp_zero_encryption.get_mut_mask_and_body();
|
||||
random_generator
|
||||
.fill_slice_with_random_uniform_custom_mod(mask.as_mut(), ciphertext_modulus);
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
if ciphertext_modulus.is_non_native_power_of_two() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
mask.as_mut(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
@@ -947,17 +1088,7 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
|
||||
lwe_ciphertext_add_assign(output, &tmp_zero_encryption);
|
||||
}
|
||||
|
||||
// Add encoded plaintext
|
||||
let body = output.get_mut_body();
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_add(encoded.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
encoded
|
||||
.0
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
lwe_ciphertext_plaintext_add_assign(output, encoded);
|
||||
}
|
||||
|
||||
/// Convenience function to share the core logic of the seeded LWE encryption between all functions
|
||||
@@ -1005,17 +1136,17 @@ pub fn encrypt_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
.fork_lwe_list_to_lwe::<Scalar>(output.lwe_ciphertext_count(), output.lwe_size())
|
||||
.unwrap();
|
||||
|
||||
for ((output_body, plaintext), mut loop_generator) in
|
||||
for ((mut output_body, plaintext), mut loop_generator) in
|
||||
output.iter_mut().zip(encoded.iter()).zip(gen_iter)
|
||||
{
|
||||
fill_lwe_mask_and_body_for_encryption(
|
||||
lwe_secret_key,
|
||||
&mut output_mask,
|
||||
output_body,
|
||||
&mut output_body,
|
||||
plaintext.into(),
|
||||
noise_parameters,
|
||||
&mut loop_generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1166,17 +1297,17 @@ pub fn par_encrypt_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
.par_iter_mut()
|
||||
.zip(encoded.par_iter())
|
||||
.zip(gen_iter)
|
||||
.for_each(|((output_body, plaintext), mut loop_generator)| {
|
||||
.for_each(|((mut output_body, plaintext), mut loop_generator)| {
|
||||
let mut output_mask =
|
||||
LweMask::from_container(vec![Scalar::ZERO; lwe_dimension.0], ciphertext_modulus);
|
||||
fill_lwe_mask_and_body_for_encryption(
|
||||
lwe_secret_key,
|
||||
&mut output_mask,
|
||||
output_body,
|
||||
&mut output_body,
|
||||
plaintext.into(),
|
||||
noise_parameters,
|
||||
&mut loop_generator,
|
||||
)
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1300,11 +1431,11 @@ pub fn encrypt_seeded_lwe_ciphertext_with_existing_generator<Scalar, KeyCont, Ge
|
||||
fill_lwe_mask_and_body_for_encryption(
|
||||
lwe_secret_key,
|
||||
&mut mask,
|
||||
output.get_mut_body(),
|
||||
&mut output.get_mut_body(),
|
||||
encoded,
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Encrypt an input plaintext in an output [`seeded LWE ciphertext`](`SeededLweCiphertext`).
|
||||
@@ -1393,7 +1524,7 @@ pub fn encrypt_seeded_lwe_ciphertext<Scalar, KeyCont, NoiseSeeder>(
|
||||
encoded,
|
||||
noise_parameters,
|
||||
&mut encryption_generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Allocate a new [`seeded LWE ciphertext`](`SeededLweCiphertext`) and encrypt an input plaintext
|
||||
@@ -1820,7 +1951,7 @@ pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key<
|
||||
.for_each(|(dst, (&src, plaintext))| {
|
||||
*dst.data = src
|
||||
.wrapping_add(loop_generator.random_noise(body_noise_parameters))
|
||||
.wrapping_add(*plaintext.0)
|
||||
.wrapping_add(*plaintext.0);
|
||||
});
|
||||
},
|
||||
);
|
||||
@@ -2033,7 +2164,7 @@ pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key<
|
||||
.for_each(|(dst, (&src, plaintext))| {
|
||||
*dst.data = src
|
||||
.wrapping_add(loop_generator.random_noise(body_noise_parameters))
|
||||
.wrapping_add(*plaintext.0)
|
||||
.wrapping_add(*plaintext.0);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
@@ -266,7 +266,7 @@ pub fn par_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
|
||||
input_lwe_ciphertext,
|
||||
output_lwe_ciphertext,
|
||||
thread_count,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`keyswitch_lwe_ciphertext`].
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
//! like addition, multiplication, etc.
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::*;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::numeric::UnsignedInteger;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -71,6 +72,22 @@ pub fn lwe_ciphertext_add_assign<Scalar, LhsCont, RhsCont>(
|
||||
Scalar: UnsignedInteger,
|
||||
LhsCont: ContainerMut<Element = Scalar>,
|
||||
RhsCont: Container<Element = Scalar>,
|
||||
{
|
||||
let ciphertext_modulus = rhs.ciphertext_modulus();
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
lwe_ciphertext_add_assign_native_mod_compatible(lhs, rhs);
|
||||
} else {
|
||||
lwe_ciphertext_add_assign_other_mod(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_add_assign_native_mod_compatible<Scalar, LhsCont, RhsCont>(
|
||||
lhs: &mut LweCiphertext<LhsCont>,
|
||||
rhs: &LweCiphertext<RhsCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
LhsCont: ContainerMut<Element = Scalar>,
|
||||
RhsCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(
|
||||
lhs.ciphertext_modulus(),
|
||||
@@ -79,10 +96,37 @@ pub fn lwe_ciphertext_add_assign<Scalar, LhsCont, RhsCont>(
|
||||
lhs.ciphertext_modulus(),
|
||||
rhs.ciphertext_modulus()
|
||||
);
|
||||
let ciphertext_modulus = rhs.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
slice_wrapping_add_assign(lhs.as_mut(), rhs.as_ref());
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_add_assign_other_mod<Scalar, LhsCont, RhsCont>(
|
||||
lhs: &mut LweCiphertext<LhsCont>,
|
||||
rhs: &LweCiphertext<RhsCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
LhsCont: ContainerMut<Element = Scalar>,
|
||||
RhsCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(
|
||||
lhs.ciphertext_modulus(),
|
||||
rhs.ciphertext_modulus(),
|
||||
"Mismatched moduli between lhs ({:?}) and rhs ({:?}) LweCiphertext",
|
||||
lhs.ciphertext_modulus(),
|
||||
rhs.ciphertext_modulus()
|
||||
);
|
||||
let ciphertext_modulus = rhs.ciphertext_modulus();
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
slice_wrapping_add_assign_custom_mod(
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
);
|
||||
}
|
||||
|
||||
/// Add the right-hand side [`LWE ciphertext`](`LweCiphertext`) to the left-hand side [`LWE
|
||||
/// ciphertext`](`LweCiphertext`) writing the result in the output [`LWE
|
||||
/// ciphertext`](`LweCiphertext`).
|
||||
@@ -235,18 +279,50 @@ pub fn lwe_ciphertext_plaintext_add_assign<Scalar, InCont>(
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let ciphertext_modulus = lhs.ciphertext_modulus();
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
lwe_ciphertext_plaintext_add_assign_native_mod_compatible(lhs, rhs);
|
||||
} else {
|
||||
lwe_ciphertext_plaintext_add_assign_other_mod(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_plaintext_add_assign_native_mod_compatible<Scalar, InCont>(
|
||||
lhs: &mut LweCiphertext<InCont>,
|
||||
rhs: Plaintext<Scalar>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_add(rhs.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
rhs.0
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
|
||||
let plaintext = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::Native => rhs.0,
|
||||
// Manage power of 2 encoding
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => rhs
|
||||
.0
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
CiphertextModulusKind::Other => unreachable!(),
|
||||
};
|
||||
|
||||
*body.data = (*body.data).wrapping_add(plaintext);
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_plaintext_add_assign_other_mod<Scalar, InCont>(
|
||||
lhs: &mut LweCiphertext<InCont>,
|
||||
rhs: Plaintext<Scalar>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
*body.data = (*body.data)
|
||||
.wrapping_add_custom_mod(rhs.0, ciphertext_modulus.get_custom_modulus().cast_into());
|
||||
}
|
||||
|
||||
/// Add the right-hand side encoded [`Plaintext`] to the left-hand side [`LWE
|
||||
@@ -311,18 +387,50 @@ pub fn lwe_ciphertext_plaintext_sub_assign<Scalar, InCont>(
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let ciphertext_modulus = lhs.ciphertext_modulus();
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
lwe_ciphertext_plaintext_sub_assign_native_mod_compatible(lhs, rhs);
|
||||
} else {
|
||||
lwe_ciphertext_plaintext_sub_assign_other_mod(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_plaintext_sub_assign_native_mod_compatible<Scalar, InCont>(
|
||||
lhs: &mut LweCiphertext<InCont>,
|
||||
rhs: Plaintext<Scalar>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_sub(rhs.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_sub(
|
||||
rhs.0
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
|
||||
let plaintext = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::Native => rhs.0,
|
||||
// Manage power of 2 encoding
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => rhs
|
||||
.0
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
CiphertextModulusKind::Other => unreachable!(),
|
||||
};
|
||||
|
||||
*body.data = (*body.data).wrapping_sub(plaintext);
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_plaintext_sub_assign_other_mod<Scalar, InCont>(
|
||||
lhs: &mut LweCiphertext<InCont>,
|
||||
rhs: Plaintext<Scalar>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
*body.data = (*body.data)
|
||||
.wrapping_sub_custom_mod(rhs.0, ciphertext_modulus.get_custom_modulus().cast_into());
|
||||
}
|
||||
|
||||
/// Compute the opposite of the input [`LWE ciphertext`](`LweCiphertext`) and update it in place.
|
||||
|
||||
@@ -385,7 +385,7 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
polynomial_wrapping_monic_monomial_div_assign(
|
||||
&mut poly,
|
||||
MonomialDegree(monomial_degree),
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
let fourier_multi_bit_ggsw_buffers = (0..thread_count.0)
|
||||
@@ -634,7 +634,7 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, InputCont, OutputCont
|
||||
polynomial_wrapping_monic_monomial_div_assign(
|
||||
&mut poly,
|
||||
MonomialDegree(monomial_degree),
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
let fourier_multi_bit_ggsw_buffers = (0..thread_count.0)
|
||||
@@ -1382,7 +1382,7 @@ pub fn std_multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>
|
||||
polynomial_wrapping_monic_monomial_div_assign(
|
||||
&mut poly,
|
||||
MonomialDegree(monomial_degree),
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
let fourier_multi_bit_ggsw_buffers = (0..thread_count.0)
|
||||
@@ -1666,7 +1666,7 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, InputCont, Output
|
||||
polynomial_wrapping_monic_monomial_div_assign(
|
||||
&mut poly,
|
||||
MonomialDegree(monomial_degree),
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
let fourier_multi_bit_ggsw_buffers = (0..thread_count.0)
|
||||
|
||||
@@ -373,7 +373,7 @@ pub fn keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphertext<
|
||||
.as_mut_polynomial_list()
|
||||
.iter_mut()
|
||||
.for_each(|mut poly| {
|
||||
polynomial_wrapping_monic_monomial_mul_assign(&mut poly, MonomialDegree(degree))
|
||||
polynomial_wrapping_monic_monomial_mul_assign(&mut poly, MonomialDegree(degree));
|
||||
});
|
||||
slice_wrapping_add_assign(output_glwe_ciphertext.as_mut(), buffer.as_ref());
|
||||
}
|
||||
@@ -509,7 +509,7 @@ pub fn par_keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphertext<
|
||||
input_lwe_ciphertext_list,
|
||||
output_glwe_ciphertext,
|
||||
thread_count,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphertext`].
|
||||
@@ -739,7 +739,7 @@ pub fn par_keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphertext_with_thread
|
||||
polynomial_wrapping_monic_monomial_mul_assign(
|
||||
&mut poly,
|
||||
MonomialDegree(chunk_idx * chunk_size),
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
buffer
|
||||
|
||||
@@ -282,7 +282,7 @@ pub fn private_functional_keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphert
|
||||
.as_mut_polynomial_list()
|
||||
.iter_mut()
|
||||
.for_each(|mut poly| {
|
||||
polynomial_wrapping_monic_monomial_mul_assign(&mut poly, MonomialDegree(degree))
|
||||
polynomial_wrapping_monic_monomial_mul_assign(&mut poly, MonomialDegree(degree));
|
||||
});
|
||||
slice_wrapping_add_assign(output.as_mut(), buffer.as_ref());
|
||||
}
|
||||
|
||||
@@ -129,7 +129,7 @@ pub fn generate_lwe_private_functional_packing_keyswitch_key<
|
||||
&messages,
|
||||
noise_parameters,
|
||||
&mut loop_generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -246,7 +246,7 @@ pub fn par_generate_lwe_private_functional_packing_keyswitch_key<
|
||||
&messages,
|
||||
noise_parameters,
|
||||
&mut loop_generator,
|
||||
)
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -257,10 +257,10 @@ mod test {
|
||||
use crate::core_crypto::commons::math::random::Seed;
|
||||
use crate::core_crypto::prelude::*;
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
|
||||
#[test]
|
||||
fn test_pfpksk_list_gen_equivalence() {
|
||||
const NB_TESTS: usize = 10;
|
||||
|
||||
for _ in 0..NB_TESTS {
|
||||
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield
|
||||
// correct computations
|
||||
@@ -331,7 +331,7 @@ mod test {
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
assert_eq!(par_cbs_pfpksk, ser_cbs_pfpksk)
|
||||
assert_eq!(par_cbs_pfpksk, ser_cbs_pfpksk);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1057,7 +1057,7 @@ pub fn programmable_bootstrap_lwe_ciphertext<Scalar, InputCont, OutputCont, AccC
|
||||
fourier_bsk,
|
||||
fft,
|
||||
stack,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Memory optimized version of [`programmable_bootstrap_lwe_ciphertext`], the caller must provide
|
||||
@@ -1367,7 +1367,7 @@ pub fn programmable_bootstrap_f128_lwe_ciphertext<Scalar, InputCont, OutputCont,
|
||||
fourier_bsk,
|
||||
fft,
|
||||
stack,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Memory optimized version of [`programmable_bootstrap_f128_lwe_ciphertext`], the caller must
|
||||
|
||||
@@ -37,7 +37,7 @@ pub fn generate_lwe_public_key<Scalar, InputKeyCont, OutputKeyCont, Gen>(
|
||||
PlaintextCount(output.zero_encryption_count().0),
|
||||
);
|
||||
|
||||
encrypt_lwe_ciphertext_list(lwe_secret_key, output, &zeros, noise_parameters, generator)
|
||||
encrypt_lwe_ciphertext_list(lwe_secret_key, output, &zeros, noise_parameters, generator);
|
||||
}
|
||||
|
||||
/// Allocate a new [`LWE public key`](`LwePublicKey`) and fill it with an actual public key
|
||||
@@ -95,7 +95,7 @@ pub fn par_generate_lwe_public_key<Scalar, InputKeyCont, OutputKeyCont, Gen>(
|
||||
PlaintextCount(output.zero_encryption_count().0),
|
||||
);
|
||||
|
||||
par_encrypt_lwe_ciphertext_list(lwe_secret_key, output, &zeros, noise_parameters, generator)
|
||||
par_encrypt_lwe_ciphertext_list(lwe_secret_key, output, &zeros, noise_parameters, generator);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`allocate_and_generate_new_lwe_public_key`], it is recommended to use this
|
||||
@@ -227,7 +227,7 @@ pub fn par_generate_seeded_lwe_public_key<Scalar, InputKeyCont, OutputKeyCont, N
|
||||
&zeros,
|
||||
noise_parameters,
|
||||
noise_seeder,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`allocate_and_generate_new_seeded_lwe_public_key`], it is recommended to
|
||||
|
||||
@@ -62,5 +62,5 @@ pub fn generate_binary_lwe_secret_key<Scalar, InCont, Gen>(
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
generator.fill_slice_with_random_uniform_binary(lwe_secret_key.as_mut())
|
||||
generator.fill_slice_with_random_uniform_binary(lwe_secret_key.as_mut());
|
||||
}
|
||||
|
||||
@@ -44,8 +44,7 @@ where
|
||||
{
|
||||
assert!(
|
||||
ciphertext_modulus.is_native_modulus(),
|
||||
"This operation currently only supports native moduli, got modulus {:?}",
|
||||
ciphertext_modulus
|
||||
"This operation currently only supports native moduli, got modulus {ciphertext_modulus:?}"
|
||||
);
|
||||
|
||||
let mut cbs_pfpksk_list = LwePrivateFunctionalPackingKeyswitchKeyListOwned::new(
|
||||
@@ -185,8 +184,7 @@ where
|
||||
{
|
||||
assert!(
|
||||
ciphertext_modulus.is_native_modulus(),
|
||||
"This operation currently only supports native moduli, got modulus {:?}",
|
||||
ciphertext_modulus
|
||||
"This operation currently only supports native moduli, got modulus {ciphertext_modulus:?}"
|
||||
);
|
||||
|
||||
let mut cbs_pfpksk_list = LwePrivateFunctionalPackingKeyswitchKeyListOwned::new(
|
||||
@@ -362,7 +360,7 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized<
|
||||
number_of_bits_to_extract,
|
||||
fft,
|
||||
stack,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Return the required memory for [`extract_bits_from_lwe_ciphertext_mem_optimized`].
|
||||
@@ -699,7 +697,7 @@ pub fn circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list_mem_optimi
|
||||
base_log_cbs,
|
||||
fft,
|
||||
stack,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
||||
@@ -205,13 +205,14 @@ where
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
const NB_TESTS: usize = 1_000_000_000;
|
||||
|
||||
#[test]
|
||||
fn test_divide_funcs() {
|
||||
use rand::Rng;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
const NB_TESTS: usize = 1_000_000_000;
|
||||
const SCALING: f64 = u64::MAX as f64;
|
||||
for _ in 0..NB_TESTS {
|
||||
let num: f64 = rng.gen();
|
||||
|
||||
@@ -44,7 +44,7 @@ pub mod seeded_lwe_public_key_decompression;
|
||||
pub mod slice_algorithms;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
pub(crate) mod test;
|
||||
|
||||
// No pub use for slice and polynomial algorithms which would not interest higher level users
|
||||
// They can still be used via `use crate::core_crypto::algorithms::slice_algorithms::*;`
|
||||
|
||||
@@ -32,7 +32,7 @@ pub fn polynomial_wrapping_add_assign<Scalar, OutputCont, InputCont>(
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(lhs.polynomial_size(), rhs.polynomial_size());
|
||||
slice_wrapping_add_assign(lhs.as_mut(), rhs.as_ref())
|
||||
slice_wrapping_add_assign(lhs.as_mut(), rhs.as_ref());
|
||||
}
|
||||
|
||||
/// Subtract a polynomial to the output polynomial.
|
||||
@@ -61,7 +61,7 @@ pub fn polynomial_wrapping_sub_assign<Scalar, OutputCont, InputCont>(
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(lhs.polynomial_size(), rhs.polynomial_size());
|
||||
slice_wrapping_sub_assign(lhs.as_mut(), rhs.as_ref())
|
||||
slice_wrapping_sub_assign(lhs.as_mut(), rhs.as_ref());
|
||||
}
|
||||
|
||||
/// Add the sum of the element-wise product between two lists of polynomials to the output
|
||||
@@ -431,13 +431,6 @@ pub(crate) fn polynomial_wrapping_monic_monomial_mul_and_subtract<Scalar, Output
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
output.polynomial_size() == input.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input polynomial size {:?}.",
|
||||
output.polynomial_size(),
|
||||
input.polynomial_size(),
|
||||
);
|
||||
|
||||
/// performs the operation: dst = -src - src_orig, with wrapping arithmetic
|
||||
fn copy_with_neg_and_subtract<Scalar: UnsignedInteger>(
|
||||
dst: &mut [Scalar],
|
||||
@@ -460,6 +453,13 @@ pub(crate) fn polynomial_wrapping_monic_monomial_mul_and_subtract<Scalar, Output
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
output.polynomial_size() == input.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input polynomial size {:?}.",
|
||||
output.polynomial_size(),
|
||||
input.polynomial_size(),
|
||||
);
|
||||
|
||||
let polynomial_size = output.polynomial_size().0;
|
||||
let remaining_degree = monomial_degree.0 % polynomial_size;
|
||||
|
||||
@@ -750,7 +750,7 @@ where
|
||||
for (lhs_degree, &lhs_elt) in p.iter().enumerate() {
|
||||
let res = &mut res[lhs_degree..];
|
||||
for (&rhs_elt, res) in q.iter().zip(res) {
|
||||
*res = (*res).wrapping_add(lhs_elt.wrapping_mul(rhs_elt))
|
||||
*res = (*res).wrapping_add(lhs_elt.wrapping_mul(rhs_elt));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -942,41 +942,41 @@ mod test {
|
||||
|
||||
#[test]
|
||||
pub fn test_multiply_divide_unit_monomial_u32() {
|
||||
test_multiply_divide_unit_monomial::<u32>()
|
||||
test_multiply_divide_unit_monomial::<u32>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_multiply_divide_unit_monomial_u64() {
|
||||
test_multiply_divide_unit_monomial::<u64>()
|
||||
test_multiply_divide_unit_monomial::<u64>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_multiply_karatsuba_u32() {
|
||||
test_multiply_karatsuba::<u32>()
|
||||
test_multiply_karatsuba::<u32>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_multiply_karatsuba_u64() {
|
||||
test_multiply_karatsuba::<u64>()
|
||||
test_multiply_karatsuba::<u64>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_add_mul_u32() {
|
||||
test_add_mul::<u32>()
|
||||
test_add_mul::<u32>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_add_mul_u64() {
|
||||
test_add_mul::<u64>()
|
||||
test_add_mul::<u64>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_sub_mul_u32() {
|
||||
test_sub_mul::<u32>()
|
||||
test_sub_mul::<u32>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_sub_mul_u64() {
|
||||
test_sub_mul::<u64>()
|
||||
test_sub_mul::<u64>();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ pub fn decompress_seeded_ggsw_ciphertext<Scalar, InputCont, OutputCont, Gen>(
|
||||
output_ggsw,
|
||||
input_seeded_ggsw,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_ggsw_ciphertext_with_existing_generator`].
|
||||
@@ -133,7 +133,7 @@ pub fn par_decompress_seeded_ggsw_ciphertext_with_existing_generator<
|
||||
);
|
||||
},
|
||||
);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_ggsw_ciphertext`].
|
||||
@@ -151,5 +151,5 @@ pub fn par_decompress_seeded_ggsw_ciphertext<Scalar, InputCont, OutputCont, Gen>
|
||||
output_ggsw,
|
||||
input_seeded_ggsw,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ pub fn decompress_seeded_ggsw_ciphertext_list_with_existing_generator<
|
||||
&mut ggsw_out,
|
||||
&ggsw_in,
|
||||
&mut loop_generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ pub fn decompress_seeded_ggsw_ciphertext_list<Scalar, InputCont, OutputCont, Gen
|
||||
output_list,
|
||||
input_seeded_list,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_ggsw_ciphertext_list_with_existing_generator`].
|
||||
@@ -115,8 +115,8 @@ pub fn par_decompress_seeded_ggsw_ciphertext_list_with_existing_generator<
|
||||
&mut ggsw_out,
|
||||
&ggsw_in,
|
||||
&mut loop_generator,
|
||||
)
|
||||
})
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_ggsw_ciphertext_list`].
|
||||
@@ -134,5 +134,5 @@ pub fn par_decompress_seeded_ggsw_ciphertext_list<Scalar, InputCont, OutputCont,
|
||||
output_list,
|
||||
input_seeded_list,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -65,5 +65,5 @@ pub fn decompress_seeded_glwe_ciphertext<Scalar, InputCont, OutputCont, Gen>(
|
||||
output_glwe,
|
||||
input_seeded_glwe,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -65,5 +65,5 @@ pub fn decompress_seeded_glwe_ciphertext_list<Scalar, InputCont, OutputCont, Gen
|
||||
output_list,
|
||||
input_seeded_list,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -31,7 +31,9 @@ pub fn decompress_seeded_lwe_bootstrap_key_with_existing_generator<
|
||||
output_bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
decompress_seeded_ggsw_ciphertext_list_with_existing_generator(output_bsk, input_bsk, generator)
|
||||
decompress_seeded_ggsw_ciphertext_list_with_existing_generator(
|
||||
output_bsk, input_bsk, generator,
|
||||
);
|
||||
}
|
||||
|
||||
/// Decompress a [`SeededLweBootstrapKey`], without consuming it, into a standard
|
||||
@@ -59,7 +61,7 @@ pub fn decompress_seeded_lwe_bootstrap_key<Scalar, InputCont, OutputCont, Gen>(
|
||||
output_bsk,
|
||||
input_bsk,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_lwe_bootstrap_key_with_existing_generator`].
|
||||
@@ -89,7 +91,7 @@ pub fn par_decompress_seeded_lwe_bootstrap_key_with_existing_generator<
|
||||
|
||||
par_decompress_seeded_ggsw_ciphertext_list_with_existing_generator(
|
||||
output_bsk, input_bsk, generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_lwe_bootstrap_key`]`.
|
||||
@@ -116,5 +118,5 @@ pub fn par_decompress_seeded_lwe_bootstrap_key<Scalar, InputCont, OutputCont, Ge
|
||||
output_bsk,
|
||||
input_bsk,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Module with primitives pertaining to [`SeededLweCiphertext`] decompression.
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_mul_assign;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::generators::MaskRandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -26,16 +27,20 @@ pub fn decompress_seeded_lwe_ciphertext_with_existing_generator<Scalar, OutputCo
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_lwe.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
let (mut output_mask, output_body) = output_lwe.get_mut_mask_and_body();
|
||||
|
||||
// generate a uniformly random mask
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
match ciphertext_modulus.kind() {
|
||||
// Manage the specific encoding for non native power of 2
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
// Nothing to do
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::Other => (),
|
||||
}
|
||||
*output_body.data = *input_seeded_lwe.get_body().data;
|
||||
}
|
||||
@@ -64,5 +69,5 @@ pub fn decompress_seeded_lwe_ciphertext<Scalar, OutputCont, Gen>(
|
||||
output_lwe,
|
||||
input_seeded_lwe,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Module with primitives pertaining to [`SeededLweCiphertextList`] decompression.
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_mul_assign;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::generators::MaskRandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -34,7 +35,6 @@ pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_list.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// Generator forking and decompression computations must match the SeededLweCiphertextList
|
||||
// encryption algorithm
|
||||
@@ -55,12 +55,16 @@ pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
// Generate a uniformly random mask
|
||||
loop_generator
|
||||
.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
// Manage the power of 2 encoding we use
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
match ciphertext_modulus.kind() {
|
||||
// Manage the specific encoding for non native power of 2
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
// Nothing to do
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::Other => (),
|
||||
}
|
||||
*output_body.data = *body_in.data;
|
||||
}
|
||||
@@ -91,7 +95,7 @@ pub fn decompress_seeded_lwe_ciphertext_list<Scalar, InputCont, OutputCont, Gen>
|
||||
output_list,
|
||||
input_seeded_list,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parllel variant of [`decompress_seeded_lwe_ciphertext_list_with_existing_generator`].
|
||||
@@ -140,12 +144,16 @@ pub fn par_decompress_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
// Generate a uniformly random mask
|
||||
loop_generator
|
||||
.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
// Manage the power of 2 encoding we use
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
match ciphertext_modulus.kind() {
|
||||
// Manage the specific encoding for non native power of 2
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
// Nothing to do
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::Other => (),
|
||||
}
|
||||
*output_body.data = *body_in.data;
|
||||
});
|
||||
@@ -175,5 +183,5 @@ pub fn par_decompress_seeded_lwe_ciphertext_list<Scalar, InputCont, OutputCont,
|
||||
output_list,
|
||||
input_seeded_list,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -45,5 +45,5 @@ pub fn decompress_seeded_lwe_compact_public_key<Scalar, InputCont, OutputCont, G
|
||||
output_cpk,
|
||||
input_seeded_cpk,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ pub fn decompress_seeded_lwe_keyswitch_key_with_existing_generator<
|
||||
&mut output_ksk.as_mut_lwe_ciphertext_list(),
|
||||
&input_ksk.as_seeded_lwe_ciphertext_list(),
|
||||
generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Decompress a [`SeededLweKeyswitchKey`], without consuming it, into a standard
|
||||
@@ -45,7 +45,7 @@ pub fn decompress_seeded_lwe_keyswitch_key<Scalar, InputCont, OutputCont, Gen>(
|
||||
output_ksk,
|
||||
input_ksk,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_lwe_keyswitch_key_with_existing_generator`].
|
||||
@@ -68,7 +68,7 @@ pub fn par_decompress_seeded_lwe_keyswitch_key_with_existing_generator<
|
||||
&mut output_ksk.as_mut_lwe_ciphertext_list(),
|
||||
&input_ksk.as_seeded_lwe_ciphertext_list(),
|
||||
generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_lwe_keyswitch_key`].
|
||||
@@ -86,5 +86,5 @@ pub fn par_decompress_seeded_lwe_keyswitch_key<Scalar, InputCont, OutputCont, Ge
|
||||
output_ksk,
|
||||
input_ksk,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ pub fn decompress_seeded_lwe_multi_bit_bootstrap_key<Scalar, InputCont, OutputCo
|
||||
output_bsk,
|
||||
input_bsk,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator`].
|
||||
@@ -174,9 +174,9 @@ pub fn par_decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator
|
||||
&mut inner_loop_generator,
|
||||
);
|
||||
},
|
||||
)
|
||||
);
|
||||
},
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`decompress_seeded_lwe_multi_bit_bootstrap_key`].
|
||||
@@ -203,5 +203,5 @@ pub fn par_decompress_seeded_lwe_multi_bit_bootstrap_key<Scalar, InputCont, Outp
|
||||
output_bsk,
|
||||
input_bsk,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ pub fn decompress_seeded_lwe_packing_keyswitch_key_with_existing_generator<
|
||||
&mut output_pksk.as_mut_glwe_ciphertext_list(),
|
||||
&input_pksk.as_seeded_glwe_ciphertext_list(),
|
||||
generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Decompress a [`SeededLwePackingKeyswitchKey`], without consuming it, into a standard
|
||||
@@ -45,5 +45,5 @@ pub fn decompress_seeded_lwe_packing_keyswitch_key<Scalar, InputCont, OutputCont
|
||||
output_pksk,
|
||||
input_pksk,
|
||||
&mut generator,
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -38,6 +38,31 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
/// This primitive is meant to manage the dot product avoiding overflow on multiplication by casting
|
||||
/// to u128, for example for u64, avoiding overflow on each multiplication (as u64::MAX * u64::MAX <
|
||||
/// u128::MAX)
|
||||
pub fn slice_wrapping_dot_product_custom_mod<Scalar>(
|
||||
lhs: &[Scalar],
|
||||
rhs: &[Scalar],
|
||||
modulus: Scalar,
|
||||
) -> Scalar
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
|
||||
lhs.iter()
|
||||
.zip(rhs.iter())
|
||||
.fold(Scalar::ZERO, |acc, (&left, &right)| {
|
||||
acc.wrapping_add_custom_mod(left.wrapping_mul_custom_mod(right, modulus), modulus)
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a slice containing unsigned integers to another one element-wise.
|
||||
///
|
||||
/// # Note
|
||||
@@ -78,6 +103,33 @@ where
|
||||
.for_each(|(out, (&lhs, &rhs))| *out = lhs.wrapping_add(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_add_custom_mod<Scalar>(
|
||||
output: &mut [Scalar],
|
||||
lhs: &[Scalar],
|
||||
rhs: &[Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
assert!(
|
||||
output.len() == lhs.len(),
|
||||
"output (len: {}) and rhs (len: {}) must have the same length",
|
||||
output.len(),
|
||||
lhs.len()
|
||||
);
|
||||
|
||||
output
|
||||
.iter_mut()
|
||||
.zip(lhs.iter().zip(rhs.iter()))
|
||||
.for_each(|(out, (&lhs, &rhs))| *out = lhs.wrapping_add_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
/// Add a slice containing unsigned integers to another one element-wise and in place.
|
||||
///
|
||||
/// # Note
|
||||
@@ -110,6 +162,25 @@ where
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_add(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_add_assign_custom_mod<Scalar>(
|
||||
lhs: &mut [Scalar],
|
||||
rhs: &[Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
|
||||
lhs.iter_mut()
|
||||
.zip(rhs.iter())
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_add_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
/// Add a slice containing unsigned integers to another one mutiplied by a scalar.
|
||||
///
|
||||
/// Let *a*,*b* be two slices, let *c* be a scalar, this computes: *a <- a+bc*
|
||||
@@ -187,6 +258,33 @@ where
|
||||
.for_each(|(out, (&lhs, &rhs))| *out = lhs.wrapping_sub(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_sub_custom_mod<Scalar>(
|
||||
output: &mut [Scalar],
|
||||
lhs: &[Scalar],
|
||||
rhs: &[Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
assert!(
|
||||
output.len() == lhs.len(),
|
||||
"output (len: {}) and rhs (len: {}) must have the same length",
|
||||
output.len(),
|
||||
lhs.len()
|
||||
);
|
||||
|
||||
output
|
||||
.iter_mut()
|
||||
.zip(lhs.iter().zip(rhs.iter()))
|
||||
.for_each(|(out, (&lhs, &rhs))| *out = lhs.wrapping_sub_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
/// Subtract a slice containing unsigned integers to another one, element-wise and in place.
|
||||
///
|
||||
/// # Note
|
||||
@@ -219,6 +317,25 @@ where
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_sub(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_sub_assign_custom_mod<Scalar>(
|
||||
lhs: &mut [Scalar],
|
||||
rhs: &[Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
|
||||
lhs.iter_mut()
|
||||
.zip(rhs.iter())
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_sub_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
/// Subtract a slice containing unsigned integers to another one mutiplied by a scalar,
|
||||
/// element-wise and in place.
|
||||
///
|
||||
@@ -229,6 +346,11 @@ where
|
||||
/// Computations wrap around (similar to computing modulo $2^{n\_{bits}}$) when exceeding the
|
||||
/// unsigned integer capacity.
|
||||
///
|
||||
/// This functions has hardcoded cases for small values for `scalar` in $[-16, 16]$ which allows
|
||||
/// for specifically optimized code paths (a multiplication by a power of 2 can be changed to shift
|
||||
/// by the compiler), this yields significant performance improvements for the keyswitch which
|
||||
/// heavily relies on that primitive.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
@@ -245,15 +367,118 @@ pub fn slice_wrapping_sub_scalar_mul_assign<Scalar>(
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
struct Impl<'a, Scalar> {
|
||||
lhs: &'a mut [Scalar],
|
||||
rhs: &'a [Scalar],
|
||||
scalar: Scalar,
|
||||
}
|
||||
|
||||
impl<Scalar: UnsignedInteger> pulp::NullaryFnOnce for Impl<'_, Scalar> {
|
||||
type Output = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn call(self) -> Self::Output {
|
||||
let Self { lhs, rhs, scalar } = self;
|
||||
|
||||
macro_rules! spec_constant {
|
||||
($constant: expr) => {
|
||||
if scalar == Scalar::cast_from($constant as u128) {
|
||||
for (lhs, &rhs) in lhs.iter_mut().zip(rhs.iter()) {
|
||||
*lhs = (*lhs).wrapping_sub(
|
||||
rhs.wrapping_mul(Scalar::cast_from($constant as u128)),
|
||||
)
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Manage all values with hardcoded paths for values in [-16; 16]
|
||||
// This takes care of all keyswitch base logs <= 5
|
||||
// The negated value is handled in the spec constant to avoid bad surprises with the
|
||||
// constant type vs the Scalar type
|
||||
// UnsignedInteger is CastFrom<u128> by default, we give the constant in a readable form
|
||||
// as an i128, it then gets cast to u128
|
||||
spec_constant!(-16i128);
|
||||
spec_constant!(-15i128);
|
||||
spec_constant!(-14i128);
|
||||
spec_constant!(-13i128);
|
||||
spec_constant!(-12i128);
|
||||
spec_constant!(-11i128);
|
||||
spec_constant!(-10i128);
|
||||
spec_constant!(-9i128);
|
||||
spec_constant!(-8i128);
|
||||
spec_constant!(-7i128);
|
||||
spec_constant!(-6i128);
|
||||
spec_constant!(-5i128);
|
||||
spec_constant!(-4i128);
|
||||
spec_constant!(-3i128);
|
||||
spec_constant!(-2i128);
|
||||
spec_constant!(-1i128);
|
||||
spec_constant!(0i128);
|
||||
spec_constant!(1i128);
|
||||
spec_constant!(2i128);
|
||||
spec_constant!(3i128);
|
||||
spec_constant!(4i128);
|
||||
spec_constant!(5i128);
|
||||
spec_constant!(6i128);
|
||||
spec_constant!(7i128);
|
||||
spec_constant!(8i128);
|
||||
spec_constant!(9i128);
|
||||
spec_constant!(10i128);
|
||||
spec_constant!(11i128);
|
||||
spec_constant!(12i128);
|
||||
spec_constant!(13i128);
|
||||
spec_constant!(14i128);
|
||||
spec_constant!(15i128);
|
||||
spec_constant!(16i128);
|
||||
|
||||
// Fall back case, will likely be slower as the compiler cannot hard code optimized code
|
||||
// like filling with 0s for the 0 case, noop for the 1 case, shift left by 1 for 2, etc.
|
||||
for (lhs, &rhs) in lhs.iter_mut().zip(rhs.iter()) {
|
||||
*lhs = (*lhs).wrapping_sub(rhs.wrapping_mul(scalar));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Const evaluated
|
||||
assert!(
|
||||
Scalar::BITS <= 128,
|
||||
"Scalar has more than 128 bits, \
|
||||
specialized constants will not work properly for negative values."
|
||||
);
|
||||
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
lhs.iter_mut()
|
||||
.zip(rhs.iter())
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_sub(rhs.wrapping_mul(scalar)));
|
||||
|
||||
pulp::Arch::new().dispatch(Impl { lhs, rhs, scalar });
|
||||
}
|
||||
|
||||
/// This primitive is meant to manage the sub_scalar_mul operation for values that were cast to a
|
||||
/// bigger type, for example u64 to u128, avoiding overflow on each multiplication (as u64::MAX *
|
||||
/// u64::MAX < u128::MAX )
|
||||
pub fn slice_wrapping_sub_scalar_mul_assign_custom_modulus<Scalar>(
|
||||
lhs: &mut [Scalar],
|
||||
rhs: &[Scalar],
|
||||
scalar: Scalar,
|
||||
modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
lhs.iter_mut().zip(rhs.iter()).for_each(|(lhs, &rhs)| {
|
||||
*lhs =
|
||||
(*lhs).wrapping_sub_custom_mod(rhs.wrapping_mul_custom_mod(scalar, modulus), modulus);
|
||||
});
|
||||
}
|
||||
|
||||
/// Compute the opposite of a slice containing unsigned integers, element-wise and in place.
|
||||
@@ -280,6 +505,17 @@ where
|
||||
.for_each(|elt| *elt = (*elt).wrapping_neg());
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_opposite_assign_custom_mod<Scalar>(
|
||||
slice: &mut [Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
slice
|
||||
.iter_mut()
|
||||
.for_each(|elt| *elt = (*elt).wrapping_neg_custom_mod(custom_modulus));
|
||||
}
|
||||
|
||||
/// Multiply a slice containing unsigned integers by a scalar, element-wise and in place.
|
||||
///
|
||||
/// # Note
|
||||
@@ -304,6 +540,17 @@ where
|
||||
.for_each(|lhs| *lhs = (*lhs).wrapping_mul(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_scalar_mul_assign_custom_mod<Scalar>(
|
||||
lhs: &mut [Scalar],
|
||||
rhs: Scalar,
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
lhs.iter_mut()
|
||||
.for_each(|lhs| *lhs = (*lhs).wrapping_mul_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_scalar_div_assign<Scalar>(lhs: &mut [Scalar], rhs: Scalar)
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
|
||||
@@ -6,6 +6,11 @@ use crate::core_crypto::commons::generators::{
|
||||
use crate::core_crypto::commons::math::random::{ActivatedRandomGenerator, CompressionSeed};
|
||||
use crate::core_crypto::commons::test_tools;
|
||||
|
||||
#[cfg(not(feature = "__coverage"))]
|
||||
const NB_TESTS: usize = 10;
|
||||
#[cfg(feature = "__coverage")]
|
||||
const NB_TESTS: usize = 1;
|
||||
|
||||
fn test_parallel_and_seeded_ggsw_encryption_equivalence<Scalar>(
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
) where
|
||||
@@ -27,8 +32,6 @@ fn test_parallel_and_seeded_ggsw_encryption_equivalence<Scalar>(
|
||||
let mut secret_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
|
||||
for _ in 0..NB_TESTS {
|
||||
// Create the GlweSecretKey
|
||||
let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
@@ -174,7 +177,7 @@ fn test_parallel_and_seeded_ggsw_encryption_equivalence_u64_custom_mod() {
|
||||
);
|
||||
}
|
||||
|
||||
fn ggsw_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn ggsw_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: ClassicTestParams<Scalar>) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -184,8 +187,6 @@ fn ggsw_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Sca
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
|
||||
let mut msg = Scalar::ONE << decomposition_base_log.0;
|
||||
|
||||
while msg != Scalar::ZERO {
|
||||
@@ -232,13 +233,18 @@ fn ggsw_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Sca
|
||||
|
||||
assert!(decoded.0 == msg);
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(ggsw_encrypt_decrypt_custom_mod);
|
||||
|
||||
fn ggsw_par_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus + Send + Sync>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -249,8 +255,6 @@ fn ggsw_par_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus + Send + Sync>(
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
|
||||
let mut msg = Scalar::ONE << decomposition_base_log.0;
|
||||
|
||||
while msg != Scalar::ZERO {
|
||||
@@ -297,12 +301,19 @@ fn ggsw_par_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus + Send + Sync>(
|
||||
|
||||
assert!(decoded.0 == msg);
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(ggsw_par_encrypt_decrypt_custom_mod);
|
||||
|
||||
fn ggsw_seeded_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn ggsw_seeded_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -312,8 +323,6 @@ fn ggsw_seeded_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestPar
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
|
||||
let mut msg = Scalar::ONE << decomposition_base_log.0;
|
||||
|
||||
while msg != Scalar::ZERO {
|
||||
@@ -363,13 +372,18 @@ fn ggsw_seeded_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestPar
|
||||
|
||||
assert!(decoded.0 == msg);
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(ggsw_seeded_encrypt_decrypt_custom_mod);
|
||||
|
||||
fn ggsw_seeded_par_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus + Sync + Send>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -380,8 +394,6 @@ fn ggsw_seeded_par_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus + Sync + Sen
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
|
||||
let mut msg = Scalar::ONE << decomposition_base_log.0;
|
||||
|
||||
while msg != Scalar::ZERO {
|
||||
@@ -431,6 +443,11 @@ fn ggsw_seeded_par_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus + Sync + Sen
|
||||
|
||||
assert!(decoded.0 == msg);
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
use super::*;
|
||||
|
||||
fn glwe_encrypt_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
#[cfg(not(feature = "__coverage"))]
|
||||
const NB_TESTS: usize = 10;
|
||||
#[cfg(feature = "__coverage")]
|
||||
const NB_TESTS: usize = 1;
|
||||
|
||||
fn glwe_encrypt_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -10,7 +17,6 @@ fn glwe_encrypt_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestPar
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -57,12 +63,17 @@ fn glwe_encrypt_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestPar
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_assign_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn glwe_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: ClassicTestParams<Scalar>) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -72,7 +83,6 @@ fn glwe_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Sca
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -123,12 +133,17 @@ fn glwe_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Sca
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_decrypt_custom_mod);
|
||||
|
||||
fn glwe_list_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn glwe_list_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: ClassicTestParams<Scalar>) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -139,7 +154,6 @@ fn glwe_list_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParam
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -195,12 +209,19 @@ fn glwe_list_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParam
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_list_encrypt_decrypt_custom_mod);
|
||||
|
||||
fn glwe_trivial_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn glwe_trivial_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
@@ -209,7 +230,6 @@ fn glwe_trivial_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestPa
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -254,13 +274,18 @@ fn glwe_trivial_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestPa
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_trivial_encrypt_decrypt_custom_mod);
|
||||
|
||||
fn glwe_allocate_trivial_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -270,7 +295,6 @@ fn glwe_allocate_trivial_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -308,12 +332,19 @@ fn glwe_allocate_trivial_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
|
||||
assert!(output_plaintext_list.iter().all(|x| *x.0 == msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_allocate_trivial_encrypt_decrypt_custom_mod);
|
||||
|
||||
fn glwe_seeded_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn glwe_seeded_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -323,7 +354,6 @@ fn glwe_seeded_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestPar
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -384,12 +414,19 @@ fn glwe_seeded_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestPar
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_seeded_encrypt_decrypt_custom_mod);
|
||||
|
||||
fn glwe_seeded_list_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn glwe_seeded_list_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -400,7 +437,6 @@ fn glwe_seeded_list_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: Te
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -459,6 +495,11 @@ fn glwe_seeded_list_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: Te
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
use super::*;
|
||||
|
||||
fn glwe_encrypt_add_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
#[cfg(not(feature = "__coverage"))]
|
||||
const NB_TESTS: usize = 10;
|
||||
#[cfg(feature = "__coverage")]
|
||||
const NB_TESTS: usize = 1;
|
||||
|
||||
fn glwe_encrypt_add_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -10,7 +17,6 @@ fn glwe_encrypt_add_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(params: Tes
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -70,12 +76,17 @@ fn glwe_encrypt_add_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(params: Tes
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == (msg + msg) % msg_modulus));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_add_assign_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_add_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn glwe_encrypt_add_decrypt_custom_mod<Scalar: UnsignedTorus>(params: ClassicTestParams<Scalar>) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -85,7 +96,6 @@ fn glwe_encrypt_add_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -147,13 +157,18 @@ fn glwe_encrypt_add_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == (msg + msg) % msg_modulus));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_add_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_plaintext_list_add_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -164,7 +179,6 @@ fn glwe_encrypt_plaintext_list_add_assign_decrypt_custom_mod<Scalar: UnsignedTor
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -222,13 +236,18 @@ fn glwe_encrypt_plaintext_list_add_assign_decrypt_custom_mod<Scalar: UnsignedTor
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == (msg + msg) % msg_modulus));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_plaintext_list_add_assign_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_plaintext_list_sub_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -239,7 +258,6 @@ fn glwe_encrypt_plaintext_list_sub_assign_decrypt_custom_mod<Scalar: UnsignedTor
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -297,13 +315,18 @@ fn glwe_encrypt_plaintext_list_sub_assign_decrypt_custom_mod<Scalar: UnsignedTor
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == Scalar::ZERO));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_plaintext_list_sub_assign_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_plaintext_add_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -314,7 +337,6 @@ fn glwe_encrypt_plaintext_add_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -373,13 +395,18 @@ fn glwe_encrypt_plaintext_add_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == (msg + msg) % msg_modulus));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_plaintext_add_assign_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_plaintext_sub_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -390,7 +417,6 @@ fn glwe_encrypt_plaintext_sub_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -449,13 +475,18 @@ fn glwe_encrypt_plaintext_sub_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == Scalar::ZERO));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_plaintext_sub_assign_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_opposite_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -466,7 +497,6 @@ fn glwe_encrypt_opposite_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -526,13 +556,18 @@ fn glwe_encrypt_opposite_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
.iter()
|
||||
.all(|&x| x == msg.wrapping_neg() % msg_modulus));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_opposite_assign_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_cleartext_mul_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -543,7 +578,6 @@ fn glwe_encrypt_cleartext_mul_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -604,13 +638,18 @@ fn glwe_encrypt_cleartext_mul_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
.iter()
|
||||
.all(|&x| x == (msg * cleartext.0) % msg_modulus));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_cleartext_mul_assign_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_cleartext_mul_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: TestParams<Scalar>,
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
@@ -621,7 +660,6 @@ fn glwe_encrypt_cleartext_mul_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -684,12 +722,19 @@ fn glwe_encrypt_cleartext_mul_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
.iter()
|
||||
.all(|&x| x == (msg * cleartext.0) % msg_modulus));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_cleartext_mul_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_sub_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn glwe_encrypt_sub_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: ClassicTestParams<Scalar>,
|
||||
) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -699,7 +744,6 @@ fn glwe_encrypt_sub_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(params: Tes
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -759,12 +803,17 @@ fn glwe_encrypt_sub_assign_decrypt_custom_mod<Scalar: UnsignedTorus>(params: Tes
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == Scalar::ZERO));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(glwe_encrypt_sub_assign_decrypt_custom_mod);
|
||||
|
||||
fn glwe_encrypt_sub_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
fn glwe_encrypt_sub_decrypt_custom_mod<Scalar: UnsignedTorus>(params: ClassicTestParams<Scalar>) {
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
@@ -774,7 +823,6 @@ fn glwe_encrypt_sub_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
@@ -836,6 +884,11 @@ fn glwe_encrypt_sub_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams
|
||||
|
||||
assert!(decoded.iter().all(|&x| x == Scalar::ZERO));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user