From 14da0ca00176d6fe2dd102eba17aec4f0e37ff07 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 12 Dec 2022 14:28:25 +0100 Subject: [PATCH] feat(integer): add concrete-integer as integer module --- .github/workflows/aws_tfhe_integer_tests.yml | 10 +- .github/workflows/aws_tfhe_tests.yml | 2 +- .github/workflows/cargo_build.yml | 8 +- .github/workflows/m1_tests.yml | 12 +- .github/workflows/trigger_aws_tests_on_pr.yml | 1 + .gitignore | 2 + Makefile | 42 +- README.md | 6 +- ci/slab.toml | 5 + scripts/integer-tests.sh | 58 + tfhe/Cargo.toml | 9 +- tfhe/benches/integer/bench.rs | 912 ++++++++++++++++ tfhe/docs/SUMMARY.md | 3 + tfhe/docs/integer/SUMMARY.md | 20 + .../integer/getting_started/first_circuit.md | 105 ++ .../integer/getting_started/installation.md | 11 + .../integer/getting_started/operation_list.md | 15 + .../getting_started/operation_types.md | 86 ++ .../integer/getting_started/parameters.md | 6 + tfhe/docs/integer/how_to/pbs.md | 51 + tfhe/docs/integer/introduction.md | 8 + .../integer/tutorials/circuit_evaluation.md | 120 +++ tfhe/docs/integer/tutorials/serialization.md | 78 ++ tfhe/src/integer/ciphertext/mod.rs | 57 + tfhe/src/integer/client_key/crt.rs | 67 ++ tfhe/src/integer/client_key/mod.rs | 434 ++++++++ tfhe/src/integer/client_key/radix.rs | 78 ++ tfhe/src/integer/client_key/utils.rs | 91 ++ tfhe/src/integer/keycache.rs | 43 + tfhe/src/integer/mod.rs | 135 +++ tfhe/src/integer/parameters/mod.rs | 125 +++ tfhe/src/integer/server_key/crt/add_crt.rs | 73 ++ tfhe/src/integer/server_key/crt/mod.rs | 102 ++ tfhe/src/integer/server_key/crt/mul_crt.rs | 88 ++ tfhe/src/integer/server_key/crt/neg_crt.rs | 94 ++ .../integer/server_key/crt/scalar_add_crt.rs | 212 ++++ .../integer/server_key/crt/scalar_mul_crt.rs | 226 ++++ .../integer/server_key/crt/scalar_sub_crt.rs | 204 ++++ tfhe/src/integer/server_key/crt/sub_crt.rs | 170 +++ tfhe/src/integer/server_key/crt/tests.rs | 273 +++++ .../server_key/crt_parallel/add_crt.rs | 115 ++ .../integer/server_key/crt_parallel/mod.rs | 108 ++ .../server_key/crt_parallel/mul_crt.rs | 108 ++ .../server_key/crt_parallel/neg_crt.rs | 84 ++ .../server_key/crt_parallel/scalar_add_crt.rs | 194 ++++ .../server_key/crt_parallel/scalar_mul_crt.rs | 199 ++++ .../server_key/crt_parallel/scalar_sub_crt.rs | 186 ++++ .../server_key/crt_parallel/sub_crt.rs | 158 +++ tfhe/src/integer/server_key/mod.rs | 92 ++ tfhe/src/integer/server_key/radix/add.rs | 245 +++++ .../integer/server_key/radix/bitwise_op.rs | 604 +++++++++++ tfhe/src/integer/server_key/radix/mod.rs | 117 +++ tfhe/src/integer/server_key/radix/mul.rs | 272 +++++ tfhe/src/integer/server_key/radix/neg.rs | 215 ++++ .../integer/server_key/radix/scalar_add.rs | 236 +++++ .../integer/server_key/radix/scalar_mul.rs | 369 +++++++ .../integer/server_key/radix/scalar_sub.rs | 233 ++++ tfhe/src/integer/server_key/radix/shift.rs | 219 ++++ tfhe/src/integer/server_key/radix/sub.rs | 307 ++++++ tfhe/src/integer/server_key/radix/tests.rs | 957 +++++++++++++++++ .../integer/server_key/radix_parallel/add.rs | 146 +++ .../server_key/radix_parallel/bitwise_op.rs | 172 +++ .../integer/server_key/radix_parallel/mod.rs | 89 ++ .../integer/server_key/radix_parallel/mul.rs | 334 ++++++ .../integer/server_key/radix_parallel/neg.rs | 37 + .../server_key/radix_parallel/scalar_add.rs | 74 ++ .../server_key/radix_parallel/scalar_mul.rs | 311 ++++++ .../server_key/radix_parallel/scalar_sub.rs | 46 + .../server_key/radix_parallel/shift.rs | 180 ++++ .../integer/server_key/radix_parallel/sub.rs | 101 ++ .../server_key/radix_parallel/tests.rs | 736 +++++++++++++ tfhe/src/integer/tests.rs | 21 + tfhe/src/integer/wopbs/mod.rs | 992 ++++++++++++++++++ tfhe/src/integer/wopbs/test.rs | 291 +++++ tfhe/src/lib.rs | 14 +- tfhe/src/shortint/engine/wopbs/mod.rs | 4 +- tfhe/src/shortint/wopbs/mod.rs | 4 +- tfhe/src/test_user_docs.rs | 15 + 78 files changed, 12599 insertions(+), 28 deletions(-) create mode 100755 scripts/integer-tests.sh create mode 100644 tfhe/benches/integer/bench.rs create mode 100644 tfhe/docs/integer/SUMMARY.md create mode 100644 tfhe/docs/integer/getting_started/first_circuit.md create mode 100644 tfhe/docs/integer/getting_started/installation.md create mode 100644 tfhe/docs/integer/getting_started/operation_list.md create mode 100644 tfhe/docs/integer/getting_started/operation_types.md create mode 100644 tfhe/docs/integer/getting_started/parameters.md create mode 100644 tfhe/docs/integer/how_to/pbs.md create mode 100644 tfhe/docs/integer/introduction.md create mode 100644 tfhe/docs/integer/tutorials/circuit_evaluation.md create mode 100644 tfhe/docs/integer/tutorials/serialization.md create mode 100644 tfhe/src/integer/ciphertext/mod.rs create mode 100644 tfhe/src/integer/client_key/crt.rs create mode 100644 tfhe/src/integer/client_key/mod.rs create mode 100644 tfhe/src/integer/client_key/radix.rs create mode 100644 tfhe/src/integer/client_key/utils.rs create mode 100644 tfhe/src/integer/keycache.rs create mode 100755 tfhe/src/integer/mod.rs create mode 100644 tfhe/src/integer/parameters/mod.rs create mode 100644 tfhe/src/integer/server_key/crt/add_crt.rs create mode 100644 tfhe/src/integer/server_key/crt/mod.rs create mode 100644 tfhe/src/integer/server_key/crt/mul_crt.rs create mode 100644 tfhe/src/integer/server_key/crt/neg_crt.rs create mode 100644 tfhe/src/integer/server_key/crt/scalar_add_crt.rs create mode 100644 tfhe/src/integer/server_key/crt/scalar_mul_crt.rs create mode 100644 tfhe/src/integer/server_key/crt/scalar_sub_crt.rs create mode 100644 tfhe/src/integer/server_key/crt/sub_crt.rs create mode 100644 tfhe/src/integer/server_key/crt/tests.rs create mode 100644 tfhe/src/integer/server_key/crt_parallel/add_crt.rs create mode 100644 tfhe/src/integer/server_key/crt_parallel/mod.rs create mode 100644 tfhe/src/integer/server_key/crt_parallel/mul_crt.rs create mode 100644 tfhe/src/integer/server_key/crt_parallel/neg_crt.rs create mode 100644 tfhe/src/integer/server_key/crt_parallel/scalar_add_crt.rs create mode 100644 tfhe/src/integer/server_key/crt_parallel/scalar_mul_crt.rs create mode 100644 tfhe/src/integer/server_key/crt_parallel/scalar_sub_crt.rs create mode 100644 tfhe/src/integer/server_key/crt_parallel/sub_crt.rs create mode 100644 tfhe/src/integer/server_key/mod.rs create mode 100644 tfhe/src/integer/server_key/radix/add.rs create mode 100644 tfhe/src/integer/server_key/radix/bitwise_op.rs create mode 100644 tfhe/src/integer/server_key/radix/mod.rs create mode 100644 tfhe/src/integer/server_key/radix/mul.rs create mode 100644 tfhe/src/integer/server_key/radix/neg.rs create mode 100644 tfhe/src/integer/server_key/radix/scalar_add.rs create mode 100644 tfhe/src/integer/server_key/radix/scalar_mul.rs create mode 100644 tfhe/src/integer/server_key/radix/scalar_sub.rs create mode 100644 tfhe/src/integer/server_key/radix/shift.rs create mode 100644 tfhe/src/integer/server_key/radix/sub.rs create mode 100644 tfhe/src/integer/server_key/radix/tests.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/add.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/bitwise_op.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/mod.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/mul.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/neg.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/scalar_add.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/scalar_mul.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/shift.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/sub.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests.rs create mode 100644 tfhe/src/integer/tests.rs create mode 100644 tfhe/src/integer/wopbs/mod.rs create mode 100644 tfhe/src/integer/wopbs/test.rs diff --git a/.github/workflows/aws_tfhe_integer_tests.yml b/.github/workflows/aws_tfhe_integer_tests.yml index e4fe74c92..07a200ac5 100644 --- a/.github/workflows/aws_tfhe_integer_tests.yml +++ b/.github/workflows/aws_tfhe_integer_tests.yml @@ -30,9 +30,9 @@ on: type: string jobs: - shortint-tests: + integer-tests: concurrency: - group: ${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }} + group: ${{ github.workflow }}_${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }} cancel-in-progress: true runs-on: ${{ github.event.inputs.runner_name }} steps: @@ -56,6 +56,10 @@ jobs: toolchain: stable default: true + - name: Gen Keys if required + run: | + make gen_key_cache + - name: Run integer tests run: | make test_integer_ci @@ -68,6 +72,6 @@ jobs: SLACK_COLOR: ${{ job.status }} SLACK_CHANNEL: ${{ secrets.SLACK_CHANNEL }} SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png - SLACK_MESSAGE: "Shortint tests finished with status: ${{ job.status }}. (${{ env.ACTION_RUN_URL }})" + SLACK_MESSAGE: "Integer tests finished with status: ${{ job.status }}. (${{ env.ACTION_RUN_URL }})" SLACK_USERNAME: ${{ secrets.BOT_USERNAME }} SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} diff --git a/.github/workflows/aws_tfhe_tests.yml b/.github/workflows/aws_tfhe_tests.yml index cdbd1a67a..5cdf6f27e 100644 --- a/.github/workflows/aws_tfhe_tests.yml +++ b/.github/workflows/aws_tfhe_tests.yml @@ -32,7 +32,7 @@ on: jobs: shortint-tests: concurrency: - group: ${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }} + group: ${{ github.workflow }}_${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }} cancel-in-progress: true runs-on: ${{ github.event.inputs.runner_name }} steps: diff --git a/.github/workflows/cargo_build.yml b/.github/workflows/cargo_build.yml index 2fa2110ef..a0b78c4af 100644 --- a/.github/workflows/cargo_build.yml +++ b/.github/workflows/cargo_build.yml @@ -40,9 +40,13 @@ jobs: run: | make build_shortint - - name: Build Release shortint and boolean + - name: Build Release integer run: | - make build_boolean_and_shortint + make build_integer + + - name: Build Release tfhe full + run: | + make build_tfhe_full - name: Build Release c_api run: | diff --git a/.github/workflows/m1_tests.yml b/.github/workflows/m1_tests.yml index dfd171149..6a7be4ba3 100644 --- a/.github/workflows/m1_tests.yml +++ b/.github/workflows/m1_tests.yml @@ -40,9 +40,13 @@ jobs: run: | make build_shortint - - name: Build Release shortint and boolean + - name: Build Release integer run: | - make build_boolean_and_shortint + make build_integer + + - name: Build Release tfhe full + run: | + make build_tfhe_full - name: Build Release c_api run: | @@ -75,6 +79,10 @@ jobs: run: | make test_shortint_ci + - name: Run integer tests + run: | + make test_integer_ci + remove_label: name: Remove m1_test label runs-on: ubuntu-latest diff --git a/.github/workflows/trigger_aws_tests_on_pr.yml b/.github/workflows/trigger_aws_tests_on_pr.yml index 7a14ca161..2dee81587 100644 --- a/.github/workflows/trigger_aws_tests_on_pr.yml +++ b/.github/workflows/trigger_aws_tests_on_pr.yml @@ -15,3 +15,4 @@ jobs: allow-repeats: true message: | @slab-ci cpu_test + @slab-ci cpu_integer_test diff --git a/.gitignore b/.gitignore index 2c01b436e..13088a25b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ target/ # Path we use for internal-keycache during tests keys/ +# In case of symlinked keys +keys **/Cargo.lock **/*.bin diff --git a/Makefile b/Makefile index dc705b770..c5a18451a 100644 --- a/Makefile +++ b/Makefile @@ -69,10 +69,16 @@ clippy_shortint: install_rs_check_toolchain --features=$(TARGET_ARCH_FEATURE),shortint \ -p tfhe -- --no-deps -D warnings -.PHONY: clippy # Run clippy lints enabling the boolean, shortint -clippy: install_rs_check_toolchain +.PHONY: clippy_integer # Run clippy lints enabling the integer features +clippy_integer: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint \ + --features=$(TARGET_ARCH_FEATURE),integer \ + -p tfhe -- --no-deps -D warnings + +.PHONY: clippy # Run clippy lints enabling the boolean, shortint, integer +clippy: install_rs_check_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer \ -p tfhe -- --no-deps -D warnings .PHONY: clippy_c_api # Run clippy lints enabling the boolean, shortint and the C API @@ -95,11 +101,11 @@ clippy_tasks: .PHONY: clippy_all_targets # Run clippy lints on all targets (benches, examples, etc.) clippy_all_targets: RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint,internal-keycache \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache \ -p tfhe -- --no-deps -D warnings .PHONY: clippy_all # Run all clippy targets -clippy_all: clippy clippy_boolean clippy_shortint clippy_all_targets clippy_c_api \ +clippy_all: clippy clippy_boolean clippy_shortint clippy_integer clippy_all_targets clippy_c_api \ clippy_js_wasm_api clippy_tasks .PHONY: gen_key_cache # Run the script to generate keys and cache them for shortint tests @@ -118,10 +124,15 @@ build_shortint: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) build --release \ --features=$(TARGET_ARCH_FEATURE),shortint -p tfhe -.PHONY: build_boolean_and_shortint # Build with boolean and shortint enabled -build_boolean_and_shortint: install_rs_build_toolchain +.PHONY: build_integer # Build with integer enabled +build_integer: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) build --release \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint -p tfhe + --features=$(TARGET_ARCH_FEATURE),integer -p tfhe + +.PHONY: build_tfhe_full # Build with boolean, shortint and integer enabled +build_tfhe_full: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) build --release \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer -p tfhe .PHONY: build_c_api # Build the C API for boolean and shortint build_c_api: install_rs_build_toolchain @@ -165,17 +176,26 @@ test_shortint: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \ --features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache -p tfhe -- shortint:: +.PHONY: test_integer_ci # Run the tests for integer ci +test_integer_ci: install_rs_build_toolchain install_cargo_nextest + ./scripts/integer-tests.sh $(CARGO_RS_BUILD_TOOLCHAIN) + +.PHONY: test_integer # Run all the tests for integer +test_integer: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \ + --features=$(TARGET_ARCH_FEATURE),integer,internal-keycache -p tfhe -- integer:: + .PHONY: test_user_doc # Run tests from the .md documentation test_user_doc: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release --doc \ - --features=$(TARGET_ARCH_FEATURE),shortint,boolean,internal-keycache -p tfhe \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache -p tfhe \ -- test_user_docs:: .PHONY: doc # Build rust doc doc: install_rs_check_toolchain RUSTDOCFLAGS="--html-in-header katex-header.html -Dwarnings" \ cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" doc \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint --no-deps + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer --no-deps .PHONY: format_doc_latex # Format the documentation latex equations to avoid broken rendering. format_doc_latex: @@ -189,7 +209,7 @@ format_doc_latex: .PHONY: check_compile_tests # Build tests in debug without running them check_compile_tests: build_c_api RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \ - --features=$(TARGET_ARCH_FEATURE),shortint,boolean,internal-keycache -p tfhe && \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache -p tfhe && \ ./scripts/c_api_tests.sh --build-only .PHONY: build_nodejs_test_docker # Build a docker image with tools to run nodejs tests for wasm API diff --git a/README.md b/README.md index e6a78da77..d2c403a6c 100644 --- a/README.md +++ b/README.md @@ -43,13 +43,13 @@ To use the latest version of `TFHE-rs` in your project, you first need to add it + For x86_64-based machines running Unix-like OSes: ```toml -tfhe = { version = "*", features = ["boolean", "shortint", "x86_64-unix"] } +tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix"] } ``` + For Apple Silicon or aarch64-based machines running Unix-like OSes: ```toml -tfhe = { version = "*", features = ["boolean", "shortint", "aarch64-unix"] } +tfhe = { version = "*", features = ["boolean", "shortint", "integer", "aarch64-unix"] } ``` Note: users with ARM devices must use `TFHE-rs` by compiling using the `nightly` toolchain. @@ -58,7 +58,7 @@ Note: users with ARM devices must use `TFHE-rs` by compiling using the `nightly` running Windows: ```toml -tfhe = { version = "*", features = ["boolean", "shortint", "x86_64"] } +tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64"] } ``` Note: aarch64-based machines are not yet supported for Windows as it's currently missing an entropy source to be able to seed the [CSPRNGs](https://en.wikipedia.org/wiki/Cryptographically_secure_pseudorandom_number_generator) used in TFHE-rs diff --git a/ci/slab.toml b/ci/slab.toml index 6c653e972..76ac66947 100644 --- a/ci/slab.toml +++ b/ci/slab.toml @@ -13,6 +13,11 @@ workflow = "aws_tfhe_tests.yml" profile = "cpu-big" check_run_name = "CPU AWS Tests" +[command.cpu_integer_test] +workflow = "aws_tfhe_integer_tests.yml" +profile = "cpu-big" +check_run_name = "CPU Integer AWS Tests" + [command.shortint_bench] workflow = "shortint_benchmark.yml" profile = "bench" diff --git a/scripts/integer-tests.sh b/scripts/integer-tests.sh new file mode 100755 index 000000000..56cd04275 --- /dev/null +++ b/scripts/integer-tests.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +set -e + +CURR_DIR="$(dirname "$0")" +ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")" + +nproc_bin=nproc + +# macOS detects CPUs differently +if [[ $(uname) == "Darwin" ]]; then + nproc_bin="sysctl -n hw.logicalcpu" +fi + +n_threads="$(${nproc_bin})" + +if uname -a | grep "arm64"; then + if [[ $(uname) == "Darwin" ]]; then + # Keys are 4.7 gigs at max, CI M1 macs only has 8 gigs of RAM + n_threads=1 + fi +else + # Keys are 4.7 gigs at max, test machine has 32 gigs of RAM + n_threads=6 +fi + +# block pbs are too slow for high params +# mul_crt_4_4 is extremely flaky (~80% failure) +# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment +# same for test_wopbs_bivariate_radix_wopbs_param_message_4_carry_4 +# test_integer_smart_mul_param_message_4_carry_4 is too slow +filter_expression=''\ +'test(/^integer::.*$/)'\ +'and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/)'\ +'and not test(~mul_crt_param_message_4_carry_4)'\ +'and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/)'\ +'and not test(/.*test_wopbs_bivariate_radix_wopbs_param_message_4_carry_4$/)'\ +'and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)' + +export RUSTFLAGS="-C target-cpu=native" + +cargo ${1:+"${1}"} nextest run \ + --tests \ + --release \ + --package tfhe \ + --profile ci \ + --features="${ARCH_FEATURE}",integer,internal-keycache \ + --test-threads "${n_threads}" \ + -E "$filter_expression" + +cargo ${1:+"${1}"} test \ + --release \ + --package tfhe \ + --features="${ARCH_FEATURE}",integer,internal-keycache \ + --doc \ + integer:: + +echo "Test ran in $SECONDS seconds" diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml index 620befed5..2cbf420f5 100644 --- a/tfhe/Cargo.toml +++ b/tfhe/Cargo.toml @@ -60,6 +60,7 @@ getrandom = { version = "0.2.8", optional = true } [features] boolean = [] shortint = [] +integer = ["shortint"] internal-keycache = ["lazy_static", "fs2", "bincode"] __c_api = ["cbindgen", "bincode"] @@ -106,7 +107,7 @@ aarch64-unix = ["aarch64", "seeder_unix"] [package.metadata.docs.rs] # TODO: manage builds for docs.rs based on their documentation https://docs.rs/about -features = ["x86_64-unix", "boolean", "shortint"] +features = ["x86_64-unix", "boolean", "shortint", "integer"] rustdoc-args = ["--html-in-header", "katex-header.html"] ########### @@ -133,6 +134,12 @@ path = "benches/shortint/bench.rs" harness = false required-features = ["shortint", "internal-keycache"] +[[bench]] +name = "integer-bench" +path = "benches/integer/bench.rs" +harness = false +required-features = ["integer", "internal-keycache"] + [[bench]] name = "keygen" path = "benches/keygen/bench.rs" diff --git a/tfhe/benches/integer/bench.rs b/tfhe/benches/integer/bench.rs new file mode 100644 index 000000000..1c3ea8782 --- /dev/null +++ b/tfhe/benches/integer/bench.rs @@ -0,0 +1,912 @@ +#![allow(dead_code)] + +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::Rng; +use tfhe::integer::client_key::radix_decomposition; +use tfhe::integer::keycache::KEY_CACHE; +use tfhe::integer::parameters::*; +use tfhe::integer::wopbs::WopbsKey; +use tfhe::integer::{gen_keys, RadixCiphertext, ServerKey}; +use tfhe::shortint::keycache::KEY_CACHE_WOPBS; +use tfhe::shortint::parameters::parameters_wopbs_message_carry::get_parameters_from_message_and_carry_wopbs; +use tfhe::shortint::parameters::{get_parameters_from_message_and_carry, DEFAULT_PARAMETERS}; + +criterion_group!( + to_be_reworked, + smart_block_mul, + radmodint_unchecked_mul, + radmodint_unchecked_mul_many_sizes, + crt, + // radmodint_wopbs, + // radmodint_wopbs_32_bits, + // radmodint_wopbs_16bits_param_2_2_8_blocks, + // radmodint_wopbs_16bits_param_4_4_4_blocks, + concrete_integer_unchecked_mul_crt_16_bits, + concrete_integer_unchecked_add_crt_16_bits, + concrete_integer_unchecked_clean_carry_crt_16_bits, + concrete_integer_unchecked_mul_crt_32_bits, + concrete_integer_unchecked_add_crt_32_bits, + concrete_integer_unchecked_clean_carry_crt_32_bits, +); + +#[allow(unused_imports)] +use tfhe::shortint::parameters::{ + PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_4_CARRY_4, +}; + +macro_rules! named_param { + ($param:ident) => { + (stringify!($param), $param) + }; +} + +struct Parameters { + block_parameters: tfhe::shortint::Parameters, + num_block: usize, +} + +const BLOCK_4_MESSAGE_2_CARRY_2: Parameters = Parameters { + block_parameters: PARAM_MESSAGE_2_CARRY_2, + num_block: 4, +}; + +const BLOCK_4_MESSAGE_3_CARRY_3: Parameters = Parameters { + block_parameters: PARAM_MESSAGE_3_CARRY_3, + num_block: 4, +}; + +const SERVER_KEY_BENCH_PARAMS: [(&str, Parameters); 2] = [ + named_param!(BLOCK_4_MESSAGE_2_CARRY_2), + named_param!(BLOCK_4_MESSAGE_3_CARRY_3), +]; + +fn smart_neg(c: &mut Criterion) { + let mut bench_group = c.benchmark_group("smart_neg"); + + for (param_name, param) in SERVER_KEY_BENCH_PARAMS { + let (cks, sks) = KEY_CACHE.get_from_params(param.block_parameters); + + let mut rng = rand::thread_rng(); + + let modulus = (param.block_parameters.message_modulus.0 * param.num_block) as u64; + + let clear_0 = rng.gen::() % modulus; + + let mut ct = cks.encrypt_radix(clear_0, param.num_block); + + let bench_id = param_name; + bench_group.bench_function(bench_id, |b| { + b.iter(|| { + sks.smart_neg(&mut ct); + }) + }); + } + + bench_group.finish() +} + +fn full_propagate(c: &mut Criterion) { + let mut bench_group = c.benchmark_group("full_propagate"); + + for (param_name, param) in SERVER_KEY_BENCH_PARAMS { + let (cks, sks) = KEY_CACHE.get_from_params(param.block_parameters); + let mut rng = rand::thread_rng(); + + let modulus = (param.block_parameters.message_modulus.0 * param.num_block) as u64; + + let clear_0 = rng.gen::() % modulus; + + let mut ct = cks.encrypt_radix(clear_0, param.num_block); + + let bench_id = param_name; + bench_group.bench_function(bench_id, |b| { + b.iter(|| { + sks.full_propagate(&mut ct); + }) + }); + } + + bench_group.finish() +} + +fn bench_server_key_binary_function(c: &mut Criterion, bench_name: &str, binary_op: F) +where + F: Fn(&ServerKey, &mut RadixCiphertext, &mut RadixCiphertext), +{ + let mut bench_group = c.benchmark_group(bench_name); + + for (param_name, param) in SERVER_KEY_BENCH_PARAMS { + let (cks, sks) = KEY_CACHE.get_from_params(param.block_parameters); + + let mut rng = rand::thread_rng(); + + let modulus = (param.block_parameters.message_modulus.0 * param.num_block) as u64; + + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let mut ct_0 = cks.encrypt_radix(clear_0, param.num_block); + let mut ct_1 = cks.encrypt_radix(clear_1, param.num_block); + + let bench_id = format!("{bench_name}::{param_name}"); + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + binary_op(&sks, &mut ct_0, &mut ct_1); + }) + }); + } + + bench_group.finish() +} + +fn bench_server_key_binary_scalar_function(c: &mut Criterion, bench_name: &str, binary_op: F) +where + F: Fn(&ServerKey, &mut RadixCiphertext, u64), +{ + let mut bench_group = c.benchmark_group(bench_name); + + for (param_name, param) in SERVER_KEY_BENCH_PARAMS { + let (cks, sks) = KEY_CACHE.get_from_params(param.block_parameters); + + let mut rng = rand::thread_rng(); + + let modulus = (param.block_parameters.message_modulus.0 * param.num_block) as u64; + + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let mut ct_0 = cks.encrypt_radix(clear_0, param.num_block); + + let bench_id = format!("{bench_name}::{param_name}"); + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + binary_op(&sks, &mut ct_0, clear_1); + }) + }); + } + + bench_group.finish() +} + +macro_rules! define_server_key_bench_fn ( + ($server_key_method:ident) => { + fn $server_key_method(c: &mut Criterion) { + bench_server_key_binary_function( + c, + concat!("ServerKey::", stringify!($server_key_method)), + |server_key, lhs, rhs| { + server_key.$server_key_method(lhs, rhs); + }) + } + } +); + +macro_rules! define_server_key_bench_scalar_fn ( + ($server_key_method:ident) => { + fn $server_key_method(c: &mut Criterion) { + bench_server_key_binary_scalar_function( + c, + concat!("ServerKey::", stringify!($server_key_method)), + |server_key, lhs, rhs| { + server_key.$server_key_method(lhs, rhs); + }) + } + } +); + +define_server_key_bench_fn!(smart_add); +define_server_key_bench_fn!(smart_add_parallelized); +define_server_key_bench_fn!(smart_sub); +define_server_key_bench_fn!(smart_sub_parallelized); +define_server_key_bench_fn!(smart_mul); +define_server_key_bench_fn!(smart_mul_parallelized); +define_server_key_bench_fn!(smart_bitand); +define_server_key_bench_fn!(smart_bitand_parallelized); +define_server_key_bench_fn!(smart_bitor); +define_server_key_bench_fn!(smart_bitor_parallelized); +define_server_key_bench_fn!(smart_bitxor); +define_server_key_bench_fn!(smart_bitxor_parallelized); + +define_server_key_bench_fn!(unchecked_add); +define_server_key_bench_fn!(unchecked_sub); +define_server_key_bench_fn!(unchecked_mul); +define_server_key_bench_fn!(unchecked_mul_parallelized); +define_server_key_bench_fn!(unchecked_bitand); +define_server_key_bench_fn!(unchecked_bitor); +define_server_key_bench_fn!(unchecked_bitxor); + +define_server_key_bench_scalar_fn!(smart_scalar_add); +define_server_key_bench_scalar_fn!(smart_scalar_add_parallelized); +define_server_key_bench_scalar_fn!(smart_scalar_sub); +define_server_key_bench_scalar_fn!(smart_scalar_sub_parallelized); +define_server_key_bench_scalar_fn!(smart_scalar_mul); +define_server_key_bench_scalar_fn!(smart_scalar_mul_parallelized); + +define_server_key_bench_scalar_fn!(unchecked_scalar_add); +define_server_key_bench_scalar_fn!(unchecked_scalar_sub); +define_server_key_bench_scalar_fn!(unchecked_small_scalar_mul); + +criterion_group!( + smart_arithmetic_operation, + smart_neg, + smart_add, + smart_add_parallelized, + smart_sub, + smart_sub_parallelized, + smart_mul, + smart_mul_parallelized, + smart_bitand, + smart_bitand_parallelized, + smart_bitor, + smart_bitor_parallelized, + smart_bitxor, + smart_bitxor_parallelized, +); + +criterion_group!( + smart_scalar_arithmetic_operation, + smart_scalar_add, + smart_scalar_add_parallelized, + smart_scalar_sub, + smart_scalar_sub_parallelized, + smart_scalar_mul, + smart_scalar_mul_parallelized, +); + +criterion_group!( + unchecked_arithmetic_operation, + unchecked_add, + unchecked_sub, + unchecked_mul, + unchecked_mul_parallelized, + unchecked_bitand, + unchecked_bitor, + unchecked_bitxor, +); + +criterion_group!( + unchecked_scalar_arithmetic_operation, + unchecked_scalar_add, + unchecked_scalar_sub, + unchecked_small_scalar_mul, +); + +criterion_group!(misc, full_propagate,); + +criterion_main!( + smart_arithmetic_operation, + smart_scalar_arithmetic_operation, + unchecked_arithmetic_operation, + unchecked_scalar_arithmetic_operation, + misc, + to_be_reworked, +); + +fn smart_block_mul(c: &mut Criterion) { + let size = 4; + + // generate the server-client key set + let (cks, sks) = gen_keys(&DEFAULT_PARAMETERS); + + //RNG + let mut rng = rand::thread_rng(); + + let block_modulus = DEFAULT_PARAMETERS.message_modulus.0 as u64; + + // message_modulus^vec_length + let modulus = DEFAULT_PARAMETERS.message_modulus.0.pow(size as u32) as u64; + + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % block_modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_radix(clear_0, size); + + // encryption of an integer + let ct_one = cks.encrypt_one_block(clear_1); + + //scalar mul + c.bench_function("Smart_Block_Mul", |b| { + b.iter(|| { + sks.smart_block_mul(&mut ct_zero, &ct_one, 0); + }) + }); +} + +fn crt(c: &mut Criterion) { + // generate the server-client key set + let (cks, sks) = gen_keys(&DEFAULT_PARAMETERS); + + //RNG + let mut rng = rand::thread_rng(); + + let basis = vec![2, 3, 5]; + let modulus = 30; // 30 = 2*3*5 + + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + // Encrypt the integers + let mut ctxt_1 = cks.encrypt_crt(clear1, basis.clone()); + let mut ctxt_2 = cks.encrypt_crt(clear2, basis); + + //scalar mul + c.bench_function("CRT: Smart_Mul", |b| { + b.iter(|| { + sks.smart_crt_mul_assign(&mut ctxt_1, &mut ctxt_2); + }) + }); + c.bench_function("CRT: Smart_Add", |b| { + b.iter(|| { + sks.smart_crt_add_assign(&mut ctxt_1, &mut ctxt_2); + }) + }); +} + +fn radmodint_unchecked_mul(c: &mut Criterion) { + let size = 2; + + let param = DEFAULT_PARAMETERS; + let (cks, sks) = KEY_CACHE.get_from_params(param); + + println!("Chosen Parameter Set: {param:?}"); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = DEFAULT_PARAMETERS.message_modulus.0.pow(size as u32) as u64; + + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + // Encrypt the integers + let mut ctxt_1 = cks.encrypt_radix(clear1, size); + let ctxt_2 = cks.encrypt_radix(clear2, size); + + //scalar mul + c.bench_function("Unchecked Mul + Full Propagate", |b| { + b.iter(|| { + sks.unchecked_mul(&ctxt_1, &ctxt_2); + sks.full_propagate(&mut ctxt_1); + }) + }); +} + +fn radmodint_unchecked_mul_many_sizes(c: &mut Criterion) { + //Change the number of sample + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + + //At most 4bits + let max_message_space = 4; + + let message_spaces = [16]; + for msg_space in message_spaces { + let dec = radix_decomposition(msg_space, 2, max_message_space); + println!("radix decomposition = {dec:?}"); + for rad_decomp in dec.iter() { + //The carry space is at least equal to the msg_space + let carry_space = rad_decomp.msg_space; + + let param = + get_parameters_from_message_and_carry(1 << rad_decomp.msg_space, 1 << carry_space); + let (cks, sks) = KEY_CACHE.get_from_params(param); + + println!("Chosen Parameter Set: {param:?}"); + + //RNG + let mut rng = rand::thread_rng(); + + // Define the cleartexts + let clear1 = rng.gen::() % msg_space as u64; + let clear2 = rng.gen::() % msg_space as u64; + + // Encrypt the integers + + let mut ctxt_1 = cks.encrypt_radix(clear1, rad_decomp.block_number); + let ctxt_2 = cks.encrypt_radix(clear2, rad_decomp.block_number); + + println!( + "(Input Size {}; Carry_Space {}, Message_Space {}, Block Number {}): \ + Unchecked Mul\ + + \ + Full \ + Propagate ", + msg_space, carry_space, rad_decomp.msg_space, rad_decomp.block_number, + ); + let id = format!( + "(Integer-Mul-Propagate-Message_{}_Carry_{}_Input_{}_Block_{}):", + rad_decomp.msg_space, carry_space, msg_space, rad_decomp.block_number, + ); + + group.bench_function(&id, |b| { + b.iter(|| { + sks.unchecked_mul(&ctxt_1, &ctxt_2); + sks.full_propagate(&mut ctxt_1); + }) + }); + } + } +} +// +fn radmodint_wopbs(c: &mut Criterion) { + //Change the number of sample + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + + //At most 4bits + let max_message_space = 4; + + let message_spaces = [16]; + for msg_space in message_spaces { + let dec = radix_decomposition(msg_space, 2, max_message_space); + println!("radix decomposition = {dec:?}"); + //for rad_decomp in dec.iter() { + let rad_decomp = dec[0]; + //The carry space is at least equal to the msg_space + let carry_space = rad_decomp.msg_space; + + let param = get_parameters_from_message_and_carry_wopbs( + 1 << rad_decomp.msg_space, + 1 << carry_space, + ); + //let (mut cks, mut sks) = KEY_CACHE.get_from_params(param); + let keys = KEY_CACHE_WOPBS.get_from_param((param, param)); + let (cks, _, wopbs_shortint) = (keys.client_key(), keys.server_key(), keys.wopbs_key()); + + println!("Chosen Parameter Set: {param:?}"); + + let cks = tfhe::integer::client_key::ClientKey::from(cks.clone()); + + let wopbs = WopbsKey::new_from_shortint(wopbs_shortint); + let mut rng = rand::thread_rng(); + + let delta = 63 - f64::log2((param.message_modulus.0 * param.carry_modulus.0) as f64) as u64; + // Define the cleartexts + let clear1 = rng.gen::() % msg_space as u64; + + // Encrypt the integers + let ctxt_1 = cks.encrypt_radix(clear1, rad_decomp.block_number); + + let nb_bit_to_extract = f64::log2((param.message_modulus.0 * param.carry_modulus.0) as f64) + as usize + * rad_decomp.block_number; + + let mut lut_size = param.polynomial_size.0; + if (1 << nb_bit_to_extract) > wopbs_shortint.param.polynomial_size.0 { + lut_size = 1 << nb_bit_to_extract; + } + + let mut lut_1: Vec = vec![]; + let mut lut_2: Vec = vec![]; + for _ in 0..lut_size { + lut_1.push( + (rng.gen::() % (param.message_modulus.0 * param.carry_modulus.0) as u64) + << delta, + ); + lut_2.push( + (rng.gen::() % (param.message_modulus.0 * param.carry_modulus.0) as u64) + << delta, + ); + } + let big_lut = vec![lut_1, lut_2]; + + println!( + "(Input Size {}; Carry_Space {}, Message_Space {}, Block Number {}): \ + WoPBS", + msg_space, carry_space, rad_decomp.msg_space, rad_decomp.block_number, + ); + let id = format!( + "(Integer-WoPBS-Message_{}_Carry_{}_Input_{}_Block_{}):", + rad_decomp.msg_space, carry_space, msg_space, rad_decomp.block_number, + ); + + group.bench_function(&id, |b| b.iter(|| wopbs.wopbs(&ctxt_1, &big_lut))); + } + //} +} + +fn radmodint_wopbs_16bits_param_2_2_8_blocks(c: &mut Criterion) { + //Change the number of sample + let param = PARAM_MESSAGE_2_CARRY_2_16_BITS; + let nb_block = 8; + let input = 16; + + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + + println!("Chosen Parameter Set: {PARAM_MESSAGE_2_CARRY_2_16_BITS:?}"); + + let (cks, sks) = gen_keys(¶m); + let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, ¶m); + + let mut rng = rand::thread_rng(); + let delta = 63 - f64::log2((param.message_modulus.0 * param.carry_modulus.0) as f64) as u64; + // Define the cleartexts + let clear1 = rng.gen::() % param.message_modulus.0 as u64; + + // Encrypt the integers + let ctxt_1 = cks.encrypt_radix(clear1, nb_block); + + let nb_bit_to_extract = + f64::log2((param.message_modulus.0 * param.carry_modulus.0) as f64) as usize * nb_block; + + let mut lut_size = param.polynomial_size.0; + if (1 << nb_bit_to_extract) > param.polynomial_size.0 { + lut_size = 1 << nb_block; + } + + let mut lut_1: Vec = vec![]; + let mut lut_2: Vec = vec![]; + for _ in 0..lut_size { + lut_1.push( + (rng.gen::() % (param.message_modulus.0 * param.carry_modulus.0) as u64) << delta, + ); + lut_2.push( + (rng.gen::() % (param.message_modulus.0 * param.carry_modulus.0) as u64) << delta, + ); + } + let big_lut = vec![lut_1, lut_2]; + + let id = format!( + "(Integer-WoPBS-Message_{}_Carry_{}_Input_{}_Block_{}):", + param.message_modulus.0, param.message_modulus.0, input, nb_block + ); + + group.bench_function(&id, |b| b.iter(|| wopbs_key.wopbs(&ctxt_1, &big_lut))); +} + +fn radmodint_wopbs_16bits_param_4_4_4_blocks(c: &mut Criterion) { + //Change the number of sample + let param = PARAM_MESSAGE_4_CARRY_4_16_BITS; + let nb_block = 4; + let input = 16; + + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + + println!("Chosen Parameter Set: {param:?}"); + + let (cks, sks) = gen_keys(¶m); + let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, ¶m); + + let mut rng = rand::thread_rng(); + let delta = 63 - f64::log2((param.message_modulus.0 * param.carry_modulus.0) as f64) as u64; + // Define the cleartexts + let clear1 = rng.gen::() % param.message_modulus.0 as u64; + + // Encrypt the integers + let ctxt_1 = cks.encrypt_radix(clear1, nb_block); + + let nb_bit_to_extract = + f64::log2((param.message_modulus.0 * param.carry_modulus.0) as f64) as usize * nb_block; + + let mut lut_size = param.polynomial_size.0; + if (1 << nb_bit_to_extract) > param.polynomial_size.0 { + lut_size = 1 << nb_block; + } + + let mut lut_1: Vec = vec![]; + let mut lut_2: Vec = vec![]; + for _ in 0..lut_size { + lut_1.push( + (rng.gen::() % (param.message_modulus.0 * param.carry_modulus.0) as u64) << delta, + ); + lut_2.push( + (rng.gen::() % (param.message_modulus.0 * param.carry_modulus.0) as u64) << delta, + ); + } + let big_lut = vec![lut_1, lut_2]; + + let id = format!( + "(Integer-WoPBS-Message_{}_Carry_{}_Input_{}_Block_{}):", + param.message_modulus.0, param.message_modulus.0, input, nb_block + ); + + group.bench_function(&id, |b| b.iter(|| wopbs_key.wopbs(&ctxt_1, &big_lut))); +} + +fn radmodint_wopbs_32_bits(c: &mut Criterion) { + //Change the number of sample + let vec_param = &[ + PARAM_MESSAGE_1_CARRY_1_32_BITS, + PARAM_MESSAGE_2_CARRY_2_32_BITS, + PARAM_MESSAGE_4_CARRY_4_32_BITS, + ]; + let vec_nb_block = &[32, 16, 8]; + let input = 16; + + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + + for (param, nb_block) in vec_param.iter().zip(vec_nb_block.iter()) { + println!("Chosen Parameter Set: {param:?}"); + + let (cks, sks) = gen_keys(param); + let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, param); + + let mut rng = rand::thread_rng(); + let delta = 63 - f64::log2((param.message_modulus.0 * param.carry_modulus.0) as f64) as u64; + // Define the cleartexts + let clear1 = rng.gen::() % param.message_modulus.0 as u64; + + // Encrypt the integers + let ctxt_1 = cks.encrypt_radix(clear1, *nb_block); + + let nb_bit_to_extract = + f64::log2((param.message_modulus.0 * param.carry_modulus.0) as f64) as usize * nb_block; + + let mut lut_size = param.polynomial_size.0; + if (1 << nb_bit_to_extract) > param.polynomial_size.0 { + lut_size = 1 << nb_block; + } + + let mut lut_1: Vec = vec![]; + let mut lut_2: Vec = vec![]; + for _ in 0..lut_size { + lut_1.push( + (rng.gen::() % (param.message_modulus.0 * param.carry_modulus.0) as u64) + << delta, + ); + lut_2.push( + (rng.gen::() % (param.message_modulus.0 * param.carry_modulus.0) as u64) + << delta, + ); + } + let big_lut = vec![lut_1, lut_2]; + + let id = format!( + "(Integer-WoPBS-Message_{}_Carry_{}_Input_{}_Block_{}):", + param.message_modulus.0, param.message_modulus.0, input, nb_block + ); + + group.bench_function(&id, |b| b.iter(|| wopbs_key.wopbs(&ctxt_1, &big_lut))); + } +} + +fn concrete_integer_unchecked_mul_crt_16_bits(c: &mut Criterion) { + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + let param = tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4; + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + println!("Chosen Parameter Set: {param:?}"); + + let basis = vec![8, 9, 11, 13, 7]; + let mut modulus = 1; + for b in basis.iter() { + modulus *= b; + } + + // + // let block_modulus = DEFAULT_PARAMETERS.message_modulus.0 as u64; + // + // // message_modulus^vec_length + // let modulus = DEFAULT_PARAMETERS.message_modulus.0.pow(size as u32) as u64; + + let clear_0 = 29 % modulus; + let clear_1 = 23 % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + let ct_one = cks.encrypt_crt(clear_1, basis); + + let id = "(bench_concrete_integer_unchecked_mul_crt_16_bits):"; + // add the two ciphertexts + group.bench_function(id, |b| { + b.iter(|| { + sks.unchecked_crt_mul_assign(&mut ct_zero, &ct_one); + }) + }); +} + +fn concrete_integer_unchecked_add_crt_16_bits(c: &mut Criterion) { + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + let param = tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4; + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + println!("Chosen Parameter Set: {param:?}"); + + let basis = vec![8, 9, 11, 13, 7]; + let mut modulus = 1; + for b in basis.iter() { + modulus *= b; + } + + //RN + // + // let block_modulus = DEFAULT_PARAMETERS.message_modulus.0 as u64; + // + // // message_modulus^vec_length + // let modulus = DEFAULT_PARAMETERS.message_modulus.0.pow(size as u32) as u64; + + let clear_0 = 29 % modulus; + let clear_1 = 23 % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + let ct_one = cks.encrypt_crt(clear_1, basis); + + let id = "(bench_concrete_integer_unchecked_add_crt_16_bits):"; + // add the two ciphertexts + group.bench_function(id, |b| { + b.iter(|| { + sks.unchecked_crt_add_assign(&mut ct_zero, &ct_one); + }) + }); +} + +fn concrete_integer_unchecked_clean_carry_crt_16_bits(c: &mut Criterion) { + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + let param = tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4; + + // generate the server-client key set + //let (mut cks, mut sks) = + //gen_keys(&tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4, + //size); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + println!("Chosen Parameter Set: {param:?}"); + + let basis = vec![8, 9, 11, 13, 7]; + let mut modulus = 1; + for b in basis.iter() { + modulus *= b; + } + + //RN + // + // let block_modulus = DEFAULT_PARAMETERS.message_modulus.0 as u64; + // + // // message_modulus^vec_length + // let modulus = DEFAULT_PARAMETERS.message_modulus.0.pow(size as u32) as u64; + + let clear_0 = 29 % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + + let id = "(bench_concrete_integer_clean_carry_16_bits):"; + // add the two ciphertexts + group.bench_function(id, |b| { + b.iter(|| { + sks.pbs_crt_compliant_function_assign(&mut ct_zero, |x| x % basis[0]); + }) + }); +} + +fn concrete_integer_unchecked_mul_crt_32_bits(c: &mut Criterion) { + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + let param = tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4; + + // generate the server-client key set + //let (mut cks, mut sks) = + //gen_keys(&tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4, + //size); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + println!("Chosen Parameter Set: {param:?}"); + + let basis = vec![43, 47, 37, 49, 29, 41]; + let mut modulus = 1; + for b in basis.iter() { + modulus *= b; + } + + // + // let block_modulus = DEFAULT_PARAMETERS.message_modulus.0 as u64; + // + // // message_modulus^vec_length + // let modulus = DEFAULT_PARAMETERS.message_modulus.0.pow(size as u32) as u64; + + let clear_0 = 29 % modulus; + let clear_1 = 23 % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + let ct_one = cks.encrypt_crt(clear_1, basis); + + let id = "(bench_concrete_integer_unchecked_mul_crt_32_bits):"; + // add the two ciphertexts + group.bench_function(id, |b| { + b.iter(|| { + sks.unchecked_crt_mul_assign(&mut ct_zero, &ct_one); + }) + }); +} + +fn concrete_integer_unchecked_add_crt_32_bits(c: &mut Criterion) { + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + let param = tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4; + + // generate the server-client key set + //let (mut cks, mut sks) = + //gen_keys(&tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4, + //size); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + println!("Chosen Parameter Set: {param:?}"); + + let basis = vec![43, 47, 37, 49, 29, 41]; + let mut modulus = 1; + for b in basis.iter() { + modulus *= b; + } + + //RN + // + // let block_modulus = DEFAULT_PARAMETERS.message_modulus.0 as u64; + // + // // message_modulus^vec_length + // let modulus = DEFAULT_PARAMETERS.message_modulus.0.pow(size as u32) as u64; + + let clear_0 = 29 % modulus; + let clear_1 = 23 % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + let ct_one = cks.encrypt_crt(clear_1, basis); + + let id = "(bench_concrete_integer_unchecked_add_crt_32_bits):"; + // add the two ciphertexts + group.bench_function(id, |b| { + b.iter(|| { + sks.unchecked_crt_add_assign(&mut ct_zero, &ct_one); + }) + }); +} + +fn concrete_integer_unchecked_clean_carry_crt_32_bits(c: &mut Criterion) { + let mut group = c.benchmark_group("smaller-sample-count"); + group.sample_size(10); + let param = tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4; + + // generate the server-client key set + //let (mut cks, mut sks) = + //gen_keys(&tfhe::shortint::parameters::PARAM_MESSAGE_4_CARRY_4, + //size); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + println!("Chosen Parameter Set: {param:?}"); + + let basis = vec![43, 47, 37, 49, 29, 41]; + let mut modulus = 1; + for b in basis.iter() { + modulus *= b; + } + + //RN + // + // let block_modulus = DEFAULT_PARAMETERS.message_modulus.0 as u64; + // + // // message_modulus^vec_length + // let modulus = DEFAULT_PARAMETERS.message_modulus.0.pow(size as u32) as u64; + + let clear_0 = 29 % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + + let id = "(bench_concrete_integer_clean_carry_32_bits):"; + // add the two ciphertexts + group.bench_function(id, |b| { + b.iter(|| { + sks.pbs_crt_compliant_function_assign(&mut ct_zero, |x| x % basis[0]); + }) + }); +} diff --git a/tfhe/docs/SUMMARY.md b/tfhe/docs/SUMMARY.md index 2d4287771..cfeb1bc72 100644 --- a/tfhe/docs/SUMMARY.md +++ b/tfhe/docs/SUMMARY.md @@ -22,6 +22,9 @@ * [Cryptographic Parameters](shortint/parameters.md) * [Serialization/Deserialization](shortint/serialization.md) +## Integer +* [Summary](integer/SUMMARY.md) + ## C API * [Tutorial](c_api/tutorial.md) diff --git a/tfhe/docs/integer/SUMMARY.md b/tfhe/docs/integer/SUMMARY.md new file mode 100644 index 000000000..3f3b0da73 --- /dev/null +++ b/tfhe/docs/integer/SUMMARY.md @@ -0,0 +1,20 @@ +# TFHE-rs Integer User Guide + +[Introduction](introduction.md) + +# Getting Started + +[Installation](getting_started/installation.md) + +[Writing Your First Circuit](getting_started/first_circuit.md) + +[Types Of Operations](getting_started/operation_types.md) + +[List of Operations](getting_started/operation_list.md) + +[Cryptographic Parameters](getting_started/parameters.md) + + +# How to + +[Serialization / Deserialization](tutorials/serialization.md) diff --git a/tfhe/docs/integer/getting_started/first_circuit.md b/tfhe/docs/integer/getting_started/first_circuit.md new file mode 100644 index 000000000..fc803f36b --- /dev/null +++ b/tfhe/docs/integer/getting_started/first_circuit.md @@ -0,0 +1,105 @@ +# Writing Your First Circuit + + +## Key Types + +`integer` provides 2 basic key types: + - `ClientKey` + - `ServerKey` + +The `ClientKey` is the key that encrypts and decrypts messages, +thus this key is meant to be kept private and should never be shared. +This key is created from parameter values that will dictate both the security and efficiency +of computations. The parameters also set the maximum number of bits of message encrypted +in a ciphertext. + +The `ServerKey` is the key that is used to actually do the FHE computations. It contains (among other things) +a bootstrapping key and a keyswitching key. +This key is created from a `ClientKey` that needs to be shared to the server, therefore it is not +meant to be kept private. +A user with a `ServerKey` can compute on the encrypted data sent by the owner of the associated +`ClientKey`. + +To reflect that, computation/operation methods are tied to the `ServerKey` type. + + +## 1. Key Generation + +To generate the keys, a user needs two parameters: + - A set of `shortint` cryptographic parameters. + - The number of ciphertexts used to encrypt an integer (we call them "shortint blocks"). + + +For this example we are going to build a pair of keys that can encrypt an **8-bit** integer +by using **4** shortint blocks that store **2** bits of message each. + + +```rust +use tfhe::integer::gen_keys_radix; +use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let num_block = 4; + let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block); +} +``` + + + +## 2. Encrypting values + + +Once we have our keys we can encrypt values: + +```rust +use tfhe::integer::gen_keys_radix; +use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let num_block = 4; + let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block); + + let msg1 = 128; + let msg2 = 13; + + // We use the client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); +} +``` + +## 3. Computing and decrypting + +With our `server_key`, and encrypted values, we can now do an addition +and then decrypt the result. + +```rust +use tfhe::integer::gen_keys_radix; +use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let num_block = 4; + let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block); + + let msg1 = 128; + let msg2 = 13; + + // message_modulus^vec_length + let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64; + + // We use the client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + // We use the server public key to execute an integer circuit: + let ct_3 = server_key.unchecked_add(&ct_1, &ct_2); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_3); + + assert_eq!(output, (msg1 + msg2) % modulus); +} +``` diff --git a/tfhe/docs/integer/getting_started/installation.md b/tfhe/docs/integer/getting_started/installation.md new file mode 100644 index 000000000..bcb4b4a28 --- /dev/null +++ b/tfhe/docs/integer/getting_started/installation.md @@ -0,0 +1,11 @@ +# Installation + +## Cargo.toml + +To use `integer`, you will need to add TFHE-rs to the list of dependencies your project, by updating your `Cargo.toml` file. + +```toml +tfhe = { version = "0.2.0", features = ["integer", "x86_64-unix"] } +``` + +TODO doc \ No newline at end of file diff --git a/tfhe/docs/integer/getting_started/operation_list.md b/tfhe/docs/integer/getting_started/operation_list.md new file mode 100644 index 000000000..865619d74 --- /dev/null +++ b/tfhe/docs/integer/getting_started/operation_list.md @@ -0,0 +1,15 @@ +# List of available operations + +`integer` comes with a set of already implemented functions: + + +- addition between two ciphertexts +- addition between a ciphertext and an unencrypted scalar +- multiplication of a ciphertext by an unencrypted scalar +- bitwise shift `<<`, `>>` +- bitwise and, or and xor +- multiplication between two ciphertexts +- subtraction of a ciphertext by another ciphertext +- subtraction of a ciphertext by an unencrypted scalar +- negation of a ciphertext + diff --git a/tfhe/docs/integer/getting_started/operation_types.md b/tfhe/docs/integer/getting_started/operation_types.md new file mode 100644 index 000000000..282851390 --- /dev/null +++ b/tfhe/docs/integer/getting_started/operation_types.md @@ -0,0 +1,86 @@ +# How Integers are represented + + +In `integer`, the encrypted data is split amongst many ciphertexts +encrypted using the `shortint` library. + +This crate implements two ways to represent an integer: + - the Radix representation + - the CRT (Chinese Reminder Theorem) representation + +## Radix based Integers +The first possibility to represent a large integer is to use a radix-based decomposition on the +plaintexts. Let $$B \in \mathbb{N}$$ be a basis such that the size of $$B$$ is smaller (or equal) +to four bits. +Then, an integer $$m \in \mathbb{N}$$ can be written as $$m = m_0 + m_1*B + m_2*B^2 + ... $$, where +each $$m_i$$ is strictly smaller than $$B$$. Each $$m_i$$ is then independently encrypted. In +the end, an Integer ciphertext is defined as a set of Shortint ciphertexts. + +In practice, the definition of an Integer requires the basis and the number of blocks. This is +done at the key creation step. +```rust +use tfhe::integer::gen_keys_radix; +use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let num_block = 4; + let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block); +} +``` + +In this example, the keys are dedicated to Integers decomposed as four blocks using the basis +$$B=2^2$$. Otherwise said, they allow to work on Integers modulus $$(2^2)^4 = 2^8$$. + + +In this representation, the correctness of operations requires to propagate the carries +between the ciphertext. This operation is costly since it relies on the computation of many +programmable bootstrapping over Shortints. + + +## CRT based Integers +The second approach to represent large integers is based on the Chinese Remainder Theorem. +In this cases, the basis $$B$$ is composed of several integers $$b_i$$, such that there are +pairwise coprime, and each b_i has a size smaller than four bits. Then, the Integer will be +defined modulus $$\prod b_i$$. For an integer $$m$$, its CRT decomposition is simply defined as +$$m % b_0, m % b_1, ...$$. Each part is then encrypted as a Shortint ciphertext. In +the end, an Integer ciphertext is defined as a set of Shortint ciphertexts. + +An example of such a basis +could be $$B = [2, 3, 5]$$. This means that the Integer is defined modulus $$2*3*5 = 30$$. + +This representation has many advantages: no carry propagation is required, so that only cleaning +the carry buffer of each ciphertexts is enough. This implies that operations can easily be +parallelized. Moreover, it allows to efficiently compute PBS in the case where the function is +CRT compliant. + +A variant of the CRT is proposed, where each block might be associated to a different key couple. +In the end, a keychain is required to the computations, but performance might be improved. + + + +# Types of operations + + +Much like `shortint`, the operations available via a `ServerKey` may come in different variants: + + - operations that take their inputs as encrypted values. + - scalar operations take at least one non-encrypted value as input. + +For example, the addition has both variants: + + - `ServerKey::unchecked_add` which takes two encrypted values and adds them. + - `ServerKey::unchecked_scalar_add` which takes an encrypted value and a clear value (the + so-called scalar) and adds them. + +Each operation may come in different 'flavors': + + - `unchecked`: Always does the operation, without checking if the result may exceed the capacity of + the plaintext space. + - `checked`: Checks are done before computing the operation, returning an error if operation + cannot be done safely. + - `smart`: Always does the operation, if the operation cannot be computed safely, the smart operation + will propagate the carry buffer to make the operation possible. + +Not all operations have these 3 flavors, as some of them are implemented in a way that the operation +is always possible without ever exceeding the plaintext space capacity. diff --git a/tfhe/docs/integer/getting_started/parameters.md b/tfhe/docs/integer/getting_started/parameters.md new file mode 100644 index 000000000..13313c96e --- /dev/null +++ b/tfhe/docs/integer/getting_started/parameters.md @@ -0,0 +1,6 @@ +# Use of parameters + + +`integer` does not come with its own set of parameters, instead it uses +parameters from the `shortint` crate. Currently, only the parameters +`PARAM_MESSAGE_{X}_CARRY_{X}` with `X` in [1,4] can be used in `integer`. diff --git a/tfhe/docs/integer/how_to/pbs.md b/tfhe/docs/integer/how_to/pbs.md new file mode 100644 index 000000000..7dfe5c4ae --- /dev/null +++ b/tfhe/docs/integer/how_to/pbs.md @@ -0,0 +1,51 @@ +# The tree programmable bootstrapping + +In `integer`, the user can evaluate any function on an encrypted ciphertext. To do so the user must first +create a `treepbs key`, choose a function to evaluate and give them as parameters to the `tree programmable bootstrapping`. + +Two versions of the tree pbs are implemented: the `standard` version that computes a result according to every encrypted +bit (message and carry), and the `base` version that only takes into account the message bits of each block. + +{% hint style="warning" %} + +The `tree pbs` is quite slow, therefore its use is currently restricted to two and three blocks integer ciphertexts. + +{% endhint %} + +```rust +use tfhe::integer::gen_keys_radix; +use tfhe::integer::wopbs::WopbsKey; +use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; +use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + +fn main() { + let num_block = 2; + // Generate the client key and the server key: + let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block); + + let msg: u64 = 27; + let ct = cks.encrypt(msg); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus.0.pow(num_block as u32) as u64; + + let wopbs_key = WopbsKey::new_wopbs_key(&cks.as_ref(), &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2); + + let f = |x: u64| x * x; + + // evaluate f + let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct); + let lut = wopbs_key.generate_lut_radix(&ct, f); + let ct_res = wopbs_key.wopbs(&ct, &lut); + let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + + // decryption + let res = cks.decrypt(&ct_res); + + let clear = f(msg) % modulus; + assert_eq!(res, clear); +} +``` + +# The WOP programmable bootstrapping + diff --git a/tfhe/docs/integer/introduction.md b/tfhe/docs/integer/introduction.md new file mode 100644 index 000000000..df065b29e --- /dev/null +++ b/tfhe/docs/integer/introduction.md @@ -0,0 +1,8 @@ +# TFHE-rs Integer + +## Introduction + +`integer` is a module of TFHE-rs based on its `shortint` module, this crate provides +large precision integers by using multiple `shortint` ciphertexts. + +The intended target audience for this library is people who are somewhat familiar with cryptography. diff --git a/tfhe/docs/integer/tutorials/circuit_evaluation.md b/tfhe/docs/integer/tutorials/circuit_evaluation.md new file mode 100644 index 000000000..ef51d7b8a --- /dev/null +++ b/tfhe/docs/integer/tutorials/circuit_evaluation.md @@ -0,0 +1,120 @@ +# Circuit evaluation + +Let's try to do a circuit evaluation using the different flavours of operations we already introduced. +For a very small circuit, the `unchecked` flavour may be enough to do the computation correctly. +Otherwise, the `checked` and `smart` are the best options. + +As an example, let's do a scalar multiplication, a subtraction and an addition. + + +```rust +use tfhe::integer::gen_keys_radix; +use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + +fn main() { + let num_block = 4; + let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block); + + let msg1 = 12; + let msg2 = 11; + let msg3 = 9; + let scalar = 3; + + // message_modulus^vec_length + let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64; + + // We use the client key to encrypt two messages: + let mut ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + let ct_3 = client_key.encrypt(msg2); + + server_key.unchecked_small_scalar_mul_assign(&mut ct_1, scalar); + + server_key.unchecked_sub_assign(&mut ct_1, &ct_2); + + server_key.unchecked_add_assign(&mut ct_1, &ct_3); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_1); + // The carry buffer has been overflowed, the result is not correct + assert_ne!(output, ((msg1 * scalar as u64 - msg2) + msg3) % modulus as u64); +} +``` + +During this computation the carry buffer has been overflowed and as all the operations were `unchecked` the output +may be incorrect. + +If we redo this same circuit but using the `checked` flavour, a panic will occur. + +```rust +use tfhe::integer::gen_keys_radix; +use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + +fn main() { + let num_block = 2; + let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block); + + let msg1 = 12; + let msg2 = 11; + let msg3 = 9; + let scalar = 3; + + // message_modulus^vec_length + let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64; + + // We use the client key to encrypt two messages: + let mut ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + let ct_3 = client_key.encrypt(msg3); + + let result = server_key.checked_small_scalar_mul_assign(&mut ct_1, scalar); + assert!(result.is_ok()); + + let result = server_key.checked_sub_assign(&mut ct_1, &ct_2); + assert!(result.is_err()); + + // We use the client key to decrypt the output of the circuit: + // Only the scalar multiplication could be done + let output = client_key.decrypt(&ct_1); + assert_eq!(output, (msg1 * scalar) % modulus as u64); +} +``` + +Therefore the `checked` flavour permits to manually manage the overflow of the carry buffer +by raising an error if the correctness is not guaranteed. + +Lastly, using the `smart` flavour will output the correct result all the time. However, the computation may be slower +as the carry buffer may be propagated during the computations. + +```rust +use tfhe::integer::gen_keys_radix; +use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + +fn main() { + let num_block = 4; + let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block); + + let msg1 = 12; + let msg2 = 11; + let msg3 = 9; + let scalar = 3; + + // message_modulus^vec_length + let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64; + + // We use the client key to encrypt two messages: + let mut ct_1 = client_key.encrypt(msg1); + let mut ct_2 = client_key.encrypt(msg2); + let mut ct_3 = client_key.encrypt(msg3); + + server_key.smart_scalar_mul_assign(&mut ct_1, scalar); + + server_key.smart_sub_assign(&mut ct_1, &mut ct_2); + + server_key.smart_add_assign(&mut ct_1, &mut ct_3); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_1); + assert_eq!(output, ((msg1 * scalar as u64 - msg2) + msg3) % modulus as u64); +} +``` \ No newline at end of file diff --git a/tfhe/docs/integer/tutorials/serialization.md b/tfhe/docs/integer/tutorials/serialization.md new file mode 100644 index 000000000..daf717688 --- /dev/null +++ b/tfhe/docs/integer/tutorials/serialization.md @@ -0,0 +1,78 @@ +# Serialization / Deserialization + +As explained in the introduction, some types (`Serverkey`, `Ciphertext`) are meant to be shared +with the server that does the computations. + +The easiest way to send these data to a server is to use the serialization and deserialization features. +concrete-integer uses the serde framework, serde's Serialize and Deserialize are implemented. + +To be able to serialize our data, we need to pick a [data format], for our use case, +[bincode] is a good choice, mainly because it is binary format. + + +```toml +# Cargo.toml + +[dependencies] +# ... +bincode = "1.3.3" +``` + + +```rust +// main.rs + +use bincode; + +use std::io::Cursor; +use tfhe::integer::{gen_keys_radix, ServerKey, RadixCiphertext}; +use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + + +fn main() -> Result<(), Box> { + // We generate a set of client/server keys, using the default parameters: + let num_block = 4; + let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block); + + let msg1 = 201; + let msg2 = 12; + + // message_modulus^vec_length + let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64; + + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + let mut serialized_data = Vec::new(); + bincode::serialize_into(&mut serialized_data, &server_key)?; + bincode::serialize_into(&mut serialized_data, &ct_1)?; + bincode::serialize_into(&mut serialized_data, &ct_2)?; + + // Simulate sending serialized data to a server and getting + // back the serialized result + let serialized_result = server_function(&serialized_data)?; + let result: RadixCiphertext = bincode::deserialize(&serialized_result)?; + + let output = client_key.decrypt(&result); + assert_eq!(output, (msg1 + msg2) % modulus); + Ok(()) +} + + +fn server_function(serialized_data: &[u8]) -> Result, Box> { + let mut serialized_data = Cursor::new(serialized_data); + let server_key: ServerKey = bincode::deserialize_from(&mut serialized_data)?; + let ct_1: RadixCiphertext = bincode::deserialize_from(&mut serialized_data)?; + let ct_2: RadixCiphertext = bincode::deserialize_from(&mut serialized_data)?; + + let result = server_key.unchecked_add(&ct_1, &ct_2); + + let serialized_result = bincode::serialize(&result)?; + + Ok(serialized_result) +} +``` + +[serde]: https://crates.io/crates/serde +[data format]: https://serde.rs/#data-formats +[bincode]: https://crates.io/crates/bincode diff --git a/tfhe/src/integer/ciphertext/mod.rs b/tfhe/src/integer/ciphertext/mod.rs new file mode 100644 index 000000000..3e90b2b47 --- /dev/null +++ b/tfhe/src/integer/ciphertext/mod.rs @@ -0,0 +1,57 @@ +//! This module implements the ciphertext structures. +use crate::shortint::Ciphertext as ShortintCiphertext; +use serde::{Deserialize, Serialize}; + +/// Structure containing a ciphertext in radix decomposition. +#[derive(Serialize, Clone, Deserialize)] +pub struct RadixCiphertext { + /// The blocks are stored from LSB to MSB + pub(crate) blocks: Vec, +} + +pub trait IntegerCiphertext: Clone { + fn from_blocks(blocks: Vec) -> Self; + fn blocks(&self) -> &[ShortintCiphertext]; + fn blocks_mut(&mut self) -> &mut [ShortintCiphertext]; + fn moduli(&self) -> Vec { + self.blocks() + .iter() + .map(|x| x.message_modulus.0 as u64) + .collect() + } +} + +impl IntegerCiphertext for RadixCiphertext { + fn blocks(&self) -> &[ShortintCiphertext] { + &self.blocks + } + fn blocks_mut(&mut self) -> &mut [ShortintCiphertext] { + &mut self.blocks + } + fn from_blocks(blocks: Vec) -> Self { + Self { blocks } + } +} + +impl IntegerCiphertext for CrtCiphertext { + fn blocks(&self) -> &[ShortintCiphertext] { + &self.blocks + } + fn blocks_mut(&mut self) -> &mut [ShortintCiphertext] { + &mut self.blocks + } + fn from_blocks(blocks: Vec) -> Self { + let moduli = blocks.iter().map(|x| x.message_modulus.0 as u64).collect(); + Self { blocks, moduli } + } +} + +/// Structure containing a ciphertext in CRT decomposition. +/// +/// For this CRT decomposition, each block is encrypted using +/// the same parameters. +#[derive(Serialize, Clone, Deserialize)] +pub struct CrtCiphertext { + pub(crate) blocks: Vec, + pub(crate) moduli: Vec, +} diff --git a/tfhe/src/integer/client_key/crt.rs b/tfhe/src/integer/client_key/crt.rs new file mode 100644 index 000000000..99ea74691 --- /dev/null +++ b/tfhe/src/integer/client_key/crt.rs @@ -0,0 +1,67 @@ +use super::ClientKey; +use crate::integer::CrtCiphertext; + +use serde::{Deserialize, Serialize}; + +/// Client key "specialized" for CRT decomposition. +/// +/// This key is a simple wrapper of the [ClientKey], +/// that only encrypt and decrypt in CRT decomposition. +/// +/// # Example +/// +/// ```rust +/// use tfhe::integer::CrtClientKey; +/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; +/// +/// let basis = vec![2, 3, 5]; +/// let cks = CrtClientKey::new(PARAM_MESSAGE_2_CARRY_2, basis); +/// +/// let msg = 13_u64; +/// +/// // Encryption: +/// let ct = cks.encrypt(msg); +/// +/// // Decryption: +/// let dec = cks.decrypt(&ct); +/// assert_eq!(msg, dec); +/// ``` +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub struct CrtClientKey { + key: ClientKey, + moduli: Vec, +} + +impl AsRef for CrtClientKey { + fn as_ref(&self) -> &ClientKey { + &self.key + } +} + +impl CrtClientKey { + pub fn new(parameters: crate::shortint::Parameters, moduli: Vec) -> Self { + Self { + key: ClientKey::new(parameters), + moduli, + } + } + + pub fn encrypt(&self, message: u64) -> CrtCiphertext { + self.key.encrypt_crt(message, self.moduli.clone()) + } + + pub fn decrypt(&self, ciphertext: &CrtCiphertext) -> u64 { + self.key.decrypt_crt(ciphertext) + } + + /// Returns the parameters used by the client key. + pub fn parameters(&self) -> crate::shortint::Parameters { + self.key.parameters() + } +} + +impl From<(ClientKey, Vec)> for CrtClientKey { + fn from((key, moduli): (ClientKey, Vec)) -> Self { + Self { key, moduli } + } +} diff --git a/tfhe/src/integer/client_key/mod.rs b/tfhe/src/integer/client_key/mod.rs new file mode 100644 index 000000000..2f82c773a --- /dev/null +++ b/tfhe/src/integer/client_key/mod.rs @@ -0,0 +1,434 @@ +//! This module implements the generation of the client keys structs +//! +//! Client keys are the keys used to encrypt an decrypt data. +//! These are private and **MUST NOT** be shared. + +mod crt; +mod radix; +pub(crate) mod utils; + +use crate::integer::ciphertext::{CrtCiphertext, RadixCiphertext}; +use crate::integer::client_key::utils::i_crt; +use crate::shortint::parameters::MessageModulus; +use crate::shortint::{ + Ciphertext as ShortintCiphertext, ClientKey as ShortintClientKey, + Parameters as ShortintParameters, +}; +use serde::{Deserialize, Serialize}; +pub use utils::radix_decomposition; + +pub use crt::CrtClientKey; +pub use radix::RadixClientKey; + +/// A structure containing the client key, which must be kept secret. +/// +/// This key can be used to encrypt both in Radix and CRT +/// decompositions. +/// +/// Using this key, for both decompositions, each block will +/// use the same crypto parameters. +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub struct ClientKey { + pub(crate) key: ShortintClientKey, +} + +impl From for ClientKey { + fn from(key: ShortintClientKey) -> Self { + Self { key } + } +} + +impl From for ShortintClientKey { + fn from(key: ClientKey) -> ShortintClientKey { + key.key + } +} + +impl AsRef for ClientKey { + fn as_ref(&self) -> &ClientKey { + self + } +} + +impl ClientKey { + /// Creates a Client Key. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key, that can encrypt in + /// // radix and crt decomposition, where each block of the decomposition + /// // have over 2 bits of message modulus. + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// ``` + pub fn new(parameter_set: ShortintParameters) -> Self { + Self { + key: ShortintClientKey::new(parameter_set), + } + } + + pub fn parameters(&self) -> ShortintParameters { + self.key.parameters + } + + /// Encrypts an integer in radix decomposition + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let num_block = 4; + /// + /// let msg = 167_u64; + /// + /// // 2 * 4 = 8 bits of message + /// let ct = cks.encrypt_radix(msg, num_block); + /// + /// // Decryption + /// let dec = cks.decrypt_radix(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn encrypt_radix(&self, message: u64, num_blocks: usize) -> RadixCiphertext { + let mut blocks = Vec::with_capacity(num_blocks); + + // Bits of message put to 1 + let mask = (self.key.parameters.message_modulus.0 - 1) as u64; + + let mut power = 1_u64; + // Put each decomposition into a new ciphertext + for _ in 0..num_blocks { + let decomp = (message & (mask * power)) / power; + + let ct = self.key.encrypt(decomp); + blocks.push(ct); + + // modulus to the power i + power *= self.key.parameters.message_modulus.0 as u64; + } + + RadixCiphertext { blocks } + } + + /// Encrypts an integer in radix decomposition without padding bit + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let num_block = 4; + /// + /// let msg = 167_u64; + /// + /// // 2 * 4 = 8 bits of message + /// let ct = cks.encrypt_radix_without_padding(msg, num_block); + /// + /// // Decryption + /// let dec = cks.decrypt_radix_without_padding(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn encrypt_radix_without_padding( + &self, + message: u64, + num_blocks: usize, + ) -> RadixCiphertext { + let mut blocks = Vec::with_capacity(num_blocks); + + // Bits of message put to 1 + let mask = (self.key.parameters.message_modulus.0 - 1) as u64; + + let mut power = 1_u64; + // Put each decomposition into a new ciphertext + for _ in 0..num_blocks { + let decomp = (message & (mask * power)) / power; + + // encryption + let ct = self.key.encrypt_without_padding(decomp); + blocks.push(ct); + + // modulus to the power i + power *= self.key.parameters.message_modulus.0 as u64; + } + + RadixCiphertext { blocks } + } + + /// Encrypts one block. + /// + /// This returns a shortint ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let num_block = 4; + /// + /// let msg = 2_u64; + /// + /// // Encryption + /// let ct = cks.encrypt_one_block(msg); + /// + /// // Decryption + /// let dec = cks.decrypt_one_block(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn encrypt_one_block(&self, message: u64) -> ShortintCiphertext { + self.key.encrypt(message) + } + + /// Decrypts one block. + /// + /// This takes a shortint ciphertext as input. + pub fn decrypt_one_block(&self, ct: &ShortintCiphertext) -> u64 { + self.key.decrypt(ct) + } + + /// Decrypts a ciphertext encrypting an radix integer + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let num_block = 4; + /// + /// let msg = 191_u64; + /// + /// // Encryption + /// let ct = cks.encrypt_radix(msg, num_block); + /// + /// // Decryption + /// let dec = cks.decrypt_radix(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn decrypt_radix(&self, ctxt: &RadixCiphertext) -> u64 { + let mut result = 0_u64; + let mut shift = 1_u64; + let modulus = self.parameters().message_modulus.0 as u64; + + for c_i in ctxt.blocks.iter() { + // decrypt the component i of the integer and multiply it by the radix product + let block_value = self.key.decrypt_message_and_carry(c_i).wrapping_mul(shift); + + // update the result + result = result.wrapping_add(block_value); + + // update the shift for the next iteration + shift = shift.wrapping_mul(modulus); + } + + let whole_modulus = modulus.pow(ctxt.blocks.len() as u32); + + result % whole_modulus + } + + /// Decrypts a ciphertext encrypting an radix integer encrypted without padding + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let num_block = 4; + /// + /// let msg = 191_u64; + /// + /// // Encryption + /// let ct = cks.encrypt_radix_without_padding(msg, num_block); + /// + /// // Decryption + /// let dec = cks.decrypt_radix_without_padding(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn decrypt_radix_without_padding(&self, ctxt: &RadixCiphertext) -> u64 { + let mut result = 0_u64; + let mut shift = 1_u64; + let modulus = self.parameters().message_modulus.0 as u64; + for c_i in ctxt.blocks.iter() { + // decrypt the component i of the integer and multiply it by the radix product + let block_value = self + .key + .decrypt_message_and_carry_without_padding(c_i) + .wrapping_mul(shift); + + // update the result + result = result.wrapping_add(block_value); + + // update the shift for the next iteration + shift = shift.wrapping_mul(modulus); + } + + let whole_modulus = modulus.pow(ctxt.blocks.len() as u32); + + result % whole_modulus + } + + /// Encrypts an integer using crt representation + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 13_u64; + /// + /// // Encryption: + /// let basis: Vec = vec![2, 3, 5]; + /// let ct = cks.encrypt_crt(msg, basis); + /// + /// // Decryption: + /// let dec = cks.decrypt_crt(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn encrypt_crt(&self, message: u64, base_vec: Vec) -> CrtCiphertext { + let mut ctxt_vect = Vec::with_capacity(base_vec.len()); + + // Put each decomposition into a new ciphertext + for modulus in base_vec.iter().copied() { + // encryption + let ct = self + .key + .encrypt_with_message_modulus(message, MessageModulus(modulus as usize)); + + ctxt_vect.push(ct); + } + + CrtCiphertext { + blocks: ctxt_vect, + moduli: base_vec, + } + } + + /// Decrypts an integer in crt decomposition + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let mut cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 27_u64; + /// let basis: Vec = vec![2, 3, 5]; + /// + /// // Encryption: + /// let mut ct = cks.encrypt_crt(msg, basis); + /// + /// // Decryption: + /// let dec = cks.decrypt_crt(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn decrypt_crt(&self, ctxt: &CrtCiphertext) -> u64 { + let mut val: Vec = Vec::with_capacity(ctxt.blocks.len()); + + // Decrypting each block individually + for (c_i, b_i) in ctxt.blocks.iter().zip(ctxt.moduli.iter()) { + // decrypt the component i of the integer and multiply it by the radix product + val.push(self.key.decrypt_message_and_carry(c_i) % b_i); + } + println!("VAL DEC = {val:?}"); + + // Computing the inverse CRT to recompose the message + let result = i_crt(&ctxt.moduli, &val); + + let whole_modulus: u64 = ctxt.moduli.iter().copied().product(); + + result % whole_modulus + } + + /// Encrypts a small integer message using the client key and some moduli without padding bit. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_3_CARRY_3; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_3_CARRY_3); + /// + /// let msg = 13_u64; + /// + /// // Encryption of one message: + /// let basis: Vec = vec![2, 3, 5]; + /// let ct = cks.encrypt_native_crt(msg, basis); + /// + /// // Decryption: + /// let dec = cks.decrypt_native_crt(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn encrypt_native_crt(&self, message: u64, base_vec: Vec) -> CrtCiphertext { + //Empty vector of ciphertexts + let mut ct_vec = Vec::with_capacity(base_vec.len()); + + //Put each decomposition into a new ciphertext + for modulus in base_vec.iter() { + // encryption + let ct = self.key.encrypt_native_crt(message, *modulus as u8); + + // put it in the vector of ciphertexts + ct_vec.push(ct); + } + + CrtCiphertext { + blocks: ct_vec, + moduli: base_vec, + } + } + + /// Decrypts a ciphertext encrypting an integer message with some moduli basis without + /// padding bit. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::ClientKey; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_3_CARRY_3; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_3_CARRY_3); + /// + /// let msg = 27_u64; + /// let basis: Vec = vec![2, 3, 5]; + /// // Encryption of one message: + /// let mut ct = cks.encrypt_native_crt(msg, basis); + /// + /// // Decryption: + /// let dec = cks.decrypt_native_crt(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn decrypt_native_crt(&self, ct: &CrtCiphertext) -> u64 { + let mut val: Vec = vec![]; + + //Decrypting each block individually + for (c_i, b_i) in ct.blocks.iter().zip(ct.moduli.iter()) { + //decrypt the component i of the integer and multiply it by the radix product + val.push(self.key.decrypt_message_native_crt(c_i, *b_i as u8)); + } + + //Computing the inverse CRT to recompose the message + let result = i_crt(&ct.moduli, &val); + + let whole_modulus: u64 = ct.moduli.iter().copied().product(); + + result % whole_modulus + } +} diff --git a/tfhe/src/integer/client_key/radix.rs b/tfhe/src/integer/client_key/radix.rs new file mode 100644 index 000000000..3a309c363 --- /dev/null +++ b/tfhe/src/integer/client_key/radix.rs @@ -0,0 +1,78 @@ +//! Definition of the client key for radix decomposition + +use super::ClientKey; +use crate::integer::RadixCiphertext; +use crate::shortint::{Ciphertext as ShortintCiphertext, Parameters as ShortintParameters}; + +use serde::{Deserialize, Serialize}; + +/// Client key "specialized" for radix decomposition. +/// +/// This key is a simple wrapper of the [ClientKey], +/// that only encrypt and decrypt in radix decomposition. +/// +/// # Example +/// +/// ```rust +/// use tfhe::integer::RadixClientKey; +/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; +/// +/// // 2 * 4 = 8 bits of message +/// let num_block = 4; +/// let cks = RadixClientKey::new(PARAM_MESSAGE_2_CARRY_2, num_block); +/// +/// let msg = 167_u64; +/// +/// let ct = cks.encrypt(msg); +/// +/// // Decryption +/// let dec = cks.decrypt(&ct); +/// assert_eq!(msg, dec); +/// ``` +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub struct RadixClientKey { + key: ClientKey, + num_blocks: usize, +} + +impl AsRef for RadixClientKey { + fn as_ref(&self) -> &ClientKey { + &self.key + } +} + +impl RadixClientKey { + pub fn new(parameters: ShortintParameters, num_blocks: usize) -> Self { + Self { + key: ClientKey::new(parameters), + num_blocks, + } + } + + pub fn encrypt(&self, message: u64) -> RadixCiphertext { + self.key.encrypt_radix(message, self.num_blocks) + } + + pub fn decrypt(&self, ciphertext: &RadixCiphertext) -> u64 { + self.key.decrypt_radix(ciphertext) + } + + /// Returns the parameters used by the client key. + pub fn parameters(&self) -> ShortintParameters { + self.key.parameters() + } + + pub fn encrypt_one_block(&self, message: u64) -> ShortintCiphertext { + self.key.encrypt_one_block(message) + } + + pub fn decrypt_one_block(&self, ct: &ShortintCiphertext) -> u64 { + self.key.decrypt_one_block(ct) + } +} + +impl From<(ClientKey, usize)> for RadixClientKey { + fn from((key, num_blocks): (ClientKey, usize)) -> Self { + Self { key, num_blocks } + } +} diff --git a/tfhe/src/integer/client_key/utils.rs b/tfhe/src/integer/client_key/utils.rs new file mode 100644 index 000000000..305df8fe6 --- /dev/null +++ b/tfhe/src/integer/client_key/utils.rs @@ -0,0 +1,91 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct RadixDecomposition { + pub msg_space: usize, + pub block_number: usize, +} + +/// Computes possible radix decompositions +/// +/// Takes the number of bit of the message space as input and output a vector containing all the +/// correct +/// possible block decomposition assuming the same message space for all blocks. +/// Lower and upper bounds define the minimal and maximal space to be considered +/// Example: 6,2,4 -> \[\[2,3\], \[3,2\]\] : \[msg_space = 2 bits, block_number = 3\] +/// +/// # Example +/// +/// ```rust +/// use tfhe::integer::client_key::radix_decomposition; +/// let input_space = 16; +/// let min = 2; +/// let max = 4; +/// let decomp = radix_decomposition(input_space, min, max); +/// +/// // Check that 3 possible radix decompositions are provided +/// assert_eq!(decomp.len(), 3); +/// ``` +pub fn radix_decomposition( + input_space: usize, + min_space: usize, + max_space: usize, +) -> Vec { + let mut out: Vec = vec![]; + let mut max = max_space; + if max_space > input_space { + max = input_space; + } + for msg_space in min_space..max + 1 { + let mut block_number = input_space / msg_space; + //Manual ceil of the division + if input_space % msg_space != 0 { + block_number += 1; + } + out.push(RadixDecomposition { + msg_space, + block_number, + }) + } + out +} + +// Tools to compute the inverse Chinese Remainder Theorem +pub(crate) fn extended_euclid(f: i64, g: i64) -> (usize, Vec, Vec, Vec, Vec) { + let mut s: Vec = vec![1, 0]; + let mut t: Vec = vec![0, 1]; + let mut r: Vec = vec![f, g]; + let mut q: Vec = vec![0]; + let mut i = 1; + while r[i] != 0 { + q.push(r[i - 1] / r[i]); //q[i] + r.push(r[i - 1] - q[i] * r[i]); //r[i+1] + s.push(s[i - 1] - q[i] * s[i]); //s[i+1] + t.push(t[i - 1] - q[i] * t[i]); //t[i+1] + i += 1; + } + let l: usize = i - 1; + (l, r, s, t, q) +} + +pub(crate) fn i_crt(modulus: &[u64], val: &[u64]) -> u64 { + let big_mod = modulus.iter().product::(); + let mut c: Vec = vec![0; val.len()]; + let mut out: u64 = 0; + + for i in 0..val.len() { + let tmp_mod = big_mod / modulus[i]; + let (l, _, s, _, _) = extended_euclid(tmp_mod as i64, modulus[i] as i64); + let sl: u64 = if s[l] < 0 { + //a is positive + (s[l] % modulus[i] as i64 + modulus[i] as i64) as u64 + } else { + s[l] as u64 + }; + c[i] = val[i].wrapping_mul(sl); + c[i] %= modulus[i]; + + out = out.wrapping_add(c[i] * tmp_mod); + } + out % big_mod +} diff --git a/tfhe/src/integer/keycache.rs b/tfhe/src/integer/keycache.rs new file mode 100644 index 000000000..f28334d20 --- /dev/null +++ b/tfhe/src/integer/keycache.rs @@ -0,0 +1,43 @@ +use crate::shortint::Parameters; +use lazy_static::lazy_static; + +use crate::integer::wopbs::WopbsKey; +use crate::integer::{ClientKey, ServerKey}; + +#[derive(Default)] +pub struct IntegerKeyCache; + +impl IntegerKeyCache { + pub fn get_from_params(&self, params: Parameters) -> (ClientKey, ServerKey) { + let keys = crate::shortint::keycache::KEY_CACHE.get_from_param(params); + let (client_key, server_key) = (keys.client_key(), keys.server_key()); + + let client_key = ClientKey::from(client_key.clone()); + let server_key = ServerKey::from_shortint(&client_key, server_key.clone()); + + (client_key, server_key) + } + + pub fn get_shortint_from_params( + &self, + params: Parameters, + ) -> (crate::shortint::ClientKey, crate::shortint::ServerKey) { + let keys = crate::shortint::keycache::KEY_CACHE.get_from_param(params); + (keys.client_key().clone(), keys.server_key().clone()) + } +} + +#[derive(Default)] +pub struct WopbsKeyCache; + +impl WopbsKeyCache { + pub fn get_from_params(&self, params: (Parameters, Parameters)) -> WopbsKey { + let shortint_wops_key = crate::shortint::keycache::KEY_CACHE_WOPBS.get_from_param(params); + WopbsKey::from(shortint_wops_key.wopbs_key().clone()) + } +} + +lazy_static! { + pub static ref KEY_CACHE: IntegerKeyCache = Default::default(); + pub static ref KEY_CACHE_WOPBS: WopbsKeyCache = Default::default(); +} diff --git a/tfhe/src/integer/mod.rs b/tfhe/src/integer/mod.rs new file mode 100755 index 000000000..23a98caa8 --- /dev/null +++ b/tfhe/src/integer/mod.rs @@ -0,0 +1,135 @@ +//! # Description +//! +//! This library makes it possible to execute modular operations over encrypted integer. +//! +//! It allows to execute an integer circuit on an untrusted server because both circuit inputs +//! outputs are kept private. +//! +//! Data are encrypted on the client side, before being sent to the server. +//! On the server side every computation is performed on ciphertexts +//! +//! # Quick Example +//! +//! The following piece of code shows how to generate keys and run a integer circuit +//! homomorphically. +//! +//! ```rust +//! use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; +//! use tfhe::integer::gen_keys_radix; +//! +//! //4 blocks for the radix decomposition +//! let number_of_blocks = 4; +//! // Modulus = (2^2)*4 = 2^8 (from the parameters chosen and the number of blocks +//! let modulus = 1 << 8; +//! +//! // Generation of the client/server keys, using the default parameters: +//! let (mut client_key, mut server_key) = +//! gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, number_of_blocks); +//! +//! let msg1 = 153; +//! let msg2 = 125; +//! +//! // Encryption of two messages using the client key: +//! let ct_1 = client_key.encrypt(msg1); +//! let ct_2 = client_key.encrypt(msg2); +//! +//! // Homomorphic evaluation of an integer circuit (here, an addition) using the server key: +//! let ct_3 = server_key.unchecked_add(&ct_1, &ct_2); +//! +//! // Decryption of the ciphertext using the client key: +//! let output = client_key.decrypt(&ct_3); +//! assert_eq!(output, (msg1 + msg2) % modulus); +//! ``` +//! +//! # Warning +//! This uses cryptographic parameters from the `concrete-shortint` crates. +//! Currently, the radix approach is only compatible with parameter sets such +//! that the message and carry buffers have the same size. +extern crate core; + +#[cfg(test)] +#[macro_use] +mod tests; + +pub mod ciphertext; +pub mod client_key; +#[cfg(any(test, feature = "internal-keycache"))] +pub mod keycache; +pub mod parameters; +pub mod server_key; +pub mod wopbs; + +pub use ciphertext::{CrtCiphertext, IntegerCiphertext, RadixCiphertext}; +pub use client_key::{ClientKey, CrtClientKey, RadixClientKey}; +pub use server_key::{CheckError, ServerKey}; + +/// Generate a couple of client and server keys with given parameters +/// +/// * the client key is used to encrypt and decrypt and has to be kept secret; +/// * the server key is used to perform homomorphic operations on the server side and it is meant to +/// be published (the client sends it to the server). +/// +/// ```rust +/// use tfhe::integer::gen_keys; +/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; +/// +/// // generate the client key and the server key: +/// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); +/// ``` +pub fn gen_keys( + parameters_set: &crate::shortint::parameters::Parameters, +) -> (ClientKey, ServerKey) { + #[cfg(any(test, feature = "internal-keycache"))] + { + keycache::KEY_CACHE.get_from_params(*parameters_set) + } + #[cfg(all(not(test), not(feature = "internal-keycache")))] + { + let cks = ClientKey::new(*parameters_set); + let sks = ServerKey::new(&cks); + + (cks, sks) + } +} + +/// Generate a couple of client and server keys with given parameters +/// +/// Contrary to [gen_keys], this returns a [RadixClientKey] +/// +/// ```rust +/// use tfhe::integer::gen_keys_radix; +/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; +/// +/// // generate the client key and the server key: +/// let num_blocks = 4; +/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); +/// ``` +pub fn gen_keys_radix( + parameters_set: &crate::shortint::parameters::Parameters, + num_blocks: usize, +) -> (RadixClientKey, ServerKey) { + let (cks, sks) = gen_keys(parameters_set); + + (RadixClientKey::from((cks, num_blocks)), sks) +} + +/// Generate a couple of client and server keys with given parameters +/// +/// Contrary to [gen_keys], this returns a [CrtClientKey] +/// +/// ```rust +/// use tfhe::integer::gen_keys_crt; +/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; +/// +/// // generate the client key and the server key: +/// let basis = vec![2, 3, 5]; +/// let (cks, sks) = gen_keys_crt(&PARAM_MESSAGE_2_CARRY_2, basis); +/// ``` +pub fn gen_keys_crt( + parameters_set: &crate::shortint::parameters::Parameters, + basis: Vec, +) -> (CrtClientKey, ServerKey) { + let (cks, sks) = gen_keys(parameters_set); + + (CrtClientKey::from((cks, basis)), sks) +} diff --git a/tfhe/src/integer/parameters/mod.rs b/tfhe/src/integer/parameters/mod.rs new file mode 100644 index 000000000..f1028548a --- /dev/null +++ b/tfhe/src/integer/parameters/mod.rs @@ -0,0 +1,125 @@ +#![allow(clippy::excessive_precision)] +pub use crate::shortint::Parameters; + +use crate::shortint::parameters::{CarryModulus, MessageModulus}; +pub use crate::shortint::parameters::{ + DecompositionBaseLog, DecompositionLevelCount, DispersionParameter, GlweDimension, + LweDimension, PolynomialSize, StandardDev, +}; + +pub const ALL_PARAMETER_VEC_INTEGER_16_BITS: [Parameters; 2] = [ + PARAM_MESSAGE_4_CARRY_4_16_BITS, + PARAM_MESSAGE_2_CARRY_2_16_BITS, +]; + +pub const PARAM_MESSAGE_4_CARRY_4_16_BITS: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(16), +}; + +pub const PARAM_MESSAGE_2_CARRY_2_16_BITS: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(4), +}; + +pub const PARAM_MESSAGE_4_CARRY_4_32_BITS: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(16), +}; +pub const PARAM_MESSAGE_2_CARRY_2_32_BITS: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(4), +}; +pub const PARAM_MESSAGE_1_CARRY_1_32_BITS: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(2), +}; + +pub const PARAM_4_BITS_5_BLOCKS: Parameters = Parameters { + lwe_dimension: LweDimension(667), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.0000000004168323308734758), + glwe_modular_std_dev: StandardDev(0.00000000000000000000000000000004905643852600863), + pbs_base_log: DecompositionBaseLog(7), + pbs_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(1), + ks_level: DecompositionLevelCount(14), + pfks_level: DecompositionLevelCount(6), + pfks_base_log: DecompositionBaseLog(7), + pfks_modular_std_dev: StandardDev(0.00000000000000000000000000000004905643852600863), + cbs_level: DecompositionLevelCount(7), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; diff --git a/tfhe/src/integer/server_key/crt/add_crt.rs b/tfhe/src/integer/server_key/crt/add_crt.rs new file mode 100644 index 000000000..747e0eed7 --- /dev/null +++ b/tfhe/src/integer/server_key/crt/add_crt.rs @@ -0,0 +1,73 @@ +use crate::integer::{CrtCiphertext, ServerKey}; + +impl ServerKey { + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis); + /// + /// sks.smart_crt_add_assign(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn smart_crt_add( + &self, + ct_left: &mut CrtCiphertext, + ct_right: &mut CrtCiphertext, + ) -> CrtCiphertext { + if !self.is_crt_add_possible(ct_left, ct_right) { + self.full_extract_message_assign(ct_left); + self.full_extract_message_assign(ct_right); + } + self.unchecked_crt_add(ct_left, ct_right) + } + + pub fn smart_crt_add_assign(&self, ct_left: &mut CrtCiphertext, ct_right: &mut CrtCiphertext) { + //If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_crt_add_possible(ct_left, ct_right) { + self.full_extract_message_assign(ct_left); + self.full_extract_message_assign(ct_right); + } + self.unchecked_crt_add_assign(ct_left, ct_right); + } + + pub fn is_crt_add_possible(&self, ct_left: &CrtCiphertext, ct_right: &CrtCiphertext) -> bool { + for (ct_left_i, ct_right_i) in ct_left.blocks.iter().zip(ct_right.blocks.iter()) { + if !self.key.is_add_possible(ct_left_i, ct_right_i) { + return false; + } + } + true + } + + pub fn unchecked_crt_add_assign(&self, ct_left: &mut CrtCiphertext, ct_right: &CrtCiphertext) { + for (ct_left_i, ct_right_i) in ct_left.blocks.iter_mut().zip(ct_right.blocks.iter()) { + self.key.unchecked_add_assign(ct_left_i, ct_right_i); + } + } + + pub fn unchecked_crt_add( + &self, + ct_left: &CrtCiphertext, + ct_right: &CrtCiphertext, + ) -> CrtCiphertext { + let mut ct_res = ct_left.clone(); + self.unchecked_crt_add_assign(&mut ct_res, ct_right); + ct_res + } +} diff --git a/tfhe/src/integer/server_key/crt/mod.rs b/tfhe/src/integer/server_key/crt/mod.rs new file mode 100644 index 000000000..3f10f49bb --- /dev/null +++ b/tfhe/src/integer/server_key/crt/mod.rs @@ -0,0 +1,102 @@ +use crate::integer::ciphertext::CrtCiphertext; +use crate::integer::ServerKey; + +#[cfg(test)] +mod tests; + +mod add_crt; +mod mul_crt; +mod neg_crt; +mod scalar_add_crt; +mod scalar_mul_crt; +mod scalar_sub_crt; +mod sub_crt; + +impl ServerKey { + /// Extract all the messages. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let ctxt_2 = cks.encrypt_crt(clear_2, basis); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_crt_add_assign(&mut ctxt_1, &ctxt_2); + /// + /// sks.full_extract_message_assign(&mut ctxt_1); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn full_extract_message_assign(&self, ctxt: &mut CrtCiphertext) { + for ct_i in ctxt.blocks.iter_mut() { + self.key.message_extract_assign(ct_i); + } + } + + /// Computes a PBS for CRT-compliant functions. + /// + /// # Warning + /// This allows to compute programmable bootstrapping over integers under the condition that + /// the function is said to be CRT-compliant. This means that the function should be correct + /// when evaluated on each modular block independently (e.g. arithmetic functions). + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_crt; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// // Generate the client key and the server key: + /// let basis = vec![2, 3, 5]; + /// let (cks, sks) = gen_keys_crt(&DEFAULT_PARAMETERS, basis); + /// + /// let clear_1 = 28; + /// + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// + /// // Compute homomorphically the crt-compliant PBS + /// sks.pbs_crt_compliant_function_assign(&mut ctxt_1, |x| x * x * x); + /// + /// // Decrypt + /// let res = cks.decrypt(&ctxt_1); + /// assert_eq!((clear_1 * clear_1 * clear_1) % 30, res); + /// ``` + pub fn pbs_crt_compliant_function_assign(&self, ct1: &mut CrtCiphertext, f: F) + where + F: Fn(u64) -> u64, + { + let basis = &ct1.moduli; + + let accumulators = basis + .iter() + .copied() + .map(|b| self.key.generate_accumulator(|x| f(x) % b)); + + for (block, acc) in ct1.blocks.iter_mut().zip(accumulators) { + self.key + .keyswitch_programmable_bootstrap_assign(block, &acc); + } + } + + pub fn pbs_crt_compliant_function(&self, ct1: &CrtCiphertext, f: F) -> CrtCiphertext + where + F: Fn(u64) -> u64, + { + let mut ct_res = ct1.clone(); + self.pbs_crt_compliant_function_assign(&mut ct_res, f); + ct_res + } +} diff --git a/tfhe/src/integer/server_key/crt/mul_crt.rs b/tfhe/src/integer/server_key/crt/mul_crt.rs new file mode 100644 index 000000000..237aa6c87 --- /dev/null +++ b/tfhe/src/integer/server_key/crt/mul_crt.rs @@ -0,0 +1,88 @@ +use crate::integer::{CrtCiphertext, ServerKey}; + +impl ServerKey { + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer + /// values in the CRT decomposition. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_3_CARRY_3; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_3_CARRY_3); + /// + /// let clear_1 = 29; + /// let clear_2 = 23; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let ctxt_2 = cks.encrypt_crt(clear_2, basis); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_crt_mul_assign(&mut ctxt_1, &ctxt_2); + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_mul_assign(&self, ct_left: &mut CrtCiphertext, ct_right: &CrtCiphertext) { + for (ct_left, ct_right) in ct_left.blocks.iter_mut().zip(ct_right.blocks.iter()) { + self.key.unchecked_mul_lsb_assign(ct_left, ct_right); + } + } + + pub fn unchecked_crt_mul( + &self, + ct_left: &CrtCiphertext, + ct_right: &CrtCiphertext, + ) -> CrtCiphertext { + let mut ct_res = ct_left.clone(); + self.unchecked_crt_mul_assign(&mut ct_res, ct_right); + ct_res + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer + /// values in the CRT decomposition. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_3_CARRY_3; + /// + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_3_CARRY_3); + /// + /// let clear_1 = 29; + /// let clear_2 = 29; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis); + /// + /// // Compute homomorphically a multiplication + /// sks.smart_crt_mul_assign(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn smart_crt_mul_assign(&self, ct_left: &mut CrtCiphertext, ct_right: &mut CrtCiphertext) { + for (block_left, block_right) in ct_left.blocks.iter_mut().zip(ct_right.blocks.iter_mut()) { + self.key.smart_mul_lsb_assign(block_left, block_right); + } + } + + pub fn smart_crt_mul( + &self, + ct_left: &CrtCiphertext, + ct_right: &mut CrtCiphertext, + ) -> CrtCiphertext { + let mut ct_res = ct_left.clone(); + self.smart_crt_mul_assign(&mut ct_res, ct_right); + ct_res + } +} diff --git a/tfhe/src/integer/server_key/crt/neg_crt.rs b/tfhe/src/integer/server_key/crt/neg_crt.rs new file mode 100644 index 000000000..cae641717 --- /dev/null +++ b/tfhe/src/integer/server_key/crt/neg_crt.rs @@ -0,0 +1,94 @@ +use crate::integer::{CrtCiphertext, ServerKey}; + +impl ServerKey { + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// This function computes the opposite of a message without checking if it exceeds the + /// capacity of the ciphertext. + /// + /// The result is returned as a new ciphertext. + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear = 14_u64; + /// let basis = vec![2, 3, 5]; + /// + /// let mut ctxt = cks.encrypt_crt(clear, basis.clone()); + /// + /// sks.unchecked_crt_neg_assign(&mut ctxt); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!(16, res); + /// ``` + pub fn unchecked_crt_neg(&self, ctxt: &CrtCiphertext) -> CrtCiphertext { + let mut result = ctxt.clone(); + + self.unchecked_crt_neg_assign(&mut result); + + result + } + + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// This function computes the opposite of a message without checking if it exceeds the + /// capacity of the ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + pub fn unchecked_crt_neg_assign(&self, ctxt: &mut CrtCiphertext) { + for ct_i in ctxt.blocks.iter_mut() { + self.key.unchecked_neg_assign(ct_i); + } + } + + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear = 14_u64; + /// let basis = vec![2, 3, 5]; + /// + /// let mut ctxt = cks.encrypt_crt(clear, basis.clone()); + /// + /// sks.smart_crt_neg_assign(&mut ctxt); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!(16, res); + /// ``` + pub fn smart_crt_neg_assign(&self, ctxt: &mut CrtCiphertext) { + if !self.is_crt_neg_possible(ctxt) { + self.full_extract_message_assign(ctxt); + } + self.unchecked_crt_neg_assign(ctxt); + } + + pub fn smart_crt_neg(&self, ctxt: &mut CrtCiphertext) -> CrtCiphertext { + if !self.is_crt_neg_possible(ctxt) { + self.full_extract_message_assign(ctxt); + } + self.unchecked_crt_neg(ctxt) + } + + pub fn is_crt_neg_possible(&self, ctxt: &CrtCiphertext) -> bool { + for ct_i in ctxt.blocks.iter() { + if !self.key.is_neg_possible(ct_i) { + return false; + } + } + true + } +} diff --git a/tfhe/src/integer/server_key/crt/scalar_add_crt.rs b/tfhe/src/integer/server_key/crt/scalar_add_crt.rs new file mode 100644 index 000000000..11799928b --- /dev/null +++ b/tfhe/src/integer/server_key/crt/scalar_add_crt.rs @@ -0,0 +1,212 @@ +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::{CrtCiphertext, ServerKey}; + +impl ServerKey { + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.unchecked_crt_scalar_add_assign(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_scalar_add(&self, ct: &CrtCiphertext, scalar: u64) -> CrtCiphertext { + let mut result = ct.clone(); + self.unchecked_crt_scalar_add_assign(&mut result, scalar); + result + } + + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + pub fn unchecked_crt_scalar_add_assign(&self, ct: &mut CrtCiphertext, scalar: u64) { + //Add the crt representation of the scalar to the ciphertext + for (ct_i, mod_i) in ct.blocks.iter_mut().zip(ct.moduli.iter()) { + let scalar_i = scalar % mod_i; + + self.key.unchecked_scalar_add_assign(ct_i, scalar_i as u8); + } + } + + /// Verifies if a scalar can be added to a ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// let tmp = sks.is_crt_scalar_add_possible(&mut ctxt_1, clear_2); + /// + /// assert_eq!(true, tmp); + /// ``` + pub fn is_crt_scalar_add_possible(&self, ct: &CrtCiphertext, scalar: u64) -> bool { + for (ct_i, mod_i) in ct.blocks.iter().zip(ct.moduli.iter()) { + let scalar_i = scalar % mod_i; + + if !self.key.is_scalar_add_possible(ct_i, scalar_i as u8) { + return false; + } + } + + true + } + + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.checked_crt_scalar_add_assign(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_crt_scalar_add( + &self, + ct: &CrtCiphertext, + scalar: u64, + ) -> Result { + if self.is_crt_scalar_add_possible(ct, scalar) { + Ok(self.unchecked_crt_scalar_add(ct, scalar)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + pub fn checked_crt_scalar_add_assign( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + if self.is_crt_scalar_add_possible(ct, scalar) { + self.unchecked_crt_scalar_add_assign(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically the addition of ciphertext with a scalar. + /// + /// The result is returned in a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// let ctxt = sks.smart_crt_scalar_add(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_add(&self, ct: &mut CrtCiphertext, scalar: u64) -> CrtCiphertext { + if !self.is_crt_scalar_add_possible(ct, scalar) { + self.full_extract_message_assign(ct); + } + + let mut ct = ct.clone(); + self.unchecked_crt_scalar_add_assign(&mut ct, scalar); + ct + } + + /// Computes homomorphically the addition of ciphertext with a scalar. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.smart_crt_scalar_add_assign(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_add_assign(&self, ct: &mut CrtCiphertext, scalar: u64) { + if !self.is_crt_scalar_add_possible(ct, scalar) { + self.full_extract_message_assign(ct); + } + self.unchecked_crt_scalar_add_assign(ct, scalar); + } +} diff --git a/tfhe/src/integer/server_key/crt/scalar_mul_crt.rs b/tfhe/src/integer/server_key/crt/scalar_mul_crt.rs new file mode 100644 index 000000000..d062785a6 --- /dev/null +++ b/tfhe/src/integer/server_key/crt/scalar_mul_crt.rs @@ -0,0 +1,226 @@ +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::{CrtCiphertext, ServerKey}; + +impl ServerKey { + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 2; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.unchecked_crt_scalar_mul_assign(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_scalar_mul(&self, ctxt: &CrtCiphertext, scalar: u64) -> CrtCiphertext { + let mut ct_result = ctxt.clone(); + self.unchecked_crt_scalar_mul_assign(&mut ct_result, scalar); + + ct_result + } + + pub fn unchecked_crt_scalar_mul_assign(&self, ctxt: &mut CrtCiphertext, scalar: u64) { + for (ct_i, mod_i) in ctxt.blocks.iter_mut().zip(ctxt.moduli.iter()) { + self.key + .unchecked_scalar_mul_assign(ct_i, (scalar % mod_i) as u8); + } + } + + ///Verifies if ct1 can be multiplied by scalar. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 2; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// let tmp = sks.is_crt_scalar_mul_possible(&mut ctxt_1, clear_2); + /// + /// assert_eq!(true, tmp); + /// ``` + pub fn is_crt_scalar_mul_possible(&self, ctxt: &CrtCiphertext, scalar: u64) -> bool { + for (ct_i, mod_i) in ctxt.blocks.iter().zip(ctxt.moduli.iter()) { + if !self + .key + .is_scalar_mul_possible(ct_i, (scalar % mod_i) as u8) + { + return false; + } + } + true + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 2; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.checked_crt_scalar_mul_assign(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_crt_scalar_mul( + &self, + ct: &CrtCiphertext, + scalar: u64, + ) -> Result { + let mut ct_result = ct.clone(); + + // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext + if self.is_crt_scalar_mul_possible(ct, scalar) { + ct_result = self.unchecked_crt_scalar_mul(&ct_result, scalar); + + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is assigned to the ciphertext given + /// as parameter. + /// Otherwise [CheckError::CarryFull] is returned. + pub fn checked_crt_scalar_mul_assign( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext + if self.is_crt_scalar_mul_possible(ct, scalar) { + self.unchecked_crt_scalar_mul_assign(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// `small` means the scalar value shall fit in a __shortint block__. + /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2, + /// the scalar should fit in 2 bits. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// let ctxt = sks.smart_crt_scalar_mul(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_mul(&self, ctxt: &mut CrtCiphertext, scalar: u64) -> CrtCiphertext { + if !self.is_crt_scalar_mul_possible(ctxt, scalar) { + self.full_extract_message_assign(ctxt); + } + self.unchecked_crt_scalar_mul(ctxt, scalar) + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// `small` means the scalar shall value fit in a __shortint block__. + /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2, + /// the scalar should fit in 2 bits. + /// + /// The result is assigned to the input ciphertext + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.smart_crt_scalar_mul_assign(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_mul_assign(&self, ctxt: &mut CrtCiphertext, scalar: u64) { + if !self.is_crt_small_scalar_mul_possible(ctxt, scalar) { + self.full_extract_message_assign(ctxt); + } + self.unchecked_crt_scalar_mul_assign(ctxt, scalar); + } + + pub fn is_crt_small_scalar_mul_possible(&self, ctxt: &CrtCiphertext, scalar: u64) -> bool { + for ct_i in ctxt.blocks.iter() { + if !self.key.is_scalar_mul_possible(ct_i, scalar as u8) { + return false; + } + } + true + } +} diff --git a/tfhe/src/integer/server_key/crt/scalar_sub_crt.rs b/tfhe/src/integer/server_key/crt/scalar_sub_crt.rs new file mode 100644 index 000000000..edca2cecf --- /dev/null +++ b/tfhe/src/integer/server_key/crt/scalar_sub_crt.rs @@ -0,0 +1,204 @@ +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::{CrtCiphertext, ServerKey}; + +impl ServerKey { + /// Computes homomorphically a subtraction between a ciphertext and a scalar. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 7; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.unchecked_crt_scalar_sub_assign(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_scalar_sub(&self, ct: &CrtCiphertext, scalar: u64) -> CrtCiphertext { + let mut result = ct.clone(); + self.unchecked_crt_scalar_sub_assign(&mut result, scalar); + result + } + + pub fn unchecked_crt_scalar_sub_assign(&self, ct: &mut CrtCiphertext, scalar: u64) { + //Put each decomposition into a new ciphertext + for (ct_i, mod_i) in ct.blocks.iter_mut().zip(ct.moduli.iter()) { + let neg_scalar = (mod_i - scalar % mod_i) % mod_i; + self.key + .unchecked_scalar_add_assign_crt(ct_i, neg_scalar as u8); + } + } + + /// Verifies if the subtraction of a ciphertext by scalar can be computed. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 7; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// let bit = sks.is_crt_scalar_sub_possible(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!(true, bit); + /// ``` + pub fn is_crt_scalar_sub_possible(&self, ct: &CrtCiphertext, scalar: u64) -> bool { + for (ct_i, mod_i) in ct.blocks.iter().zip(ct.moduli.iter()) { + let neg_scalar = (mod_i - scalar % mod_i) % mod_i; + + if !self.key.is_scalar_add_possible(ct_i, neg_scalar as u8) { + return false; + } + } + true + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 8; + /// let basis = vec![2, 3, 5]; + /// + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// let ct_res = sks.checked_crt_scalar_sub(&mut ctxt_1, clear_2)?; + /// + /// // Decrypt: + /// let dec = cks.decrypt_crt(&ct_res); + /// assert_eq!((clear_1 - clear_2) % 30, dec); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_crt_scalar_sub( + &self, + ct: &CrtCiphertext, + scalar: u64, + ) -> Result { + if self.is_crt_scalar_sub_possible(ct, scalar) { + Ok(self.unchecked_crt_scalar_sub(ct, scalar)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 7; + /// let basis = vec![2, 3, 5]; + /// + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.checked_crt_scalar_sub_assign(&mut ctxt_1, clear_2)?; + /// + /// // Decrypt: + /// let dec = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 - clear_2) % 30, dec); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_crt_scalar_sub_assign( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + if self.is_crt_scalar_sub_possible(ct, scalar) { + self.unchecked_crt_scalar_sub_assign(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 7; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.smart_crt_scalar_sub_assign(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_sub(&self, ct: &mut CrtCiphertext, scalar: u64) -> CrtCiphertext { + if !self.is_crt_scalar_sub_possible(ct, scalar) { + self.full_extract_message_assign(ct); + } + + self.unchecked_crt_scalar_sub(ct, scalar) + } + + pub fn smart_crt_scalar_sub_assign(&self, ct: &mut CrtCiphertext, scalar: u64) { + if !self.is_crt_scalar_sub_possible(ct, scalar) { + self.full_extract_message_assign(ct); + } + + self.unchecked_crt_scalar_sub_assign(ct, scalar); + } +} diff --git a/tfhe/src/integer/server_key/crt/sub_crt.rs b/tfhe/src/integer/server_key/crt/sub_crt.rs new file mode 100644 index 000000000..22fbbbcfc --- /dev/null +++ b/tfhe/src/integer/server_key/crt/sub_crt.rs @@ -0,0 +1,170 @@ +use crate::integer::{CrtCiphertext, ServerKey}; + +impl ServerKey { + /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. + /// + /// This function computes the subtraction without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 5; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis.clone()); + /// + /// let ctxt = sks.unchecked_crt_sub(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_sub( + &self, + ctxt_left: &CrtCiphertext, + ctxt_right: &CrtCiphertext, + ) -> CrtCiphertext { + let mut result = ctxt_left.clone(); + self.unchecked_crt_sub_assign(&mut result, ctxt_right); + result + } + + /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. + /// + /// This function computes the subtraction without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 5; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis.clone()); + /// + /// let ctxt = sks.unchecked_crt_sub(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_sub_assign( + &self, + ctxt_left: &mut CrtCiphertext, + ctxt_right: &CrtCiphertext, + ) { + let neg = self.unchecked_crt_neg(ctxt_right); + self.unchecked_crt_add_assign(ctxt_left, &neg); + } + + /// Computes homomorphically the subtraction between ct_left and ct_right. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 5; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis.clone()); + /// + /// let ctxt = sks.smart_crt_sub(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn smart_crt_sub( + &self, + ctxt_left: &mut CrtCiphertext, + ctxt_right: &mut CrtCiphertext, + ) -> CrtCiphertext { + // If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_crt_sub_possible(ctxt_left, ctxt_right) { + self.full_extract_message_assign(ctxt_left); + self.full_extract_message_assign(ctxt_right); + } + + let mut result = ctxt_left.clone(); + self.unchecked_crt_sub_assign(&mut result, ctxt_right); + + result + } + + /// Computes homomorphically the subtraction between ct_left and ct_right. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 5; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis.clone()); + /// + /// sks.smart_crt_sub_assign(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn smart_crt_sub_assign( + &self, + ctxt_left: &mut CrtCiphertext, + ctxt_right: &mut CrtCiphertext, + ) { + // If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_crt_sub_possible(ctxt_left, ctxt_right) { + self.full_extract_message_assign(ctxt_left); + self.full_extract_message_assign(ctxt_right); + } + + self.unchecked_crt_sub_assign(ctxt_left, ctxt_right); + } + + pub fn is_crt_sub_possible( + &self, + ctxt_left: &CrtCiphertext, + ctxt_right: &CrtCiphertext, + ) -> bool { + for (ct_left_i, ct_right_i) in ctxt_left.blocks.iter().zip(ctxt_right.blocks.iter()) { + if !self.key.is_sub_possible(ct_left_i, ct_right_i) { + return false; + } + } + true + } +} diff --git a/tfhe/src/integer/server_key/crt/tests.rs b/tfhe/src/integer/server_key/crt/tests.rs new file mode 100644 index 000000000..a40388761 --- /dev/null +++ b/tfhe/src/integer/server_key/crt/tests.rs @@ -0,0 +1,273 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::shortint::parameters::*; +use crate::shortint::Parameters; +use rand::Rng; + +create_parametrized_test!(integer_unchecked_crt_mul); +create_parametrized_test!(integer_smart_crt_add); +create_parametrized_test!(integer_smart_crt_mul); +create_parametrized_test!(integer_smart_crt_neg); + +create_parametrized_test!(integer_smart_crt_scalar_add); + +create_parametrized_test!(integer_smart_crt_scalar_mul); +create_parametrized_test!(integer_smart_crt_scalar_sub); +create_parametrized_test!(integer_smart_crt_sub); + +/// Number of loop iteration within randomized tests +const NB_TEST: usize = 30; + +/// Smaller number of loop iteration within randomized test, +/// meant for test where the function tested is more expensive +const NB_TEST_SMALLER: usize = 10; + +fn make_basis(message_modulus: usize) -> Vec { + match message_modulus { + 2 => vec![2], + 3 => vec![2], + n if n < 8 => vec![2, 3], + n if n < 16 => vec![2, 5, 7], + _ => vec![3, 7, 13], + } +} + +fn integer_unchecked_crt_mul(param: Parameters) { + // generate the server-client key set + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // Define CRT basis, and global modulus + let basis = make_basis(param.message_modulus.0); + let modulus = basis.iter().product::(); + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + let ct_one = cks.encrypt_crt(clear_1, basis.clone()); + + // add the two ciphertexts + sks.unchecked_crt_mul_assign(&mut ct_zero, &ct_one); + + // decryption of ct_res + let dec_res = cks.decrypt_crt(&ct_zero); + + // assert + assert_eq!((clear_0 * clear_1) % modulus, dec_res % modulus); + } +} + +fn integer_smart_crt_add(param: Parameters) { + // Define CRT basis, and global modulus + let basis = make_basis(param.message_modulus.0); + let modulus = basis.iter().product::(); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + let mut clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + let mut ct_one = cks.encrypt_crt(clear_1, basis); + + for _ in 0..NB_TEST { + // add the two ciphertexts + sks.smart_crt_add_assign(&mut ct_zero, &mut ct_one); + + // decryption of ct_res + let dec_res = cks.decrypt_crt(&ct_zero); + + // assert + clear_0 += clear_1; + assert_eq!(clear_0 % modulus, dec_res % modulus); + } +} + +fn integer_smart_crt_mul(param: Parameters) { + // generate the server-client key set + let (cks, sks) = KEY_CACHE.get_from_params(param); + + // Define CRT basis, and global modulus + let basis = make_basis(param.message_modulus.0); + let modulus = basis.iter().product::(); + + println!("BASIS = {basis:?}"); + + //RNG + let mut rng = rand::thread_rng(); + + let mut clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + let mut ct_one = cks.encrypt_crt(clear_1, basis); + + for _ in 0..NB_TEST_SMALLER { + // mul the two ciphertexts + sks.smart_crt_mul_assign(&mut ct_zero, &mut ct_one); + + // decryption of ct_res + let dec_res = cks.decrypt_crt(&ct_zero); + + clear_0 = (clear_0 * clear_1) % modulus; + assert_eq!(clear_0 % modulus, dec_res % modulus); + } +} + +fn integer_smart_crt_neg(param: Parameters) { + // Define CRT basis, and global modulus + let basis = make_basis(param.message_modulus.0); + let modulus = basis.iter().product::(); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + let mut clear_0 = rng.gen::() % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis); + + for _ in 0..NB_TEST { + // add the two ciphertexts + sks.smart_crt_neg_assign(&mut ct_zero); + + // decryption of ct_res + let dec_res = cks.decrypt_crt(&ct_zero); + + clear_0 = (modulus - clear_0) % modulus; + + // println!("clear = {}", clear_0); + // assert + assert_eq!(clear_0, dec_res); + } +} + +fn integer_smart_crt_scalar_add(param: Parameters) { + // Define CRT basis, and global modulus + let basis = make_basis(param.message_modulus.0); + let modulus = basis.iter().product::(); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + + // add the two ciphertexts + sks.smart_crt_scalar_add_assign(&mut ct_zero, clear_1); + + // decryption of ct_res + let dec_res = cks.decrypt_crt(&ct_zero); + + // assert + assert_eq!((clear_0 + clear_1) % modulus, dec_res % modulus); + } +} + +fn integer_smart_crt_scalar_mul(param: Parameters) { + // Define CRT basis, and global modulus + let basis = make_basis(param.message_modulus.0); + let modulus = basis.iter().product::(); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + + // add the two ciphertexts + sks.smart_crt_scalar_mul_assign(&mut ct_zero, clear_1); + + // decryption of ct_res + let dec_res = cks.decrypt_crt(&ct_zero); + + // assert + assert_eq!((clear_0 * clear_1) % modulus, dec_res % modulus); + } +} + +fn integer_smart_crt_scalar_sub(param: Parameters) { + // Define CRT basis, and global modulus + let basis = make_basis(param.message_modulus.0); + let modulus = basis.iter().product::(); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + let mut clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis); + + for _ in 0..NB_TEST { + // add the two ciphertexts + sks.smart_crt_scalar_sub_assign(&mut ct_zero, clear_1); + + // decryption of ct_res + let dec_res = cks.decrypt_crt(&ct_zero); + + // println!("clear_0 = {}, clear_1 = {}, modulus = {}", clear_0, clear_1, modulus); + + // assert + clear_0 = (clear_0 + modulus - clear_1) % modulus; + assert_eq!(clear_0, dec_res % modulus); + } +} + +fn integer_smart_crt_sub(param: Parameters) { + // Define CRT basis, and global modulus + let basis = make_basis(param.message_modulus.0); + let modulus = basis.iter().product::(); + + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + let mut clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ct_zero = cks.encrypt_crt(clear_0, basis.clone()); + let mut ct_one = cks.encrypt_crt(clear_1, basis); + + for _ in 0..NB_TEST { + // sub the two ciphertexts + sks.smart_crt_sub_assign(&mut ct_zero, &mut ct_one); + + // decryption of ct_res + let dec_res = cks.decrypt_crt(&ct_zero); + + // println!("clear_0 = {}, clear_1 = {}, modulus = {}", clear_0, clear_1, modulus); + + // assert + clear_0 = (clear_0 + modulus - clear_1) % modulus; + assert_eq!(clear_0, dec_res); + } +} diff --git a/tfhe/src/integer/server_key/crt_parallel/add_crt.rs b/tfhe/src/integer/server_key/crt_parallel/add_crt.rs new file mode 100644 index 000000000..553a751d7 --- /dev/null +++ b/tfhe/src/integer/server_key/crt_parallel/add_crt.rs @@ -0,0 +1,115 @@ +use crate::integer::ciphertext::CrtCiphertext; +use crate::integer::ServerKey; +use rayon::prelude::*; + +impl ServerKey { + /// Computes homomorphically an addition between two ciphertexts encrypting integer + /// values in the CRT decomposition. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_crt; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let basis = vec![2, 3, 5]; + /// let (cks, sks) = gen_keys_crt(&PARAM_MESSAGE_2_CARRY_2, basis); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let ctxt_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_crt_add_assign_parallelized(&mut ctxt_1, &ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_add_assign_parallelized( + &self, + ct_left: &mut CrtCiphertext, + ct_right: &CrtCiphertext, + ) { + ct_left + .blocks + .par_iter_mut() + .zip(&ct_right.blocks) + .for_each(|(ct_left, ct_right)| { + self.key.unchecked_add_assign(ct_left, ct_right); + }); + } + + pub fn unchecked_crt_add_parallelized( + &self, + ct_left: &CrtCiphertext, + ct_right: &CrtCiphertext, + ) -> CrtCiphertext { + let mut ct_res = ct_left.clone(); + self.unchecked_crt_add_assign_parallelized(&mut ct_res, ct_right); + ct_res + } + + /// Computes homomorphically an addition between two ciphertexts encrypting integer values in + /// the CRT decomposition. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_crt; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let basis = vec![2, 3, 5]; + /// let (cks, sks) = gen_keys_crt(&PARAM_MESSAGE_2_CARRY_2, basis); + /// + /// let clear_1 = 29; + /// let clear_2 = 29; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let mut ctxt_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// sks.smart_crt_add_assign_parallelized(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn smart_crt_add_assign_parallelized( + &self, + ct_left: &mut CrtCiphertext, + ct_right: &mut CrtCiphertext, + ) { + if !self.is_crt_add_possible(ct_left, ct_right) { + rayon::join( + || self.full_extract_message_assign(ct_left), + || self.full_extract_message_assign(ct_right), + ); + } + self.unchecked_crt_add_assign_parallelized(ct_left, ct_right); + } + + pub fn smart_crt_add_parallelized( + &self, + ct_left: &mut CrtCiphertext, + ct_right: &mut CrtCiphertext, + ) -> CrtCiphertext { + if !self.is_crt_add_possible(ct_left, ct_right) { + rayon::join( + || self.full_extract_message_assign(ct_left), + || self.full_extract_message_assign(ct_right), + ); + } + self.unchecked_crt_add_parallelized(ct_left, ct_right) + } +} diff --git a/tfhe/src/integer/server_key/crt_parallel/mod.rs b/tfhe/src/integer/server_key/crt_parallel/mod.rs new file mode 100644 index 000000000..491a7e1a4 --- /dev/null +++ b/tfhe/src/integer/server_key/crt_parallel/mod.rs @@ -0,0 +1,108 @@ +mod add_crt; +mod mul_crt; +mod neg_crt; +mod scalar_add_crt; +mod scalar_mul_crt; +mod scalar_sub_crt; +mod sub_crt; + +use crate::integer::ciphertext::CrtCiphertext; +use crate::integer::ServerKey; +use rayon::prelude::*; + +impl ServerKey { + /// Extract all the messages. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let ctxt_2 = cks.encrypt_crt(clear_2, basis); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_crt_add_assign(&mut ctxt_1, &ctxt_2); + /// + /// sks.full_extract_message_assign_parallelized(&mut ctxt_1); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn full_extract_message_assign_parallelized(&self, ctxt: &mut CrtCiphertext) { + ctxt.blocks.par_iter_mut().for_each(|ct_i| { + self.key.message_extract_assign(ct_i); + }); + } + + /// Computes a PBS for CRT-compliant functions. + /// + /// # Warning + /// + /// This allows to compute programmable bootstrapping over integers under the condition that + /// the function is said to be CRT-compliant. This means that the function should be correct + /// when evaluated on each modular block independently (e.g. arithmetic functions). + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_crt; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// // Generate the client key and the server key: + /// let basis = vec![2, 3, 5]; + /// let (cks, sks) = gen_keys_crt(&DEFAULT_PARAMETERS, basis); + /// + /// let clear_1 = 28; + /// + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// + /// // Compute homomorphically the crt-compliant PBS + /// sks.pbs_crt_compliant_function_assign_parallelized(&mut ctxt_1, |x| x * x * x); + /// + /// // Decrypt + /// let res = cks.decrypt(&ctxt_1); + /// assert_eq!((clear_1 * clear_1 * clear_1) % 30, res); + /// ``` + pub fn pbs_crt_compliant_function_assign_parallelized(&self, ct1: &mut CrtCiphertext, f: F) + where + F: Fn(u64) -> u64, + { + let basis = &ct1.moduli; + + let accumulators = basis + .iter() + .copied() + .map(|b| self.key.generate_accumulator(|x| f(x) % b)) + .collect::>(); + + ct1.blocks + .par_iter_mut() + .zip(&accumulators) + .for_each(|(block, acc)| { + self.key.keyswitch_programmable_bootstrap_assign(block, acc); + }); + } + + pub fn pbs_crt_compliant_function_parallelized( + &self, + ct1: &CrtCiphertext, + f: F, + ) -> CrtCiphertext + where + F: Fn(u64) -> u64, + { + let mut ct_res = ct1.clone(); + self.pbs_crt_compliant_function_assign_parallelized(&mut ct_res, f); + ct_res + } +} diff --git a/tfhe/src/integer/server_key/crt_parallel/mul_crt.rs b/tfhe/src/integer/server_key/crt_parallel/mul_crt.rs new file mode 100644 index 000000000..cc781d846 --- /dev/null +++ b/tfhe/src/integer/server_key/crt_parallel/mul_crt.rs @@ -0,0 +1,108 @@ +use crate::integer::ciphertext::CrtCiphertext; +use crate::integer::ServerKey; +use rayon::prelude::*; + +impl ServerKey { + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer + /// values in the CRT decomposition. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_crt; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_3_CARRY_3; + /// let size = 3; + /// + /// // Generate the client key and the server key: + /// let basis = vec![2, 3, 5]; + /// let (cks, sks) = gen_keys_crt(&PARAM_MESSAGE_3_CARRY_3, basis); + /// + /// let clear_1 = 29; + /// let clear_2 = 23; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let ctxt_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_crt_mul_assign_parallelized(&mut ctxt_1, &ctxt_2); + /// // Decrypt + /// let res = cks.decrypt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_mul_assign_parallelized( + &self, + ct_left: &mut CrtCiphertext, + ct_right: &CrtCiphertext, + ) { + ct_left + .blocks + .par_iter_mut() + .zip(&ct_right.blocks) + .for_each(|(ct_left, ct_right)| { + self.key.unchecked_mul_lsb_assign(ct_left, ct_right); + }); + } + + pub fn unchecked_crt_mul_parallelized( + &self, + ct_left: &CrtCiphertext, + ct_right: &CrtCiphertext, + ) -> CrtCiphertext { + let mut ct_res = ct_left.clone(); + self.unchecked_crt_mul_assign_parallelized(&mut ct_res, ct_right); + ct_res + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer + /// values in the CRT decomposition. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_crt; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_3_CARRY_3; + /// + /// let basis = vec![2, 3, 5]; + /// let (cks, sks) = gen_keys_crt(&PARAM_MESSAGE_3_CARRY_3, basis); + /// + /// let clear_1 = 29; + /// let clear_2 = 29; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let mut ctxt_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// sks.smart_crt_mul_assign_parallelized(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn smart_crt_mul_assign_parallelized( + &self, + ct_left: &mut CrtCiphertext, + ct_right: &mut CrtCiphertext, + ) { + ct_left + .blocks + .par_iter_mut() + .zip(&mut ct_right.blocks) + .for_each(|(block_left, block_right)| { + self.key.smart_mul_lsb_assign(block_left, block_right); + }); + } + pub fn smart_crt_mul_parallelized( + &self, + ct_left: &CrtCiphertext, + ct_right: &mut CrtCiphertext, + ) -> CrtCiphertext { + let mut ct_res = ct_left.clone(); + self.smart_crt_mul_assign_parallelized(&mut ct_res, ct_right); + ct_res + } +} diff --git a/tfhe/src/integer/server_key/crt_parallel/neg_crt.rs b/tfhe/src/integer/server_key/crt_parallel/neg_crt.rs new file mode 100644 index 000000000..a7823e63c --- /dev/null +++ b/tfhe/src/integer/server_key/crt_parallel/neg_crt.rs @@ -0,0 +1,84 @@ +use crate::integer::{CrtCiphertext, ServerKey}; +use rayon::prelude::*; + +impl ServerKey { + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// This function computes the opposite of a message without checking if it exceeds the + /// capacity of the ciphertext. + /// + /// The result is returned as a new ciphertext. + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear = 14_u64; + /// let basis = vec![2, 3, 5]; + /// + /// let mut ctxt = cks.encrypt_crt(clear, basis.clone()); + /// + /// sks.unchecked_crt_neg_assign_parallelized(&mut ctxt); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!(16, res); + /// ``` + pub fn unchecked_crt_neg_parallelized(&self, ctxt: &CrtCiphertext) -> CrtCiphertext { + let mut result = ctxt.clone(); + self.unchecked_crt_neg_assign_parallelized(&mut result); + result + } + + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// This function computes the opposite of a message without checking if it exceeds the + /// capacity of the ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + pub fn unchecked_crt_neg_assign_parallelized(&self, ctxt: &mut CrtCiphertext) { + ctxt.blocks.par_iter_mut().for_each(|ct_i| { + self.key.unchecked_neg_assign(ct_i); + }); + } + + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear = 14_u64; + /// let basis = vec![2, 3, 5]; + /// + /// let mut ctxt = cks.encrypt_crt(clear, basis.clone()); + /// + /// sks.smart_crt_neg_assign_parallelized(&mut ctxt); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!(16, res); + /// ``` + pub fn smart_crt_neg_assign_parallelized(&self, ctxt: &mut CrtCiphertext) { + if !self.is_crt_neg_possible(ctxt) { + self.full_extract_message_assign_parallelized(ctxt); + } + self.unchecked_crt_neg_assign_parallelized(ctxt); + } + + pub fn smart_crt_neg_parallelized(&self, ctxt: &mut CrtCiphertext) -> CrtCiphertext { + if !self.is_crt_neg_possible(ctxt) { + self.full_extract_message_assign_parallelized(ctxt); + } + self.unchecked_crt_neg_parallelized(ctxt) + } +} diff --git a/tfhe/src/integer/server_key/crt_parallel/scalar_add_crt.rs b/tfhe/src/integer/server_key/crt_parallel/scalar_add_crt.rs new file mode 100644 index 000000000..c6081dc51 --- /dev/null +++ b/tfhe/src/integer/server_key/crt_parallel/scalar_add_crt.rs @@ -0,0 +1,194 @@ +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::{CrtCiphertext, ServerKey}; +use rayon::prelude::*; + +impl ServerKey { + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.unchecked_crt_scalar_add_assign_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_scalar_add_parallelized( + &self, + ct: &CrtCiphertext, + scalar: u64, + ) -> CrtCiphertext { + let mut result = ct.clone(); + self.unchecked_crt_scalar_add_assign_parallelized(&mut result, scalar); + result + } + + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + pub fn unchecked_crt_scalar_add_assign_parallelized( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) { + //Add the crt representation of the scalar to the ciphertext + ct.blocks + .par_iter_mut() + .zip(ct.moduli.par_iter()) + .for_each(|(ct_i, mod_i)| { + let scalar_i = scalar % mod_i; + self.key.unchecked_scalar_add_assign(ct_i, scalar_i as u8); + }); + } + + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.checked_crt_scalar_add_assign_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_crt_scalar_add_parallelized( + &self, + ct: &CrtCiphertext, + scalar: u64, + ) -> Result { + if self.is_crt_scalar_add_possible(ct, scalar) { + Ok(self.unchecked_crt_scalar_add_parallelized(ct, scalar)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + pub fn checked_crt_scalar_add_assign_parallelized( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + if self.is_crt_scalar_add_possible(ct, scalar) { + self.unchecked_crt_scalar_add_assign_parallelized(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically the addition of ciphertext with a scalar. + /// + /// The result is returned in a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// let ctxt = sks.smart_crt_scalar_add_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_add_parallelized( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) -> CrtCiphertext { + if !self.is_crt_scalar_add_possible(ct, scalar) { + self.full_extract_message_assign_parallelized(ct); + } + + let mut ct = ct.clone(); + self.unchecked_crt_scalar_add_assign_parallelized(&mut ct, scalar); + ct + } + + /// Computes homomorphically the addition of ciphertext with a scalar. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.smart_crt_scalar_add_assign_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 + clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_add_assign_parallelized(&self, ct: &mut CrtCiphertext, scalar: u64) { + if !self.is_crt_scalar_add_possible(ct, scalar) { + self.full_extract_message_assign_parallelized(ct); + } + self.unchecked_crt_scalar_add_assign_parallelized(ct, scalar); + } +} diff --git a/tfhe/src/integer/server_key/crt_parallel/scalar_mul_crt.rs b/tfhe/src/integer/server_key/crt_parallel/scalar_mul_crt.rs new file mode 100644 index 000000000..e91546b0e --- /dev/null +++ b/tfhe/src/integer/server_key/crt_parallel/scalar_mul_crt.rs @@ -0,0 +1,199 @@ +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::{CrtCiphertext, ServerKey}; +use rayon::prelude::*; + +impl ServerKey { + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 2; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.unchecked_crt_scalar_mul_assign_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_scalar_mul_parallelized( + &self, + ctxt: &CrtCiphertext, + scalar: u64, + ) -> CrtCiphertext { + let mut ct_result = ctxt.clone(); + self.unchecked_crt_scalar_mul_assign_parallelized(&mut ct_result, scalar); + ct_result + } + + pub fn unchecked_crt_scalar_mul_assign_parallelized( + &self, + ctxt: &mut CrtCiphertext, + scalar: u64, + ) { + ctxt.blocks + .par_iter_mut() + .zip(ctxt.moduli.par_iter()) + .for_each(|(ct_i, mod_i)| { + self.key + .unchecked_scalar_mul_assign(ct_i, (scalar % mod_i) as u8); + }); + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 2; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.checked_crt_scalar_mul_assign_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_crt_scalar_mul_parallelized( + &self, + ct: &CrtCiphertext, + scalar: u64, + ) -> Result { + let mut ct_result = ct.clone(); + + // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext + if self.is_crt_scalar_mul_possible(ct, scalar) { + ct_result = self.unchecked_crt_scalar_mul(&ct_result, scalar); + + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is assigned to the ciphertext given + /// as parameter. + /// Otherwise [CheckError::CarryFull] is returned. + pub fn checked_crt_scalar_mul_assign_parallelized( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext + if self.is_crt_scalar_mul_possible(ct, scalar) { + self.unchecked_crt_scalar_mul_assign_parallelized(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// `small` means the scalar value shall fit in a __shortint block__. + /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2, + /// the scalar should fit in 2 bits. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// let ctxt = sks.smart_crt_scalar_mul_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_mul_parallelized( + &self, + ctxt: &mut CrtCiphertext, + scalar: u64, + ) -> CrtCiphertext { + if !self.is_crt_scalar_mul_possible(ctxt, scalar) { + self.full_extract_message_assign_parallelized(ctxt); + } + self.unchecked_crt_scalar_mul(ctxt, scalar) + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// `small` means the scalar shall value fit in a __shortint block__. + /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2, + /// the scalar should fit in 2 bits. + /// + /// The result is assigned to the input ciphertext + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 14; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.smart_crt_scalar_mul_assign_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 * clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_mul_assign_parallelized(&self, ctxt: &mut CrtCiphertext, scalar: u64) { + if !self.is_crt_small_scalar_mul_possible(ctxt, scalar) { + self.full_extract_message_assign_parallelized(ctxt); + } + self.unchecked_crt_scalar_mul_assign_parallelized(ctxt, scalar); + } +} diff --git a/tfhe/src/integer/server_key/crt_parallel/scalar_sub_crt.rs b/tfhe/src/integer/server_key/crt_parallel/scalar_sub_crt.rs new file mode 100644 index 000000000..edd48e697 --- /dev/null +++ b/tfhe/src/integer/server_key/crt_parallel/scalar_sub_crt.rs @@ -0,0 +1,186 @@ +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::{CrtCiphertext, ServerKey}; +use rayon::prelude::*; + +impl ServerKey { + /// Computes homomorphically a subtraction between a ciphertext and a scalar. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 7; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.unchecked_crt_scalar_sub_assign_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_scalar_sub_parallelized( + &self, + ct: &CrtCiphertext, + scalar: u64, + ) -> CrtCiphertext { + let mut result = ct.clone(); + self.unchecked_crt_scalar_sub_assign_parallelized(&mut result, scalar); + result + } + + pub fn unchecked_crt_scalar_sub_assign_parallelized( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) { + //Put each decomposition into a new ciphertext + ct.blocks + .par_iter_mut() + .zip(ct.moduli.par_iter()) + .for_each(|(ct_i, mod_i)| { + let neg_scalar = (mod_i - scalar % mod_i) % mod_i; + self.key + .unchecked_scalar_add_assign_crt(ct_i, neg_scalar as u8); + }); + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 8; + /// let basis = vec![2, 3, 5]; + /// + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// let ct_res = sks.checked_crt_scalar_sub_parallelized(&mut ctxt_1, clear_2)?; + /// + /// // Decrypt: + /// let dec = cks.decrypt_crt(&ct_res); + /// assert_eq!((clear_1 - clear_2) % 30, dec); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_crt_scalar_sub_parallelized( + &self, + ct: &CrtCiphertext, + scalar: u64, + ) -> Result { + if self.is_crt_scalar_sub_possible(ct, scalar) { + Ok(self.unchecked_crt_scalar_sub_parallelized(ct, scalar)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 7; + /// let basis = vec![2, 3, 5]; + /// + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.checked_crt_scalar_sub_assign_parallelized(&mut ctxt_1, clear_2)?; + /// + /// // Decrypt: + /// let dec = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 - clear_2) % 30, dec); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_crt_scalar_sub_assign_parallelized( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + if self.is_crt_scalar_sub_possible(ct, scalar) { + self.unchecked_crt_scalar_sub_assign_parallelized(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 7; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// + /// sks.smart_crt_scalar_sub_assign_parallelized(&mut ctxt_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn smart_crt_scalar_sub_parallelized( + &self, + ct: &mut CrtCiphertext, + scalar: u64, + ) -> CrtCiphertext { + if !self.is_crt_scalar_sub_possible(ct, scalar) { + self.full_extract_message_assign_parallelized(ct); + } + + self.unchecked_crt_scalar_sub_parallelized(ct, scalar) + } + + pub fn smart_crt_scalar_sub_assign_parallelized(&self, ct: &mut CrtCiphertext, scalar: u64) { + if !self.is_crt_scalar_sub_possible(ct, scalar) { + self.full_extract_message_assign_parallelized(ct); + } + + self.unchecked_crt_scalar_sub_assign_parallelized(ct, scalar); + } +} diff --git a/tfhe/src/integer/server_key/crt_parallel/sub_crt.rs b/tfhe/src/integer/server_key/crt_parallel/sub_crt.rs new file mode 100644 index 000000000..33897c66e --- /dev/null +++ b/tfhe/src/integer/server_key/crt_parallel/sub_crt.rs @@ -0,0 +1,158 @@ +use crate::integer::{CrtCiphertext, ServerKey}; + +impl ServerKey { + /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. + /// + /// This function computes the subtraction without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 5; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis.clone()); + /// + /// let ctxt = sks.unchecked_crt_sub_parallelized(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_sub_parallelized( + &self, + ctxt_left: &CrtCiphertext, + ctxt_right: &CrtCiphertext, + ) -> CrtCiphertext { + let mut result = ctxt_left.clone(); + self.unchecked_crt_sub_assign_parallelized(&mut result, ctxt_right); + result + } + + /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. + /// + /// This function computes the subtraction without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 5; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis.clone()); + /// + /// let ctxt = sks.unchecked_crt_sub_parallelized(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn unchecked_crt_sub_assign_parallelized( + &self, + ctxt_left: &mut CrtCiphertext, + ctxt_right: &CrtCiphertext, + ) { + let neg = self.unchecked_crt_neg_parallelized(ctxt_right); + self.unchecked_crt_add_assign_parallelized(ctxt_left, &neg); + } + + /// Computes homomorphically the subtraction between ct_left and ct_right. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 5; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis.clone()); + /// + /// let ctxt = sks.smart_crt_sub_parallelized(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn smart_crt_sub_parallelized( + &self, + ctxt_left: &mut CrtCiphertext, + ctxt_right: &mut CrtCiphertext, + ) -> CrtCiphertext { + // If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_crt_sub_possible(ctxt_left, ctxt_right) { + rayon::join( + || self.full_extract_message_assign_parallelized(ctxt_left), + || self.full_extract_message_assign_parallelized(ctxt_right), + ); + } + + self.unchecked_crt_sub_parallelized(ctxt_left, ctxt_right) + } + + /// Computes homomorphically the subtraction between ct_left and ct_right. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 14; + /// let clear_2 = 5; + /// let basis = vec![2, 3, 5]; + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt_crt(clear_1, basis.clone()); + /// let mut ctxt_2 = cks.encrypt_crt(clear_2, basis.clone()); + /// + /// sks.smart_crt_sub_assign_parallelized(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt_crt(&ctxt_1); + /// assert_eq!((clear_1 - clear_2) % 30, res); + /// ``` + pub fn smart_crt_sub_assign_parallelized( + &self, + ctxt_left: &mut CrtCiphertext, + ctxt_right: &mut CrtCiphertext, + ) { + // If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_crt_sub_possible(ctxt_left, ctxt_right) { + rayon::join( + || self.full_extract_message_assign_parallelized(ctxt_left), + || self.full_extract_message_assign_parallelized(ctxt_right), + ); + } + + self.unchecked_crt_sub_assign_parallelized(ctxt_left, ctxt_right); + } +} diff --git a/tfhe/src/integer/server_key/mod.rs b/tfhe/src/integer/server_key/mod.rs new file mode 100644 index 000000000..4cf9bdb50 --- /dev/null +++ b/tfhe/src/integer/server_key/mod.rs @@ -0,0 +1,92 @@ +//! Module with the definition of the ServerKey. +//! +//! This module implements the generation of the server public key, together with all the +//! available homomorphic integer operations. +mod crt; +mod crt_parallel; +mod radix; +mod radix_parallel; + +use crate::integer::client_key::ClientKey; +use crate::shortint::server_key::MaxDegree; +use serde::{Deserialize, Serialize}; + +/// Error returned when the carry buffer is full. +pub use crate::shortint::CheckError; + +/// A structure containing the server public key. +/// +/// The server key is generated by the client and is meant to be published: the client +/// sends it to the server so it can compute homomorphic integer circuits. +#[derive(Serialize, Deserialize, Clone)] +pub struct ServerKey { + pub(crate) key: crate::shortint::ServerKey, +} + +impl From for crate::shortint::ServerKey { + fn from(key: ServerKey) -> crate::shortint::ServerKey { + key.key + } +} + +impl ServerKey { + /// Generates a server key. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Generate the server key: + /// let sks = ServerKey::new(&cks); + /// ``` + pub fn new(cks: C) -> ServerKey + where + C: AsRef, + { + // It should remain just enough space to add a carry + let client_key = cks.as_ref(); + let max = (client_key.key.parameters.message_modulus.0 - 1) + * client_key.key.parameters.carry_modulus.0 + - 1; + + let sks = crate::shortint::server_key::ServerKey::new_with_max_degree( + &client_key.key, + MaxDegree(max), + ); + + ServerKey { key: sks } + } + + /// Creates a ServerKey from an already generated shortint::ServerKey. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Generate the server key: + /// let sks = ServerKey::new(&cks); + /// ``` + pub fn from_shortint( + cks: &ClientKey, + mut key: crate::shortint::server_key::ServerKey, + ) -> ServerKey { + // It should remain just enough space add a carry + let max = + (cks.key.parameters.message_modulus.0 - 1) * cks.key.parameters.carry_modulus.0 - 1; + + key.max_degree = MaxDegree(max); + ServerKey { key } + } +} diff --git a/tfhe/src/integer/server_key/radix/add.rs b/tfhe/src/integer/server_key/radix/add.rs new file mode 100644 index 000000000..41d6650f3 --- /dev/null +++ b/tfhe/src/integer/server_key/radix/add.rs @@ -0,0 +1,245 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 10; + /// let msg2 = 127; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.unchecked_add(&ct1, &ct2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg1 + msg2); + /// ``` + pub fn unchecked_add( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> RadixCiphertext { + let mut result = ct_left.clone(); + self.unchecked_add_assign(&mut result, ct_right); + result + } + + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 28; + /// let msg2 = 127; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// sks.unchecked_add_assign(&mut ct1, &ct2); + /// + /// // Decrypt: + /// let dec_ct1 = cks.decrypt(&ct1); + /// assert_eq!(dec_ct1, msg1 + msg2); + /// ``` + pub fn unchecked_add_assign(&self, ct_left: &mut RadixCiphertext, ct_right: &RadixCiphertext) { + for (ct_left_i, ct_right_i) in ct_left.blocks.iter_mut().zip(ct_right.blocks.iter()) { + self.key.unchecked_add_assign(ct_left_i, ct_right_i); + } + } + + /// Verifies if ct1 and ct2 can be added together. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 46; + /// let msg2 = 87; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Check if we can perform an addition + /// let res = sks.is_add_possible(&ct1, &ct2); + /// + /// assert_eq!(true, res); + /// ``` + pub fn is_add_possible(&self, ct_left: &RadixCiphertext, ct_right: &RadixCiphertext) -> bool { + for (ct_left_i, ct_right_i) in ct_left.blocks.iter().zip(ct_right.blocks.iter()) { + if !self.key.is_add_possible(ct_left_i, ct_right_i) { + return false; + } + } + true + } + + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 41; + /// let msg2 = 101; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.checked_add(&ct1, &ct2); + /// + /// match ct_res { + /// Err(x) => panic!("{:?}", x), + /// Ok(y) => { + /// let clear = cks.decrypt(&y); + /// assert_eq!(msg1 + msg2, clear); + /// } + /// } + /// ``` + pub fn checked_add( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Result { + if self.is_add_possible(ct_left, ct_right) { + let mut result = ct_left.clone(); + self.unchecked_add_assign(&mut result, ct_right); + + Ok(result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 41; + /// let msg2 = 101; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let res = sks.checked_add_assign(&mut ct1, &ct2); + /// + /// assert!(res.is_ok()); + /// + /// let clear = cks.decrypt(&ct1); + /// assert_eq!(msg1 + msg2, clear); + /// ``` + pub fn checked_add_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Result<(), CheckError> { + if self.is_add_possible(ct_left, ct_right) { + self.unchecked_add_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 14; + /// let msg2 = 97; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.smart_add(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg1 + msg2); + /// ``` + pub fn smart_add( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + if !self.is_add_possible(ct_left, ct_right) { + self.full_propagate(ct_left); + self.full_propagate(ct_right); + } + self.unchecked_add(ct_left, ct_right) + } + + pub fn smart_add_assign(&self, ct_left: &mut RadixCiphertext, ct_right: &mut RadixCiphertext) { + //If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_add_possible(ct_left, ct_right) { + self.full_propagate(ct_left); + self.full_propagate(ct_right); + } + self.unchecked_add_assign(ct_left, ct_right); + } +} diff --git a/tfhe/src/integer/server_key/radix/bitwise_op.rs b/tfhe/src/integer/server_key/radix/bitwise_op.rs new file mode 100644 index 000000000..2d93c074f --- /dev/null +++ b/tfhe/src/integer/server_key/radix/bitwise_op.rs @@ -0,0 +1,604 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; +use crate::shortint::CheckError; +use crate::shortint::CheckError::CarryFull; + +impl ServerKey { + /// Computes homomorphically bitand between two ciphertexts encrypting integer values. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 201; + /// let msg2 = 1; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically a bitwise and: + /// let ct_res = sks.unchecked_bitand(&ct1, &ct2); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(dec, msg1 & msg2); + /// ``` + pub fn unchecked_bitand( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> RadixCiphertext { + let mut result = ct_left.clone(); + self.unchecked_bitand_assign(&mut result, ct_right); + result + } + + pub fn unchecked_bitand_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, + ) { + for (ct_left_i, ct_right_i) in ct_left.blocks.iter_mut().zip(ct_right.blocks.iter()) { + self.key.unchecked_bitand_assign(ct_left_i, ct_right_i); + } + } + + /// Verifies if a bivariate functional pbs can be applied on ct_left and ct_right. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 46; + /// let msg2 = 87; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// let res = sks.is_functional_bivariate_pbs_possible(&ct1, &ct2); + /// + /// assert_eq!(true, res); + /// ``` + pub fn is_functional_bivariate_pbs_possible( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> bool { + for (ct_left_i, ct_right_i) in ct_left.blocks.iter().zip(ct_right.blocks.iter()) { + if !self + .key + .is_functional_bivariate_pbs_possible(ct_left_i, ct_right_i) + { + return false; + } + } + true + } + + /// Computes homomorphically a bitand between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 41; + /// let msg2 = 101; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// let ct_res = sks.checked_bitand(&ct1, &ct2); + /// + /// match ct_res { + /// Err(x) => panic!("{:?}", x), + /// Ok(y) => { + /// let clear = cks.decrypt(&y); + /// assert_eq!(msg1 & msg2, clear); + /// } + /// } + /// ``` + pub fn checked_bitand( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + Ok(self.unchecked_bitand(ct_left, ct_right)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a bitand between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 41; + /// let msg2 = 101; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// let res = sks.checked_bitand_assign(&mut ct1, &ct2); + /// + /// assert!(res.is_ok()); + /// + /// let clear = cks.decrypt(&ct1); + /// assert_eq!(msg1 & msg2, clear); + /// ``` + pub fn checked_bitand_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Result<(), CheckError> { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.unchecked_bitand_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a bitand between two ciphertexts encrypting integer values. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 14; + /// let msg2 = 97; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// let ct_res = sks.smart_bitand(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg1 & msg2); + /// ``` + pub fn smart_bitand( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.full_propagate(ct_left); + self.full_propagate(ct_right); + } + self.unchecked_bitand(ct_left, ct_right) + } + + pub fn smart_bitand_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.full_propagate(ct_left); + self.full_propagate(ct_right); + } + self.unchecked_bitand_assign(ct_left, ct_right); + } + + /// Computes homomorphically bitor between two ciphertexts encrypting integer values. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 200; + /// let msg2 = 1; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically a bitwise or: + /// let ct_res = sks.unchecked_bitor(&ct1, &ct2); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(dec, msg1 | msg2); + /// ``` + pub fn unchecked_bitor( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> RadixCiphertext { + let mut result = ct_left.clone(); + self.unchecked_bitor_assign(&mut result, ct_right); + result + } + + pub fn unchecked_bitor_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, + ) { + for (ct_left_i, ct_right_i) in ct_left.blocks.iter_mut().zip(ct_right.blocks.iter()) { + self.key.unchecked_bitor_assign(ct_left_i, ct_right_i); + } + } + + /// Computes homomorphically a bitor between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 41; + /// let msg2 = 101; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.checked_bitor(&ct1, &ct2); + /// + /// match ct_res { + /// Err(x) => panic!("{:?}", x), + /// Ok(y) => { + /// let clear = cks.decrypt(&y); + /// assert_eq!(msg1 | msg2, clear); + /// } + /// } + /// ``` + pub fn checked_bitor( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + Ok(self.unchecked_bitor(ct_left, ct_right)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a bitand between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 41; + /// let msg2 = 101; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let res = sks.checked_bitor_assign(&mut ct1, &ct2); + /// + /// assert!(res.is_ok()); + /// + /// let clear = cks.decrypt(&ct1); + /// assert_eq!(msg1 | msg2, clear); + /// ``` + pub fn checked_bitor_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Result<(), CheckError> { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.unchecked_bitor_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a bitor between two ciphertexts encrypting integer values. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 14; + /// let msg2 = 97; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// let ct_res = sks.smart_bitor(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg1 | msg2); + /// ``` + pub fn smart_bitor( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.full_propagate(ct_left); + self.full_propagate(ct_right); + } + self.unchecked_bitor(ct_left, ct_right) + } + + pub fn smart_bitor_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.full_propagate(ct_left); + self.full_propagate(ct_right); + } + self.unchecked_bitor_assign(ct_left, ct_right); + } + + /// Computes homomorphically bitxor between two ciphertexts encrypting integer values. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 49; + /// let msg2 = 64; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically a bitwise xor: + /// let ct_res = sks.unchecked_bitxor(&ct1, &ct2); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg1 ^ msg2, dec); + /// ``` + pub fn unchecked_bitxor( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> RadixCiphertext { + let mut result = ct_left.clone(); + self.unchecked_bitxor_assign(&mut result, ct_right); + result + } + + pub fn unchecked_bitxor_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, + ) { + for (ct_left_i, ct_right_i) in ct_left.blocks.iter_mut().zip(ct_right.blocks.iter()) { + self.key.unchecked_bitxor_assign(ct_left_i, ct_right_i); + } + } + + /// Computes homomorphically a bitxor between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 41; + /// let msg2 = 101; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.checked_bitxor(&ct1, &ct2); + /// + /// match ct_res { + /// Err(x) => panic!("{:?}", x), + /// Ok(y) => { + /// let clear = cks.decrypt(&y); + /// assert_eq!(msg1 ^ msg2, clear); + /// } + /// } + /// ``` + pub fn checked_bitxor( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + Ok(self.unchecked_bitxor(ct_left, ct_right)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a bitxor between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 41; + /// let msg2 = 101; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let res = sks.checked_bitxor_assign(&mut ct1, &ct2); + /// + /// assert!(res.is_ok()); + /// + /// let clear = cks.decrypt(&ct1); + /// assert_eq!(msg1 ^ msg2, clear); + /// ``` + pub fn checked_bitxor_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Result<(), CheckError> { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.unchecked_bitxor_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a bitxor between two ciphertexts encrypting integer values. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg1 = 14; + /// let msg2 = 97; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// let ct_res = sks.smart_bitxor(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg1 ^ msg2); + /// ``` + pub fn smart_bitxor( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.full_propagate(ct_left); + self.full_propagate(ct_right); + } + self.unchecked_bitxor(ct_left, ct_right) + } + + pub fn smart_bitxor_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.full_propagate(ct_left); + self.full_propagate(ct_right); + } + self.unchecked_bitxor_assign(ct_left, ct_right); + } +} diff --git a/tfhe/src/integer/server_key/radix/mod.rs b/tfhe/src/integer/server_key/radix/mod.rs new file mode 100644 index 000000000..e054594c4 --- /dev/null +++ b/tfhe/src/integer/server_key/radix/mod.rs @@ -0,0 +1,117 @@ +mod add; +mod bitwise_op; +mod mul; +mod neg; +mod scalar_add; +mod scalar_mul; +mod scalar_sub; +mod shift; +mod sub; + +use super::ServerKey; + +use crate::integer::RadixCiphertext; + +#[cfg(test)] +mod tests; + +impl ServerKey { + /// Create a ciphertext filled with zeros + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// let num_blocks = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&DEFAULT_PARAMETERS, num_blocks); + /// + /// let ctxt = sks.create_trivial_zero_radix(num_blocks); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ctxt); + /// assert_eq!(0, dec); + /// ``` + pub fn create_trivial_zero_radix(&self, num_blocks: usize) -> RadixCiphertext { + let mut vec_res = Vec::with_capacity(num_blocks); + for _ in 0..num_blocks { + vec_res.push(self.key.create_trivial(0_u64)); + } + + RadixCiphertext { blocks: vec_res } + } + + /// Propagate the carry of the 'index' block to the next one. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::{gen_keys_radix, IntegerCiphertext}; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let num_blocks = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 7; + /// + /// let ct1 = cks.encrypt(msg); + /// let ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let mut ct_res = sks.unchecked_add(&ct1, &ct2); + /// sks.propagate(&mut ct_res, 0); + /// + /// // Decrypt one block: + /// let res = cks.decrypt_one_block(&ct_res.blocks()[1]); + /// assert_eq!(3, res); + /// ``` + pub fn propagate(&self, ctxt: &mut RadixCiphertext, index: usize) { + let carry = self.key.carry_extract(&ctxt.blocks[index]); + + ctxt.blocks[index] = self.key.message_extract(&ctxt.blocks[index]); + + //add the carry to the next block + if index < ctxt.blocks.len() - 1 { + self.key + .unchecked_add_assign(&mut ctxt.blocks[index + 1], &carry); + } + } + + /// Propagate all the carries. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::{gen_keys_radix, IntegerCiphertext}; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let num_blocks = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 10; + /// + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let mut ct_res = sks.unchecked_add(&mut ct1, &mut ct2); + /// sks.full_propagate(&mut ct_res); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(msg + msg, res); + /// ``` + pub fn full_propagate(&self, ctxt: &mut RadixCiphertext) { + let len = ctxt.blocks.len(); + for i in 0..len { + self.propagate(ctxt, i); + } + } +} diff --git a/tfhe/src/integer/server_key/radix/mul.rs b/tfhe/src/integer/server_key/radix/mul.rs new file mode 100644 index 000000000..b4c382443 --- /dev/null +++ b/tfhe/src/integer/server_key/radix/mul.rs @@ -0,0 +1,272 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically a multiplication between a ciphertext encrypting an integer value + /// and another encrypting a shortint value. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let clear_1 = 170; + /// let clear_2 = 3; + /// + /// // Encrypt two messages + /// let mut ct_left = cks.encrypt(clear_1); + /// let ct_right = cks.encrypt_one_block(clear_2); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_block_mul_assign(&mut ct_left, &ct_right, 0); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_left); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn unchecked_block_mul_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &crate::shortint::Ciphertext, + index: usize, + ) { + *ct_left = self.unchecked_block_mul(ct_left, ct_right, index); + } + + /// Computes homomorphically a multiplication between a ciphertexts encrypting an integer + /// value and another encrypting a shortint value. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let clear_1 = 55; + /// let clear_2 = 3; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(clear_1); + /// let ct_right = cks.encrypt_one_block(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.unchecked_block_mul(&ct_left, &ct_right, 0); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn unchecked_block_mul( + &self, + ct1: &RadixCiphertext, + ct2: &crate::shortint::Ciphertext, + index: usize, + ) -> RadixCiphertext { + let shifted_ct = self.blockshift(ct1, index); + + let mut result_lsb = shifted_ct.clone(); + let mut result_msb = shifted_ct; + + for res_lsb_i in result_lsb.blocks[index..].iter_mut() { + self.key.unchecked_mul_lsb_assign(res_lsb_i, ct2); + } + + let len = result_msb.blocks.len() - 1; + for res_msb_i in result_msb.blocks[index..len].iter_mut() { + self.key.unchecked_mul_msb_assign(res_msb_i, ct2); + } + + result_msb = self.blockshift(&result_msb, 1); + + self.unchecked_add(&result_lsb, &result_msb) + } + + /// Computes homomorphically a multiplication between a ciphertext encrypting integer value + /// and another encrypting a shortint value. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let clear_1 = 170; + /// let clear_2 = 3; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let ctxt_2 = cks.encrypt_one_block(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.smart_block_mul(&mut ctxt_1, &ctxt_2, 0); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn smart_block_mul( + &self, + ct1: &mut RadixCiphertext, + ct2: &crate::shortint::Ciphertext, + index: usize, + ) -> RadixCiphertext { + //Makes sure we can do the multiplications + self.full_propagate(ct1); + + let shifted_ct = self.blockshift(ct1, index); + + let mut result_lsb = shifted_ct.clone(); + let mut result_msb = shifted_ct; + + for res_lsb_i in result_lsb.blocks[index..].iter_mut() { + self.key.unchecked_mul_lsb_assign(res_lsb_i, ct2); + } + + let len = result_msb.blocks.len() - 1; + for res_msb_i in result_msb.blocks[index..len].iter_mut() { + self.key.unchecked_mul_msb_assign(res_msb_i, ct2); + } + + result_msb = self.blockshift(&result_msb, 1); + + self.smart_add(&mut result_lsb, &mut result_msb) + } + + pub fn smart_block_mul_assign( + &self, + ct1: &mut RadixCiphertext, + ct2: &crate::shortint::Ciphertext, + index: usize, + ) { + *ct1 = self.smart_block_mul(ct1, ct2, index); + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer values. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let clear_1 = 255; + /// let clear_2 = 143; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let ctxt_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.unchecked_mul(&mut ctxt_1, &ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn unchecked_mul_assign(&self, ct1: &mut RadixCiphertext, ct2: &RadixCiphertext) { + *ct1 = self.unchecked_mul(ct1, ct2); + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer values. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + pub fn unchecked_mul(&self, ct1: &RadixCiphertext, ct2: &RadixCiphertext) -> RadixCiphertext { + let mut result = self.create_trivial_zero_radix(ct1.blocks.len()); + + for (i, ct2_i) in ct2.blocks.iter().enumerate() { + let tmp = self.unchecked_block_mul(ct1, ct2_i, i); + + self.unchecked_add_assign(&mut result, &tmp); + } + + result + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer values. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// let size = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let clear_1 = 170; + /// let clear_2 = 6; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let mut ctxt_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.smart_mul(&mut ctxt_1, &mut ctxt_2); + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn smart_mul_assign(&self, ct1: &mut RadixCiphertext, ct2: &mut RadixCiphertext) { + *ct1 = self.smart_mul(ct1, ct2); + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer values. + /// + /// The result is returned as a new ciphertext. + pub fn smart_mul( + &self, + ct1: &mut RadixCiphertext, + ct2: &mut RadixCiphertext, + ) -> RadixCiphertext { + self.full_propagate(ct1); + self.full_propagate(ct2); + + let mut result = self.create_trivial_zero_radix(ct1.blocks.len()); + + for (i, ct2_i) in ct2.blocks.iter().enumerate() { + let mut tmp = self.unchecked_block_mul(ct1, ct2_i, i); + self.smart_add_assign(&mut result, &mut tmp); + } + + result + } +} diff --git a/tfhe/src/integer/server_key/radix/neg.rs b/tfhe/src/integer/server_key/radix/neg.rs new file mode 100644 index 000000000..4eef372eb --- /dev/null +++ b/tfhe/src/integer/server_key/radix/neg.rs @@ -0,0 +1,215 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::ServerKey; + +impl ServerKey { + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// This function computes the opposite of a message without checking if it exceeds the + /// capacity of the ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// // Encrypt two messages: + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let modulus = 1 << 8; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 159; + /// + /// // Encrypt a message + /// let mut ctxt = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation + /// sks.unchecked_neg_assign(&mut ctxt); + /// + /// // Decrypt + /// let dec = cks.decrypt(&ctxt); + /// assert_eq!(modulus - msg, dec); + /// ``` + pub fn unchecked_neg(&self, ctxt: &RadixCiphertext) -> RadixCiphertext { + let mut result = ctxt.clone(); + + self.unchecked_neg_assign(&mut result); + + result + } + + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// This function computes the opposite of a message without checking if it exceeds the + /// capacity of the ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + pub fn unchecked_neg_assign(&self, ctxt: &mut RadixCiphertext) { + //z is used to make sure the negation doesn't fill the padding bit + let mut z; + let mut z_b; + + for i in 0..ctxt.blocks.len() { + let c_i = &mut ctxt.blocks[i]; + z = self.key.unchecked_neg_assign_with_z(c_i); + + // Subtract z/B to the next ciphertext to compensate for the addition of z + z_b = z / self.key.message_modulus.0 as u64; + + if i < ctxt.blocks.len() - 1 { + let c_j = &mut ctxt.blocks[i + 1]; + self.key.unchecked_scalar_add_assign(c_j, z_b as u8); + } + } + } + + /// Verifies if ct can be negated. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 2; + /// + /// // Encrypt a message + /// let ctxt = cks.encrypt(msg); + /// + /// // Check if we can perform a negation + /// let res = sks.is_neg_possible(&ctxt); + /// + /// assert_eq!(true, res); + /// ``` + pub fn is_neg_possible(&self, ctxt: &RadixCiphertext) -> bool { + for ct_i in ctxt.blocks.iter() { + if !self.key.is_neg_possible(ct_i) { + return false; + } + } + true + } + + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// This function computes the opposite of a message without checking if it exceeds the + /// capacity of the ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 1; + /// + /// // Encrypt a message + /// let ctxt = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation: + /// let ct_res = sks.checked_neg(&ctxt); + /// + /// match ct_res { + /// Err(x) => panic!("{:?}", x), + /// Ok(y) => { + /// let clear = cks.decrypt(&y); + /// assert_eq!(255, clear); + /// } + /// } + /// ``` + pub fn checked_neg(&self, ctxt: &RadixCiphertext) -> Result { + //If the ciphertext cannot be negated without exceeding the capacity of a ciphertext + if self.is_neg_possible(ctxt) { + let mut result = ctxt.clone(); + self.unchecked_neg_assign(&mut result); + Ok(result) + } else { + Err(CarryFull) + } + } + + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// This function computes the opposite of a message without checking if it exceeds the + /// capacity of the ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let modulus = 1 << 8; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 1; + /// + /// // Encrypt a message + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation: + /// sks.checked_neg_assign(&mut ct); + /// + /// let clear_res = cks.decrypt(&ct); + /// assert_eq!(clear_res, (modulus - msg)); + /// ``` + pub fn checked_neg_assign(&self, ctxt: &mut RadixCiphertext) -> Result<(), CheckError> { + //If the ciphertext cannot be negated without exceeding the capacity of a ciphertext + if self.is_neg_possible(ctxt) { + self.unchecked_neg_assign(ctxt); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ctxt = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation + /// let ct_res = sks.smart_neg(&mut ctxt); + /// + /// // Decrypt + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(255, dec); + /// ``` + pub fn smart_neg(&self, ctxt: &mut RadixCiphertext) -> RadixCiphertext { + if !self.is_neg_possible(ctxt) { + self.full_propagate(ctxt); + } + self.unchecked_neg(ctxt) + } +} diff --git a/tfhe/src/integer/server_key/radix/scalar_add.rs b/tfhe/src/integer/server_key/radix/scalar_add.rs new file mode 100644 index 000000000..cb9136b81 --- /dev/null +++ b/tfhe/src/integer/server_key/radix/scalar_add.rs @@ -0,0 +1,236 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 4; + /// let scalar = 40; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.unchecked_scalar_add(&ct, scalar); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg + scalar, dec); + /// ``` + pub fn unchecked_scalar_add(&self, ct: &RadixCiphertext, scalar: u64) -> RadixCiphertext { + let mut result = ct.clone(); + self.unchecked_scalar_add_assign(&mut result, scalar); + result + } + + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + pub fn unchecked_scalar_add_assign(&self, ct: &mut RadixCiphertext, scalar: u64) { + // Bits of message put to 1 + let mask = (self.key.message_modulus.0 - 1) as u64; + + let mut power = 1_u64; + // Put each decomposition into a new ciphertext + for ct_i in ct.blocks.iter_mut() { + let mut decomp = scalar & (mask * power); + decomp /= power; + + self.key.unchecked_scalar_add_assign(ct_i, decomp as u8); + + //modulus to the power i + power *= self.key.message_modulus.0 as u64; + } + } + + /// Verifies if a scalar can be added to a ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 2; + /// let scalar = 40; + /// + /// // Encrypt two messages: + /// let ct1 = cks.encrypt(msg); + /// let ct2 = cks.encrypt(msg); + /// + /// // Check if we can perform an addition + /// let res = sks.is_scalar_add_possible(&ct1, scalar); + /// + /// assert_eq!(true, res); + /// ``` + pub fn is_scalar_add_possible(&self, ct: &RadixCiphertext, scalar: u64) -> bool { + //Bits of message put to 1 + let mask = (self.key.message_modulus.0 - 1) as u64; + + let mut power = 1_u64; + + for ct_i in ct.blocks.iter() { + let mut decomp = scalar & (mask * power); + decomp /= power; + + if !self.key.is_scalar_add_possible(ct_i, decomp as u8) { + return false; + } + + //modulus to the power i + power *= self.key.message_modulus.0 as u64; + } + true + } + + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 4; + /// let scalar = 40; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.checked_scalar_add(&mut ct, scalar)?; + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg + scalar, dec); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_scalar_add( + &self, + ct: &RadixCiphertext, + scalar: u64, + ) -> Result { + if self.is_scalar_add_possible(ct, scalar) { + Ok(self.unchecked_scalar_add(ct, scalar)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an addition between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + pub fn checked_scalar_add_assign( + &self, + ct: &mut RadixCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + if self.is_scalar_add_possible(ct, scalar) { + self.unchecked_scalar_add_assign(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically the addition of ciphertext with a scalar. + /// + /// The result is returned in a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 4; + /// let scalar = 40; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.smart_scalar_add(&mut ct, scalar); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg + scalar, dec); + /// ``` + pub fn smart_scalar_add(&self, ct: &mut RadixCiphertext, scalar: u64) -> RadixCiphertext { + if !self.is_scalar_add_possible(ct, scalar) { + self.full_propagate(ct); + } + + let mut ct = ct.clone(); + self.unchecked_scalar_add_assign(&mut ct, scalar); + ct + } + + /// Computes homomorphically the addition of ciphertext with a scalar. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 129; + /// let scalar = 40; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// sks.smart_scalar_add_assign(&mut ct, scalar); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg + scalar, dec); + /// ``` + pub fn smart_scalar_add_assign(&self, ct: &mut RadixCiphertext, scalar: u64) { + if !self.is_scalar_add_possible(ct, scalar) { + self.full_propagate(ct); + } + self.unchecked_scalar_add_assign(ct, scalar); + } +} diff --git a/tfhe/src/integer/server_key/radix/scalar_mul.rs b/tfhe/src/integer/server_key/radix/scalar_mul.rs new file mode 100644 index 000000000..8415201b6 --- /dev/null +++ b/tfhe/src/integer/server_key/radix/scalar_mul.rs @@ -0,0 +1,369 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::ServerKey; +use std::collections::BTreeMap; + +impl ServerKey { + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 30; + /// let scalar = 3; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.unchecked_small_scalar_mul(&ct, scalar); + /// + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(scalar * msg, clear); + /// ``` + pub fn unchecked_small_scalar_mul( + &self, + ctxt: &RadixCiphertext, + scalar: u64, + ) -> RadixCiphertext { + let mut ct_result = ctxt.clone(); + self.unchecked_small_scalar_mul_assign(&mut ct_result, scalar); + + ct_result + } + + pub fn unchecked_small_scalar_mul_assign(&self, ctxt: &mut RadixCiphertext, scalar: u64) { + for ct_i in ctxt.blocks.iter_mut() { + self.key.unchecked_scalar_mul_assign(ct_i, scalar as u8); + } + } + + ///Verifies if ct1 can be multiplied by scalar. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 25; + /// let scalar1 = 3; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Verification if the scalar multiplication can be computed: + /// let res = sks.is_small_scalar_mul_possible(&ct, scalar1); + /// + /// assert_eq!(true, res); + /// + /// let scalar2 = 7; + /// // Verification if the scalar multiplication can be computed: + /// let res = sks.is_small_scalar_mul_possible(&ct, scalar2); + /// assert_eq!(false, res); + /// ``` + pub fn is_small_scalar_mul_possible(&self, ctxt: &RadixCiphertext, scalar: u64) -> bool { + for ct_i in ctxt.blocks.iter() { + if !self.key.is_scalar_mul_possible(ct_i, scalar as u8) { + return false; + } + } + true + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 33; + /// let scalar = 3; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.checked_small_scalar_mul(&ct, scalar); + /// + /// match ct_res { + /// Err(x) => panic!("{:?}", x), + /// Ok(y) => { + /// let clear = cks.decrypt(&y); + /// assert_eq!(msg * scalar, clear); + /// } + /// } + /// ``` + pub fn checked_small_scalar_mul( + &self, + ct: &RadixCiphertext, + scalar: u64, + ) -> Result { + let mut ct_result = ct.clone(); + + // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext + if self.is_small_scalar_mul_possible(ct, scalar) { + ct_result = self.unchecked_small_scalar_mul(&ct_result, scalar); + + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is assigned to the ciphertext given + /// as parameter. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 33; + /// let scalar = 3; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// sks.checked_small_scalar_mul_assign(&mut ct, scalar); + /// + /// let clear_res = cks.decrypt(&ct); + /// assert_eq!(clear_res, msg * scalar); + /// ``` + pub fn checked_small_scalar_mul_assign( + &self, + ct: &mut RadixCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext + if self.is_small_scalar_mul_possible(ct, scalar) { + self.unchecked_small_scalar_mul_assign(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// `small` means the scalar value shall fit in a __shortint block__. + /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2, + /// the scalar should fit in 2 bits. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let modulus = 1 << 8; + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 13; + /// let scalar = 5; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.smart_small_scalar_mul(&mut ct, scalar); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(msg * scalar % modulus, clear); + /// ``` + pub fn smart_small_scalar_mul( + &self, + ctxt: &mut RadixCiphertext, + scalar: u64, + ) -> RadixCiphertext { + if !self.is_small_scalar_mul_possible(ctxt, scalar) { + self.full_propagate(ctxt); + } + self.unchecked_small_scalar_mul(ctxt, scalar) + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// `small` means the scalar shall value fit in a __shortint block__. + /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2, + /// the scalar should fit in 2 bits. + /// + /// The result is assigned to the input ciphertext + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let modulus = 1 << 8; + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 9; + /// let scalar = 3; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// sks.smart_small_scalar_mul_assign(&mut ct, scalar); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!(msg * scalar % modulus, clear); + /// ``` + pub fn smart_small_scalar_mul_assign(&self, ctxt: &mut RadixCiphertext, scalar: u64) { + if !self.is_small_scalar_mul_possible(ctxt, scalar) { + self.full_propagate(ctxt); + } + self.unchecked_small_scalar_mul_assign(ctxt, scalar); + } + + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 1; + /// let power = 2; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.blockshift(&ct, power); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(16, clear); + /// ``` + pub fn blockshift(&self, ctxt: &RadixCiphertext, shift: usize) -> RadixCiphertext { + let ctxt_zero = self.key.create_trivial(0_u64); + let mut result = ctxt.clone(); + + for res_i in result.blocks[..shift].iter_mut() { + *res_i = ctxt_zero.clone(); + } + + for (res_i, c_i) in result.blocks[shift..].iter_mut().zip(ctxt.blocks.iter()) { + *res_i = c_i.clone(); + } + result + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let modulus = 1 << 8; + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 230; + /// let scalar = 376; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.smart_scalar_mul(&mut ct, scalar); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(msg * scalar % modulus, clear); + /// ``` + pub fn smart_scalar_mul(&self, ctxt: &mut RadixCiphertext, scalar: u64) -> RadixCiphertext { + let mask = (self.key.message_modulus.0 - 1) as u64; + + //Propagate the carries before doing the multiplications + self.full_propagate(ctxt); + + //Store the computations + let mut map: BTreeMap = BTreeMap::new(); + + let mut result = self.create_trivial_zero_radix(ctxt.blocks.len()); + + let mut tmp; + + let mut b_i = 1_u64; + for i in 0..ctxt.blocks.len() { + //lambda = sum u_ib^i + let u_ib_i = scalar & (mask * b_i); + let u_i = u_ib_i / b_i; + + if u_i == 0 { + //update the power b^{i+1} + b_i *= self.key.message_modulus.0 as u64; + continue; + } else if u_i == 1 { + // tmp = ctxt * 1 * b^i + tmp = self.blockshift(ctxt, i); + } else { + tmp = map + .entry(u_i) + .or_insert_with(|| self.smart_small_scalar_mul(ctxt, u_i)) + .clone(); + + //tmp = ctxt* u_i * b^i + tmp = self.blockshift(&tmp, i); + } + + //update the result + result = self.smart_add(&mut result, &mut tmp); + + //update the power b^{i+1} + b_i *= self.key.message_modulus.0 as u64; + } + + result + } + + pub fn smart_scalar_mul_assign(&self, ctxt: &mut RadixCiphertext, scalar: u64) { + *ctxt = self.smart_scalar_mul(ctxt, scalar); + } +} diff --git a/tfhe/src/integer/server_key/radix/scalar_sub.rs b/tfhe/src/integer/server_key/radix/scalar_sub.rs new file mode 100644 index 000000000..2f5a25540 --- /dev/null +++ b/tfhe/src/integer/server_key/radix/scalar_sub.rs @@ -0,0 +1,233 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically a subtraction between a ciphertext and a scalar. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 40; + /// let scalar = 3; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.unchecked_scalar_sub(&ct, scalar); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg - scalar, dec); + /// ``` + pub fn unchecked_scalar_sub(&self, ct: &RadixCiphertext, scalar: u64) -> RadixCiphertext { + let mut result = ct.clone(); + self.unchecked_scalar_sub_assign(&mut result, scalar); + result + } + + pub fn unchecked_scalar_sub_assign(&self, ct: &mut RadixCiphertext, scalar: u64) { + //Bits of message put to 1 + let mask = (self.key.message_modulus.0 - 1) as u64; + + let modulus = self.key.message_modulus.0.pow(ct.blocks.len() as u32) as u64; + + let neg_scalar = scalar.wrapping_neg() % modulus; + + let mut power = 1_u64; + //Put each decomposition into a new ciphertext + for ct_i in ct.blocks.iter_mut() { + let mut decomp = neg_scalar & (mask * power); + decomp /= power; + + self.key.unchecked_scalar_add_assign(ct_i, decomp as u8); + + //modulus to the power i + power *= self.key.message_modulus.0 as u64; + } + } + + /// Verifies if the subtraction of a ciphertext by scalar can be computed. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 40; + /// let scalar = 2; + /// + /// let ct1 = cks.encrypt(msg); + /// + /// // Check if we can perform an addition + /// let res = sks.is_scalar_sub_possible(&ct1, scalar); + /// + /// assert_eq!(true, res); + /// ``` + pub fn is_scalar_sub_possible(&self, ct: &RadixCiphertext, scalar: u64) -> bool { + //Bits of message put to 1 + let mask = (self.key.message_modulus.0 - 1) as u64; + + let modulus = self.key.message_modulus.0.pow(ct.blocks.len() as u32) as u64; + + let neg_scalar = scalar.wrapping_neg() % modulus; + + let mut power = 1_u64; + + for ct_i in ct.blocks.iter() { + let mut decomp = neg_scalar & (mask * power); + decomp /= power; + + if !self.key.is_scalar_add_possible(ct_i, decomp as u8) { + return false; + } + + //modulus to the power i + power *= self.key.message_modulus.0 as u64; + } + true + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 40; + /// let scalar = 4; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute tne subtraction: + /// let ct_res = sks.checked_scalar_sub(&ct, scalar)?; + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg - scalar, dec); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_scalar_sub( + &self, + ct: &RadixCiphertext, + scalar: u64, + ) -> Result { + if self.is_scalar_sub_possible(ct, scalar) { + Ok(self.unchecked_scalar_sub(ct, scalar)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// # fn main() -> Result<(), Box> { + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 232; + /// let scalar = 83; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute tne subtraction: + /// sks.checked_scalar_sub_assign(&mut ct, scalar)?; + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg - scalar, dec); + /// # Ok(()) + /// # } + /// ``` + pub fn checked_scalar_sub_assign( + &self, + ct: &mut RadixCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + if self.is_scalar_sub_possible(ct, scalar) { + self.unchecked_scalar_sub_assign(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 165; + /// let scalar = 112; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.smart_scalar_sub(&mut ct, scalar); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg - scalar, dec); + /// ``` + pub fn smart_scalar_sub(&self, ct: &mut RadixCiphertext, scalar: u64) -> RadixCiphertext { + if !self.is_scalar_sub_possible(ct, scalar) { + self.full_propagate(ct); + } + + self.unchecked_scalar_sub(ct, scalar) + } + + pub fn smart_scalar_sub_assign(&self, ct: &mut RadixCiphertext, scalar: u64) { + if !self.is_scalar_sub_possible(ct, scalar) { + self.full_propagate(ct); + } + + self.unchecked_scalar_sub_assign(ct, scalar); + } +} diff --git a/tfhe/src/integer/server_key/radix/shift.rs b/tfhe/src/integer/server_key/radix/shift.rs new file mode 100644 index 000000000..88064ff7c --- /dev/null +++ b/tfhe/src/integer/server_key/radix/shift.rs @@ -0,0 +1,219 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; + +impl ServerKey { + /// Shifts the blocks to the right. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 16; + /// let shift = 2; + /// + /// // Encrypt two messages: + /// let mut ct = cks.encrypt(msg); + /// + /// let ct_res = sks.blockshift_right(&mut ct, shift); + /// + /// let div = cks.parameters().message_modulus.0.pow(shift as u32) as u64; + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(msg / div, clear); + /// ``` + pub fn blockshift_right(&self, ctxt: &RadixCiphertext, shift: usize) -> RadixCiphertext { + let mut result = self.create_trivial_zero_radix(ctxt.blocks.len()); + + let limit = result.blocks.len() - shift; + + for (res_i, c_i) in result.blocks[..limit] + .iter_mut() + .zip(ctxt.blocks[shift..].iter()) + { + *res_i = c_i.clone(); + } + + result + } + + pub fn blockshift_right_assign(&self, ctxt: &mut RadixCiphertext, shift: usize) { + *ctxt = self.blockshift_right(ctxt, shift); + } + + /// Computes homomorphically a right shift. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 128; + /// let shift = 2; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a right shift: + /// let ct_res = sks.unchecked_scalar_right_shift(&ct, shift); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg >> shift, dec); + /// ``` + pub fn unchecked_scalar_right_shift( + &self, + ct: &RadixCiphertext, + shift: usize, + ) -> RadixCiphertext { + let mut result = ct.clone(); + self.unchecked_scalar_right_shift_assign(&mut result, shift); + result + } + + /// Computes homomorphically a right shift. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 18; + /// let shift = 4; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a right shift: + /// sks.unchecked_scalar_right_shift_assign(&mut ct, shift); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg >> shift, dec); + /// ``` + pub fn unchecked_scalar_right_shift_assign(&self, ct: &mut RadixCiphertext, shift: usize) { + let tmp = self.key.message_modulus.0 as f64; + + //number of bits of message + let nb_bits = tmp.log2() as usize; + + // 2^u = 2^{p*q+r} = 2^{p*(q+1)}*2^{r-p} + let quotient = shift / nb_bits; + + //p-r + let modified_remainder = nb_bits - (shift % nb_bits); + + //if r == 0 + if modified_remainder == nb_bits { + self.full_propagate(ct); + self.blockshift_right_assign(ct, quotient); + } else { + // B/2^u = (B*2^{p-r}) / (2^{p*(q+1)}) + self.unchecked_scalar_left_shift_assign(ct, modified_remainder); + + // We partially propagate in order to not lose information + self.partial_propagate(ct); + self.blockshift_right_assign(ct, 1_usize); + + // We propagate the last block in order to not lose information + self.propagate(ct, ct.blocks.len() - 2); + self.blockshift_right_assign(ct, quotient); + } + } + + /// Propagates all carries except the last one. + /// For development purpose only. + fn partial_propagate(&self, ctxt: &mut RadixCiphertext) { + let len = ctxt.blocks.len() - 1; + for i in 0..len { + self.propagate(ctxt, i); + } + } + + /// Computes homomorphically a left shift by a scalar. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 21; + /// let shift = 2; + /// + /// let ct1 = cks.encrypt(msg); + /// + /// // Compute homomorphically a right shift: + /// let ct_res = sks.unchecked_scalar_left_shift(&ct1, shift); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg << shift, dec); + /// ``` + pub fn unchecked_scalar_left_shift( + &self, + ct_left: &RadixCiphertext, + shift: usize, + ) -> RadixCiphertext { + let mut result = ct_left.clone(); + self.unchecked_scalar_left_shift_assign(&mut result, shift); + result + } + + /// Computes homomorphically a left shift by a scalar. + /// + /// The result is assigned in the input ciphertext + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 13; + /// let shift = 2; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a right shift: + /// sks.unchecked_scalar_left_shift_assign(&mut ct, shift); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg << shift, dec); + /// ``` + pub fn unchecked_scalar_left_shift_assign(&self, ct: &mut RadixCiphertext, shift: usize) { + let tmp = 1_u64 << shift; + self.smart_scalar_mul_assign(ct, tmp); + } +} diff --git a/tfhe/src/integer/server_key/radix/sub.rs b/tfhe/src/integer/server_key/radix/sub.rs new file mode 100644 index 000000000..7ed220c50 --- /dev/null +++ b/tfhe/src/integer/server_key/radix/sub.rs @@ -0,0 +1,307 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. + /// + /// This function computes the subtraction without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg_1 = 12; + /// let msg_2 = 10; + /// + /// // Encrypt two messages: + /// let ctxt_1 = cks.encrypt(msg_1); + /// let ctxt_2 = cks.encrypt(msg_2); + /// + /// // Compute homomorphically a subtraction: + /// let ct_res = sks.unchecked_sub(&ctxt_1, &ctxt_2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg_1 - msg_2); + /// ``` + pub fn unchecked_sub( + &self, + ctxt_left: &RadixCiphertext, + ctxt_right: &RadixCiphertext, + ) -> RadixCiphertext { + let mut result = ctxt_left.clone(); + self.unchecked_sub_assign(&mut result, ctxt_right); + result + } + + /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. + /// + /// This function computes the subtraction without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg_1 = 128; + /// let msg_2 = 99; + /// + /// // Encrypt two messages: + /// let mut ctxt_1 = cks.encrypt(msg_1); + /// let ctxt_2 = cks.encrypt(msg_2); + /// + /// // Compute homomorphically a subtraction: + /// sks.unchecked_sub_assign(&mut ctxt_1, &ctxt_2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ctxt_1); + /// assert_eq!(dec_result, msg_1 - msg_2); + /// ``` + pub fn unchecked_sub_assign( + &self, + ctxt_left: &mut RadixCiphertext, + ctxt_right: &RadixCiphertext, + ) { + let neg = self.unchecked_neg(ctxt_right); + self.unchecked_add_assign(ctxt_left, &neg); + } + + /// Verifies if ct_right can be subtracted to ct_left. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg_1 = 182; + /// let msg_2 = 120; + /// + /// // Encrypt two messages: + /// let ctxt_1 = cks.encrypt(msg_1); + /// let ctxt_2 = cks.encrypt(msg_2); + /// + /// // Check if we can perform a subtraction + /// let res = sks.is_sub_possible(&ctxt_1, &ctxt_2); + /// + /// assert_eq!(true, res); + /// ``` + pub fn is_sub_possible( + &self, + ctxt_left: &RadixCiphertext, + ctxt_right: &RadixCiphertext, + ) -> bool { + for (ct_left_i, ct_right_i) in ctxt_left.blocks.iter().zip(ctxt_right.blocks.iter()) { + if !self.key.is_sub_possible(ct_left_i, ct_right_i) { + return false; + } + } + true + } + + /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let ctxt_1 = cks.encrypt(msg); + /// let ctxt_2 = cks.encrypt(msg); + /// + /// // Compute homomorphically a subtraction: + /// let ct_res = sks.checked_sub(&ctxt_1, &ctxt_2); + /// + /// match ct_res { + /// Err(x) => panic!("{:?}", x), + /// Ok(y) => { + /// let clear = cks.decrypt(&y); + /// assert_eq!(0, clear); + /// } + /// } + /// ``` + pub fn checked_sub( + &self, + ctxt_left: &RadixCiphertext, + ctxt_right: &RadixCiphertext, + ) -> Result { + if self.is_sub_possible(ctxt_left, ctxt_right) { + Ok(self.unchecked_sub(ctxt_left, ctxt_right)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let num_blocks = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 41u8; + /// let msg2 = 101u8; + /// + /// let mut ct1 = cks.encrypt(msg1 as u64); + /// let ct2 = cks.encrypt(msg2 as u64); + /// + /// // Compute homomorphically an addition: + /// let res = sks.checked_sub_assign(&mut ct1, &ct2); + /// + /// assert!(res.is_ok()); + /// + /// let clear = cks.decrypt(&ct1); + /// assert_eq!(msg1.wrapping_sub(msg2) as u64, clear); + /// ``` + pub fn checked_sub_assign( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Result<(), CheckError> { + if self.is_sub_possible(ct_left, ct_right) { + self.unchecked_sub_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically the subtraction between ct_left and ct_right. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg_1 = 120u8; + /// let msg_2 = 181u8; + /// + /// // Encrypt two messages: + /// let mut ctxt_1 = cks.encrypt(msg_1 as u64); + /// let mut ctxt_2 = cks.encrypt(msg_2 as u64); + /// + /// // Compute homomorphically a subtraction + /// let ct_res = sks.smart_sub(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(msg_1.wrapping_sub(msg_2) as u64, res); + /// ``` + pub fn smart_sub( + &self, + ctxt_left: &mut RadixCiphertext, + ctxt_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + // If the ciphertext cannot be negated without exceeding the capacity of a ciphertext + if !self.is_neg_possible(ctxt_right) { + self.full_propagate(ctxt_right); + } + + // If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_sub_possible(ctxt_left, ctxt_right) { + self.full_propagate(ctxt_left); + self.full_propagate(ctxt_right); + } + + let mut result = ctxt_left.clone(); + self.unchecked_sub_assign(&mut result, ctxt_right); + + result + } + + /// Computes homomorphically the subtraction between ct_left and ct_right. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg_1 = 120u8; + /// let msg_2 = 181u8; + /// + /// // Encrypt two messages: + /// let mut ctxt_1 = cks.encrypt(msg_1 as u64); + /// let mut ctxt_2 = cks.encrypt(msg_2 as u64); + /// + /// // Compute homomorphically a subtraction + /// sks.smart_sub_assign(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ctxt_1); + /// assert_eq!(msg_1.wrapping_sub(msg_2) as u64, res); + /// ``` + pub fn smart_sub_assign( + &self, + ctxt_left: &mut RadixCiphertext, + ctxt_right: &mut RadixCiphertext, + ) { + // If the ciphertext cannot be negated without exceeding the capacity of a ciphertext + if !self.is_neg_possible(ctxt_right) { + self.full_propagate(ctxt_right); + } + + // If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_sub_possible(ctxt_left, ctxt_right) { + self.full_propagate(ctxt_left); + self.full_propagate(ctxt_right); + } + + self.unchecked_sub_assign(ctxt_left, ctxt_right); + } +} diff --git a/tfhe/src/integer/server_key/radix/tests.rs b/tfhe/src/integer/server_key/radix/tests.rs new file mode 100644 index 000000000..7d3b8164f --- /dev/null +++ b/tfhe/src/integer/server_key/radix/tests.rs @@ -0,0 +1,957 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::shortint::parameters::*; +use crate::shortint::Parameters; +use rand::Rng; + +/// Number of loop iteration within randomized tests +const NB_TEST: usize = 30; + +/// Smaller number of loop iteration within randomized test, +/// meant for test where the function tested is more expensive +const NB_TEST_SMALLER: usize = 10; +const NB_CTXT: usize = 4; + +create_parametrized_test!(integer_encrypt_decrypt); +create_parametrized_test!(integer_unchecked_add); +create_parametrized_test!(integer_smart_add); +create_parametrized_test!(integer_unchecked_bitand); +create_parametrized_test!(integer_unchecked_bitor); +create_parametrized_test!(integer_unchecked_bitxor); +create_parametrized_test!(integer_smart_bitand); +create_parametrized_test!(integer_smart_bitor); +create_parametrized_test!(integer_smart_bitxor); +create_parametrized_test!(integer_unchecked_small_scalar_mul); +create_parametrized_test!(integer_smart_small_scalar_mul); +create_parametrized_test!(integer_blockshift); +create_parametrized_test!(integer_blockshift_right); +create_parametrized_test!(integer_smart_scalar_mul); +create_parametrized_test!(integer_unchecked_scalar_left_shift); +create_parametrized_test!(integer_unchecked_scalar_right_shift); +create_parametrized_test!(integer_unchecked_negation); +create_parametrized_test!(integer_smart_neg); +create_parametrized_test!(integer_unchecked_sub); +create_parametrized_test!(integer_smart_sub); +create_parametrized_test!(integer_unchecked_block_mul); +create_parametrized_test!(integer_smart_block_mul); +create_parametrized_test!(integer_smart_mul); + +create_parametrized_test!(integer_smart_scalar_sub); +create_parametrized_test!(integer_smart_scalar_add); +create_parametrized_test!(integer_unchecked_scalar_sub); +create_parametrized_test!(integer_unchecked_scalar_add); + +fn integer_encrypt_decrypt(param: Parameters) { + let (cks, _) = KEY_CACHE.get_from_params(param); + + // RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + //encryption + let ct = cks.encrypt_radix(clear, NB_CTXT); + + // decryption + let dec = cks.decrypt_radix(&ct); + + // assert + assert_eq!(clear, dec); + } +} + +fn integer_unchecked_add(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // encryption of an integer + let ctxt_1 = cks.encrypt_radix(clear_1, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.unchecked_add(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!((clear_0 + clear_1) % modulus, dec_res); + } +} + +fn integer_smart_add(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt_radix(clear_1, NB_CTXT); + + // add the two ciphertexts + let mut ct_res = sks.smart_add(&mut ctxt_0, &mut ctxt_1); + + clear = (clear_0 + clear_1) % modulus; + + // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); + //add multiple times to raise the degree + for _ in 0..NB_TEST_SMALLER { + ct_res = sks.smart_add(&mut ct_res, &mut ctxt_0); + clear = (clear + clear_0) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_unchecked_bitand(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // encryption of an integer + let ctxt_1 = cks.encrypt_radix(clear_1, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.unchecked_bitand(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!(clear_0 & clear_1, dec_res); + } +} + +fn integer_unchecked_bitor(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // encryption of an integer + let ctxt_1 = cks.encrypt_radix(clear_1, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.unchecked_bitor(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!(clear_0 | clear_1, dec_res); + } +} + +fn integer_unchecked_bitxor(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // encryption of an integer + let ctxt_1 = cks.encrypt_radix(clear_1, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.unchecked_bitxor(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!(clear_0 ^ clear_1, dec_res); + } +} + +fn integer_smart_bitand(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt_radix(clear_1, NB_CTXT); + + // add the two ciphertexts + let mut ct_res = sks.smart_bitand(&mut ctxt_0, &mut ctxt_1); + + clear = clear_0 & clear_1; + + for _ in 0..NB_TEST_SMALLER { + let clear_2 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_2 = cks.encrypt_radix(clear_2, NB_CTXT); + + ct_res = sks.smart_bitand(&mut ct_res, &mut ctxt_2); + clear &= clear_2; + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_smart_bitor(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt_radix(clear_1, NB_CTXT); + + // add the two ciphertexts + let mut ct_res = sks.smart_bitor(&mut ctxt_0, &mut ctxt_1); + + clear = (clear_0 | clear_1) % modulus; + + for _ in 0..1 { + let clear_2 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_2 = cks.encrypt_radix(clear_2, NB_CTXT); + + ct_res = sks.smart_bitor(&mut ct_res, &mut ctxt_2); + clear = (clear | clear_2) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_smart_bitxor(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt_radix(clear_1, NB_CTXT); + + // add the two ciphertexts + let mut ct_res = sks.smart_bitxor(&mut ctxt_0, &mut ctxt_1); + + clear = (clear_0 ^ clear_1) % modulus; + + for _ in 0..NB_TEST_SMALLER { + let clear_2 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_2 = cks.encrypt_radix(clear_2, NB_CTXT); + + ct_res = sks.smart_bitxor(&mut ct_res, &mut ctxt_2); + clear = (clear ^ clear_2) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_unchecked_small_scalar_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let scalar_modulus = param.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % scalar_modulus; + + // encryption of an integer + let ct = cks.encrypt_radix(clear, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.unchecked_small_scalar_mul(&ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!((clear * scalar) % modulus, dec_res); + } +} + +fn integer_smart_small_scalar_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let scalar_modulus = param.message_modulus.0 as u64; + + let mut clear_res; + for _ in 0..NB_TEST_SMALLER { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % scalar_modulus; + + // encryption of an integer + let mut ct = cks.encrypt_radix(clear, NB_CTXT); + + let mut ct_res = sks.smart_small_scalar_mul(&mut ct, scalar); + + clear_res = clear * scalar; + for _ in 0..NB_TEST_SMALLER { + // scalar multiplication + ct_res = sks.smart_small_scalar_mul(&mut ct_res, scalar); + clear_res *= scalar; + } + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!(clear_res % modulus, dec_res); + } +} + +fn integer_blockshift(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let power = rng.gen::() % NB_CTXT as u64; + + // encryption of an integer + let ct = cks.encrypt_radix(clear, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.blockshift(&ct, power as usize); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!( + (clear * param.message_modulus.0.pow(power as u32) as u64) % modulus, + dec_res + ); + } +} + +fn integer_blockshift_right(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let power = rng.gen::() % NB_CTXT as u64; + + // encryption of an integer + let ct = cks.encrypt_radix(clear, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.blockshift_right(&ct, power as usize); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!( + (clear / param.message_modulus.0.pow(power as u32) as u64) % modulus, + dec_res + ); + } +} + +fn integer_smart_scalar_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % modulus; + + // encryption of an integer + let mut ct = cks.encrypt_radix(clear, NB_CTXT); + + // scalar mul + let ct_res = sks.smart_scalar_mul(&mut ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!((clear * scalar) % modulus, dec_res); + } +} + +fn integer_unchecked_scalar_left_shift(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + //Nb of bits to shift + let tmp_f64 = param.message_modulus.0 as f64; + let nb_bits = tmp_f64.log2().floor() as usize * NB_CTXT; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % nb_bits; + + // encryption of an integer + let ct = cks.encrypt_radix(clear, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_left_shift(&ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!((clear << scalar) % modulus, dec_res); + } +} + +fn integer_unchecked_scalar_right_shift(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + //Nb of bits to shift + let tmp_f64 = param.message_modulus.0 as f64; + let nb_bits = tmp_f64.log2().floor() as usize * NB_CTXT; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % nb_bits; + + // encryption of an integer + let ct = cks.encrypt_radix(clear, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_right_shift(&ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!(clear >> scalar, dec_res); + } +} + +fn integer_unchecked_negation(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + // Define the cleartexts + let clear = rng.gen::() % modulus; + + // println!("clear = {}", clear); + + // Encrypt the integers + let ctxt = cks.encrypt_radix(clear, NB_CTXT); + + // Negates the ctxt + let ct_tmp = sks.unchecked_neg(&ctxt); + + // Decrypt the result + let dec = cks.decrypt_radix(&ct_tmp); + + // Check the correctness + let clear_result = clear.wrapping_neg() % modulus; + + //println!("clear = {}", clear); + assert_eq!(clear_result, dec); + } +} + +fn integer_smart_neg(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + // Define the cleartexts + let clear = rng.gen::() % modulus; + + // Encrypt the integers + let mut ctxt = cks.encrypt_radix(clear, NB_CTXT); + + // Negates the ctxt + let ct_tmp = sks.smart_neg(&mut ctxt); + + // Decrypt the result + let dec = cks.decrypt_radix(&ct_tmp); + + // Check the correctness + let clear_result = clear.wrapping_neg() % modulus; + + assert_eq!(clear_result, dec); + } +} + +fn integer_unchecked_sub(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + // RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + // Encrypt the integers + let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT); + let ctxt_2 = cks.encrypt_radix(clear2, NB_CTXT); + + // Add the ciphertext 1 and 2 + let ct_tmp = sks.unchecked_sub(&ctxt_1, &ctxt_2); + + // Decrypt the result + let dec = cks.decrypt_radix(&ct_tmp); + + // Check the correctness + let clear_result = (clear1 - clear2) % modulus; + assert_eq!(clear_result, dec); + } +} + +fn integer_smart_sub(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST_SMALLER { + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + // Encrypt the integers + let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT); + let mut ctxt_2 = cks.encrypt_radix(clear2, NB_CTXT); + + let mut res = ctxt_1.clone(); + let mut clear = clear1; + + //subtract multiple times to raise the degree + for _ in 0..NB_TEST_SMALLER { + res = sks.smart_sub(&mut res, &mut ctxt_2); + clear = (clear - clear2) % modulus; + // println!("clear = {}, clear2 = {}", clear, cks.decrypt(&res)); + } + let dec = cks.decrypt_radix(&res); + + // Check the correctness + assert_eq!(clear, dec); + } +} + +fn integer_unchecked_block_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let block_modulus = param.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % block_modulus; + + // encryption of an integer + let ct_zero = cks.encrypt_radix(clear_0, NB_CTXT); + + // encryption of an integer + let ct_one = cks.encrypt_one_block(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_block_mul(&ct_zero, &ct_one, 0); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!((clear_0 * clear_1) % modulus, dec_res); + } +} + +fn integer_smart_block_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let block_modulus = param.message_modulus.0 as u64; + + for _ in 0..5 { + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % block_modulus; + + // Encrypt the integers + let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT); + let ctxt_2 = cks.encrypt_one_block(clear2); + + let mut res = ctxt_1.clone(); + let mut clear = clear1; + + res = sks.smart_block_mul(&mut res, &ctxt_2, 0); + for _ in 0..5 { + res = sks.smart_block_mul(&mut res, &ctxt_2, 0); + clear = (clear * clear2) % modulus; + } + let dec = cks.decrypt_radix(&res); + + clear = (clear * clear2) % modulus; + + // Check the correctness + assert_eq!(clear, dec); + } +} + +fn integer_smart_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST_SMALLER { + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + // println!("clear1 = {}, clear2 = {}", clear1, clear2); + + // Encrypt the integers + let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT); + let mut ctxt_2 = cks.encrypt_radix(clear2, NB_CTXT); + + let mut res = ctxt_1.clone(); + let mut clear = clear1; + + res = sks.smart_mul(&mut res, &mut ctxt_2); + for _ in 0..5 { + res = sks.smart_mul(&mut res, &mut ctxt_2); + clear = (clear * clear2) % modulus; + } + let dec = cks.decrypt_radix(&res); + + clear = (clear * clear2) % modulus; + + // Check the correctness + assert_eq!(clear, dec); + } +} + +fn integer_unchecked_scalar_add(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_add(&ctxt_0, clear_1); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!((clear_0 + clear_1) % modulus, dec_res); + } +} + +fn integer_smart_scalar_add(param: Parameters) { + // generate the server-client key set + let (cks, sks) = KEY_CACHE.get_from_params(param); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + // RNG + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // add the two ciphertexts + let mut ct_res = sks.smart_scalar_add(&mut ctxt_0, clear_1); + + clear = (clear_0 + clear_1) % modulus; + + // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); + //add multiple times to raise the degree + for _ in 0..NB_TEST_SMALLER { + ct_res = sks.smart_scalar_add(&mut ct_res, clear_1); + clear = (clear + clear_1) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_unchecked_scalar_sub(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_sub(&ctxt_0, clear_1); + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // assert + assert_eq!((clear_0 - clear_1) % modulus, dec_res); + } +} + +fn integer_smart_scalar_sub(param: Parameters) { + // generate the server-client key set + let (cks, sks) = KEY_CACHE.get_from_params(param); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + // RNG + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT); + + // add the two ciphertexts + let mut ct_res = sks.smart_scalar_sub(&mut ctxt_0, clear_1); + + clear = (clear_0 - clear_1) % modulus; + + // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); + //add multiple times to raise the degree + for _ in 0..NB_TEST_SMALLER { + ct_res = sks.smart_scalar_sub(&mut ct_res, clear_1); + clear = (clear - clear_1) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt_radix(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + // assert + assert_eq!(clear, dec_res); + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs new file mode 100644 index 000000000..bb9e41777 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -0,0 +1,146 @@ +use std::sync::Mutex; + +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// # Warning + /// + /// - Multithreaded + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 14; + /// let msg2 = 97; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.smart_add_parallelized(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg1 + msg2); + /// ``` + pub fn smart_add_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + if !self.is_add_possible(ct_left, ct_right) { + rayon::join( + || self.full_propagate_parallelized(ct_left), + || self.full_propagate_parallelized(ct_right), + ); + } + self.unchecked_add(ct_left, ct_right) + } + + pub fn smart_add_assign_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) { + if !self.is_add_possible(ct_left, ct_right) { + rayon::join( + || self.full_propagate_parallelized(ct_left), + || self.full_propagate_parallelized(ct_right), + ); + } + self.unchecked_add_assign(ct_left, ct_right); + } + + /// op must be associative and commutative + pub fn smart_binary_op_seq_parallelized<'this, 'item>( + &'this self, + ct_seq: impl IntoIterator, + op: impl for<'a> Fn( + &'a ServerKey, + &'a mut RadixCiphertext, + &'a mut RadixCiphertext, + ) -> RadixCiphertext + + Sync, + ) -> Option { + enum CiphertextCow<'a> { + Borrowed(&'a mut RadixCiphertext), + Owned(RadixCiphertext), + } + impl CiphertextCow<'_> { + fn as_mut(&mut self) -> &mut RadixCiphertext { + match self { + CiphertextCow::Borrowed(b) => b, + CiphertextCow::Owned(o) => o, + } + } + } + + let ct_seq = ct_seq + .into_iter() + .map(CiphertextCow::Borrowed) + .collect::>(); + let op = &op; + + // overhead of dynamic dispatch is negligible compared to multithreading, PBS, etc. + // we defer all calls to a single implementation to avoid code bloat and long compile + // times + fn reduce_impl( + sks: &ServerKey, + mut ct_seq: Vec, + op: &(dyn for<'a> Fn( + &'a ServerKey, + &'a mut RadixCiphertext, + &'a mut RadixCiphertext, + ) -> RadixCiphertext + + Sync), + ) -> Option { + use rayon::prelude::*; + + if ct_seq.is_empty() { + None + } else { + // we repeatedly divide the number of terms by two by iteratively reducing + // consecutive terms in the array + while ct_seq.len() > 1 { + let results = + Mutex::new(Vec::::with_capacity(ct_seq.len() / 2)); + + // if the number of elements is odd, we skip the first element + let untouched_prefix = ct_seq.len() % 2; + let ct_seq_slice = &mut ct_seq[untouched_prefix..]; + + ct_seq_slice.par_chunks_mut(2).for_each(|chunk| { + let (first, second) = chunk.split_at_mut(1); + let first = &mut first[0]; + let second = &mut second[0]; + let result = op(sks, first.as_mut(), second.as_mut()); + results.lock().unwrap().push(result); + }); + + let results = results.into_inner().unwrap(); + ct_seq.truncate(untouched_prefix); + ct_seq.extend(results.into_iter().map(CiphertextCow::Owned)); + } + + let sum = ct_seq.pop().unwrap(); + + Some(match sum { + CiphertextCow::Borrowed(b) => b.clone(), + CiphertextCow::Owned(o) => o, + }) + } + } + + reduce_impl(self, ct_seq, op) + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/bitwise_op.rs b/tfhe/src/integer/server_key/radix_parallel/bitwise_op.rs new file mode 100644 index 000000000..3efa9b3d4 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/bitwise_op.rs @@ -0,0 +1,172 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically a bitand between two ciphertexts encrypting integer values. + /// + /// # Warning + /// + /// - Multithreaded + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 14; + /// let msg2 = 97; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// let ct_res = sks.smart_bitand_parallelized(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg1 & msg2); + /// ``` + pub fn smart_bitand_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + rayon::join( + || self.full_propagate_parallelized(ct_left), + || self.full_propagate_parallelized(ct_right), + ); + } + self.unchecked_bitand(ct_left, ct_right) + } + + pub fn smart_bitand_assign_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + rayon::join( + || self.full_propagate_parallelized(ct_left), + || self.full_propagate_parallelized(ct_right), + ); + } + self.unchecked_bitand_assign(ct_left, ct_right); + } + + /// Computes homomorphically a bitor between two ciphertexts encrypting integer values. + /// + /// # Warning + /// + /// - Multithreaded + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 14; + /// let msg2 = 97; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// let ct_res = sks.smart_bitor(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg1 | msg2); + /// ``` + pub fn smart_bitor_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + rayon::join( + || self.full_propagate_parallelized(ct_left), + || self.full_propagate_parallelized(ct_right), + ); + } + self.unchecked_bitor(ct_left, ct_right) + } + + pub fn smart_bitor_assign_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + rayon::join( + || self.full_propagate_parallelized(ct_left), + || self.full_propagate_parallelized(ct_right), + ); + } + self.unchecked_bitor_assign(ct_left, ct_right); + } + + /// Computes homomorphically a bitxor between two ciphertexts encrypting integer values. + /// + /// # Warning + /// + /// - Multithreaded + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg1 = 14; + /// let msg2 = 97; + /// + /// let mut ct1 = cks.encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// let ct_res = sks.smart_bitxor_parallelized(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let dec_result = cks.decrypt(&ct_res); + /// assert_eq!(dec_result, msg1 ^ msg2); + /// ``` + pub fn smart_bitxor_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + rayon::join( + || self.full_propagate_parallelized(ct_left), + || self.full_propagate_parallelized(ct_right), + ); + } + self.unchecked_bitxor(ct_left, ct_right) + } + + pub fn smart_bitxor_assign_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &mut RadixCiphertext, + ) { + if !self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + rayon::join( + || self.full_propagate_parallelized(ct_left), + || self.full_propagate_parallelized(ct_right), + ); + } + self.unchecked_bitxor_assign(ct_left, ct_right); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs new file mode 100644 index 000000000..35c70bce6 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -0,0 +1,89 @@ +mod add; +mod bitwise_op; +mod mul; +mod neg; +mod scalar_add; +mod scalar_mul; +mod scalar_sub; +mod shift; +mod sub; + +#[cfg(test)] +mod tests; + +use super::ServerKey; +use crate::integer::RadixCiphertext; + +// parallelized versions +impl ServerKey { + /// Propagate the carry of the 'index' block to the next one. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::{gen_keys_radix, IntegerCiphertext}; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 7; + /// + /// let ct1 = cks.encrypt(msg); + /// let ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let mut ct_res = sks.unchecked_add(&ct1, &ct2); + /// sks.propagate_parallelized(&mut ct_res, 0); + /// + /// // Decrypt one block: + /// let res = cks.decrypt_one_block(&ct_res.blocks()[1]); + /// assert_eq!(3, res); + /// ``` + pub fn propagate_parallelized(&self, ctxt: &mut RadixCiphertext, index: usize) { + let (carry, message) = rayon::join( + || self.key.carry_extract(&ctxt.blocks[index]), + || self.key.message_extract(&ctxt.blocks[index]), + ); + ctxt.blocks[index] = message; + + //add the carry to the next block + if index < ctxt.blocks.len() - 1 { + self.key + .unchecked_add_assign(&mut ctxt.blocks[index + 1], &carry); + } + } + + /// Propagate all the carries. + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let msg = 10; + /// + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let mut ct_res = sks.unchecked_add(&mut ct1, &mut ct2); + /// sks.full_propagate_parallelized(&mut ct_res); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(msg + msg, res); + /// ``` + pub fn full_propagate_parallelized(&self, ctxt: &mut RadixCiphertext) { + let len = ctxt.blocks.len(); + for i in 0..len { + self.propagate_parallelized(ctxt, i); + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/mul.rs b/tfhe/src/integer/server_key/radix_parallel/mul.rs new file mode 100644 index 000000000..21c49158f --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/mul.rs @@ -0,0 +1,334 @@ +use std::sync::Mutex; + +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; +use rayon::prelude::*; + +impl ServerKey { + /// Computes homomorphically a multiplication between a ciphertext encrypting an integer value + /// and another encrypting a shortint value. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Warning + /// + /// - Multithreaded + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let clear_1 = 170; + /// let clear_2 = 3; + /// + /// // Encrypt two messages + /// let mut ct_left = cks.encrypt(clear_1); + /// let ct_right = cks.encrypt_one_block(clear_2); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_block_mul_assign_parallelized(&mut ct_left, &ct_right, 0); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_left); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn unchecked_block_mul_assign_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &crate::shortint::Ciphertext, + index: usize, + ) { + *ct_left = self.unchecked_block_mul_parallelized(ct_left, ct_right, index); + } + + /// Computes homomorphically a multiplication between a ciphertexts encrypting an integer + /// value and another encrypting a shortint value. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Warning + /// + /// - Multithreaded + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let clear_1 = 55; + /// let clear_2 = 3; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(clear_1); + /// let ct_right = cks.encrypt_one_block(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.unchecked_block_mul_parallelized(&ct_left, &ct_right, 0); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn unchecked_block_mul_parallelized( + &self, + ct1: &RadixCiphertext, + ct2: &crate::shortint::Ciphertext, + index: usize, + ) -> RadixCiphertext { + let shifted_ct = self.blockshift(ct1, index); + + let mut result_lsb = shifted_ct.clone(); + let mut result_msb = shifted_ct; + self.unchecked_block_mul_lsb_msb_parallelized(&mut result_lsb, &mut result_msb, ct2, index); + result_msb = self.blockshift(&result_msb, 1); + + self.unchecked_add(&result_lsb, &result_msb) + } + + /// Computes homomorphically a multiplication between a ciphertext encrypting integer value + /// and another encrypting a shortint value. + /// + /// The result is returned as a new ciphertext. + /// + /// # Warning + /// + /// - Multithreaded + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let clear_1 = 170; + /// let clear_2 = 3; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let ctxt_2 = cks.encrypt_one_block(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.smart_block_mul_parallelized(&mut ctxt_1, &ctxt_2, 0); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn smart_block_mul_parallelized( + &self, + ct1: &mut RadixCiphertext, + ct2: &crate::shortint::Ciphertext, + index: usize, + ) -> RadixCiphertext { + //Makes sure we can do the multiplications + self.full_propagate_parallelized(ct1); + + let shifted_ct = self.blockshift(ct1, index); + + let mut result_lsb = shifted_ct.clone(); + let mut result_msb = shifted_ct; + self.unchecked_block_mul_lsb_msb_parallelized(&mut result_lsb, &mut result_msb, ct2, index); + result_msb = self.blockshift(&result_msb, 1); + + self.smart_add_parallelized(&mut result_lsb, &mut result_msb) + } + + fn unchecked_block_mul_lsb_msb_parallelized( + &self, + result_lsb: &mut RadixCiphertext, + result_msb: &mut RadixCiphertext, + ct2: &crate::shortint::Ciphertext, + index: usize, + ) { + let len = result_msb.blocks.len() - 1; + rayon::join( + || { + result_lsb.blocks[index..] + .par_iter_mut() + .for_each(|res_lsb_i| { + self.key.unchecked_mul_lsb_assign(res_lsb_i, ct2); + }); + }, + || { + result_msb.blocks[index..len] + .par_iter_mut() + .for_each(|res_msb_i| { + self.key.unchecked_mul_msb_assign(res_msb_i, ct2); + }); + }, + ); + } + + pub fn smart_block_mul_assign_parallelized( + &self, + ct1: &mut RadixCiphertext, + ct2: &crate::shortint::Ciphertext, + index: usize, + ) { + *ct1 = self.smart_block_mul_parallelized(ct1, ct2, index); + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer values. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Warning + /// + /// - Multithreaded + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let clear_1 = 255; + /// let clear_2 = 143; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let ctxt_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.unchecked_mul_parallelized(&mut ctxt_1, &ctxt_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn unchecked_mul_assign_parallelized( + &self, + ct1: &mut RadixCiphertext, + ct2: &RadixCiphertext, + ) { + *ct1 = self.unchecked_mul_parallelized(ct1, ct2); + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer values. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Warning + /// + /// - Multithreaded + pub fn unchecked_mul_parallelized( + &self, + ct1: &mut RadixCiphertext, + ct2: &RadixCiphertext, + ) -> RadixCiphertext { + let mut result = self.create_trivial_zero_radix(ct1.blocks.len()); + + let terms = Mutex::new(Vec::new()); + + ct2.blocks.par_iter().enumerate().for_each(|(i, ct2_i)| { + let term = self.unchecked_block_mul_parallelized(ct1, ct2_i, i); + terms.lock().unwrap().push(term); + }); + + let terms = terms.into_inner().unwrap(); + + for term in terms { + self.unchecked_add_assign(&mut result, &term); + } + + result + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer values. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Warning + /// + /// - Multithreaded + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks); + /// + /// let clear_1 = 170; + /// let clear_2 = 6; + /// + /// // Encrypt two messages + /// let mut ctxt_1 = cks.encrypt(clear_1); + /// let mut ctxt_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.smart_mul_parallelized(&mut ctxt_1, &mut ctxt_2); + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((clear_1 * clear_2) % 256, res); + /// ``` + pub fn smart_mul_assign_parallelized( + &self, + ct1: &mut RadixCiphertext, + ct2: &mut RadixCiphertext, + ) { + *ct1 = self.smart_mul_parallelized(ct1, ct2); + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer values. + /// + /// The result is returned as a new ciphertext. + /// + /// # Warning + /// + /// - Multithreaded + pub fn smart_mul_parallelized( + &self, + ct1: &mut RadixCiphertext, + ct2: &mut RadixCiphertext, + ) -> RadixCiphertext { + rayon::join( + || self.full_propagate_parallelized(ct1), + || self.full_propagate_parallelized(ct2), + ); + + let terms = Mutex::new(Vec::new()); + ct2.blocks.par_iter().enumerate().for_each(|(i, ct2_i)| { + let term = self.unchecked_block_mul_parallelized(ct1, ct2_i, i); + terms.lock().unwrap().push(term); + }); + let mut terms = terms.into_inner().unwrap(); + + self.smart_binary_op_seq_parallelized(&mut terms, ServerKey::smart_add_parallelized) + .unwrap_or_else(|| self.create_trivial_zero_radix(ct1.blocks.len())) + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/neg.rs b/tfhe/src/integer/server_key/radix_parallel/neg.rs new file mode 100644 index 000000000..12f706141 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/neg.rs @@ -0,0 +1,37 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; + +impl ServerKey { + /// Homomorphically computes the opposite of a ciphertext encrypting an integer message. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ctxt = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation + /// let ct_res = sks.smart_neg_parallelized(&mut ctxt); + /// + /// // Decrypt + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(255, dec); + /// ``` + pub fn smart_neg_parallelized(&self, ctxt: &mut RadixCiphertext) -> RadixCiphertext { + if !self.is_neg_possible(ctxt) { + self.full_propagate_parallelized(ctxt); + } + self.unchecked_neg(ctxt) + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs new file mode 100644 index 000000000..196ef1459 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs @@ -0,0 +1,74 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically the addition of ciphertext with a scalar. + /// + /// The result is returned in a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 4; + /// let scalar = 40; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.smart_scalar_add_parallelized(&mut ct, scalar); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg + scalar, dec); + /// ``` + pub fn smart_scalar_add_parallelized( + &self, + ct: &mut RadixCiphertext, + scalar: u64, + ) -> RadixCiphertext { + if !self.is_scalar_add_possible(ct, scalar) { + self.full_propagate_parallelized(ct); + } + self.unchecked_scalar_add(ct, scalar) + } + + /// Computes homomorphically the addition of ciphertext with a scalar. + /// + /// The result is assigned to the `ct_left` ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 129; + /// let scalar = 40; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// sks.smart_scalar_add_assign_parallelized(&mut ct, scalar); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg + scalar, dec); + /// ``` + pub fn smart_scalar_add_assign_parallelized(&self, ct: &mut RadixCiphertext, scalar: u64) { + if !self.is_scalar_add_possible(ct, scalar) { + self.full_propagate_parallelized(ct); + } + self.unchecked_scalar_add_assign(ct, scalar); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_mul.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_mul.rs new file mode 100644 index 000000000..6b7993c2c --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_mul.rs @@ -0,0 +1,311 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::server_key::CheckError; +use crate::integer::server_key::CheckError::CarryFull; +use crate::integer::ServerKey; +use rayon::prelude::*; +use std::collections::HashMap; +use std::sync::Mutex; + +impl ServerKey { + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// This function computes the operation without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 30; + /// let scalar = 3; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.unchecked_small_scalar_mul_parallelized(&ct, scalar); + /// + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(scalar * msg, clear); + /// ``` + pub fn unchecked_small_scalar_mul_parallelized( + &self, + ctxt: &RadixCiphertext, + scalar: u64, + ) -> RadixCiphertext { + let mut ct_result = ctxt.clone(); + self.unchecked_small_scalar_mul_assign_parallelized(&mut ct_result, scalar); + ct_result + } + + pub fn unchecked_small_scalar_mul_assign_parallelized( + &self, + ctxt: &mut RadixCiphertext, + scalar: u64, + ) { + ctxt.blocks.par_iter_mut().for_each(|ct_i| { + self.key.unchecked_scalar_mul_assign(ct_i, scalar as u8); + }); + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is returned in a new ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 33; + /// let scalar = 3; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.checked_small_scalar_mul_parallelized(&ct, scalar); + /// + /// match ct_res { + /// Err(x) => panic!("{:?}", x), + /// Ok(y) => { + /// let clear = cks.decrypt(&y); + /// assert_eq!(msg * scalar, clear); + /// } + /// } + /// ``` + pub fn checked_small_scalar_mul_parallelized( + &self, + ct: &RadixCiphertext, + scalar: u64, + ) -> Result { + // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext + if self.is_small_scalar_mul_possible(ct, scalar) { + Ok(self.unchecked_small_scalar_mul_parallelized(ct, scalar)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// If the operation can be performed, the result is assigned to the ciphertext given + /// as parameter. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 33; + /// let scalar = 3; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// sks.checked_small_scalar_mul_assign_parallelized(&mut ct, scalar); + /// + /// let clear_res = cks.decrypt(&ct); + /// assert_eq!(clear_res, msg * scalar); + /// ``` + pub fn checked_small_scalar_mul_assign_parallelized( + &self, + ct: &mut RadixCiphertext, + scalar: u64, + ) -> Result<(), CheckError> { + // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext + if self.is_small_scalar_mul_possible(ct, scalar) { + self.unchecked_small_scalar_mul_assign_parallelized(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// `small` means the scalar value shall fit in a __shortint block__. + /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2, + /// the scalar should fit in 2 bits. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let modulus = 1 << 8; + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 13; + /// let scalar = 5; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.smart_small_scalar_mul_parallelized(&mut ct, scalar); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(msg * scalar % modulus, clear); + /// ``` + pub fn smart_small_scalar_mul_parallelized( + &self, + ctxt: &mut RadixCiphertext, + scalar: u64, + ) -> RadixCiphertext { + if !self.is_small_scalar_mul_possible(ctxt, scalar) { + self.full_propagate_parallelized(ctxt); + } + self.unchecked_small_scalar_mul_parallelized(ctxt, scalar) + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// `small` means the scalar shall value fit in a __shortint block__. + /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2, + /// the scalar should fit in 2 bits. + /// + /// The result is assigned to the input ciphertext + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let modulus = 1 << 8; + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 9; + /// let scalar = 3; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// sks.smart_small_scalar_mul_assign_parallelized(&mut ct, scalar); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!(msg * scalar % modulus, clear); + /// ``` + pub fn smart_small_scalar_mul_assign_parallelized( + &self, + ctxt: &mut RadixCiphertext, + scalar: u64, + ) { + if !self.is_small_scalar_mul_possible(ctxt, scalar) { + self.full_propagate_parallelized(ctxt); + } + self.unchecked_small_scalar_mul_assign_parallelized(ctxt, scalar); + } + + /// Computes homomorphically a multiplication between a scalar and a ciphertext. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let modulus = 1 << 8; + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 230; + /// let scalar = 376; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.smart_scalar_mul_parallelized(&mut ct, scalar); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(msg * scalar % modulus, clear); + /// ``` + pub fn smart_scalar_mul_parallelized( + &self, + ct: &mut RadixCiphertext, + scalar: u64, + ) -> RadixCiphertext { + let zero = self.create_trivial_zero_radix(ct.blocks.len()); + if scalar == 0 || ct.blocks.is_empty() { + return zero; + } + + let b = self.key.message_modulus.0 as u64; + let n = ct.blocks.len(); + + //Propagate the carries before doing the multiplications + self.full_propagate_parallelized(ct); + let ct = &*ct; + + // key is the small scalar we multiply by + // value is the vector of blockshifts + let mut task_map = HashMap::>::new(); + + let mut b_i = 1_u64; + for i in 0..n { + let u_i = (scalar / b_i) % b; + task_map.entry(u_i).or_insert_with(Vec::new).push(i); + b_i *= b; + } + + let terms = Mutex::new(Vec::::new()); + task_map.par_iter().for_each(|(&u_i, blockshifts)| { + if u_i == 0 { + return; + } + + let blockshifts = &**blockshifts; + let min_blockshift = *blockshifts.iter().min().unwrap(); + + let mut tmp = ct.clone(); + if u_i != 1 { + tmp.blocks[0..n - min_blockshift] + .par_iter_mut() + .for_each(|ct_i| self.key.unchecked_scalar_mul_assign(ct_i, u_i as u8)); + } + + let tmp = &tmp; + blockshifts.par_iter().for_each(|&shift| { + let term = self.blockshift(tmp, shift); + terms.lock().unwrap().push(term); + }); + }); + let mut terms = terms.into_inner().unwrap(); + self.smart_binary_op_seq_parallelized(&mut terms, ServerKey::smart_add_parallelized) + .unwrap_or(zero) + } + + pub fn smart_scalar_mul_assign_parallelized(&self, ctxt: &mut RadixCiphertext, scalar: u64) { + *ctxt = self.smart_scalar_mul_parallelized(ctxt, scalar); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs new file mode 100644 index 000000000..887ca9d2b --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs @@ -0,0 +1,46 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 165; + /// let scalar = 112; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.smart_scalar_sub_parallelized(&mut ct, scalar); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg - scalar, dec); + /// ``` + pub fn smart_scalar_sub_parallelized( + &self, + ct: &mut RadixCiphertext, + scalar: u64, + ) -> RadixCiphertext { + if !self.is_scalar_sub_possible(ct, scalar) { + self.full_propagate_parallelized(ct); + } + self.unchecked_scalar_sub(ct, scalar) + } + + pub fn smart_scalar_sub_assign_parallelized(&self, ct: &mut RadixCiphertext, scalar: u64) { + if !self.is_scalar_sub_possible(ct, scalar) { + self.full_propagate_parallelized(ct); + } + self.unchecked_scalar_sub_assign(ct, scalar); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/shift.rs b/tfhe/src/integer/server_key/radix_parallel/shift.rs new file mode 100644 index 000000000..9203d032b --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/shift.rs @@ -0,0 +1,180 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically a right shift. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 128; + /// let shift = 2; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a right shift: + /// let ct_res = sks.unchecked_scalar_right_shift_parallelized(&ct, shift); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg >> shift, dec); + /// ``` + pub fn unchecked_scalar_right_shift_parallelized( + &self, + ct: &RadixCiphertext, + shift: usize, + ) -> RadixCiphertext { + let mut result = ct.clone(); + self.unchecked_scalar_right_shift_assign_parallelized(&mut result, shift); + result + } + + /// Computes homomorphically a right shift. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 18; + /// let shift = 4; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a right shift: + /// sks.unchecked_scalar_right_shift_assign_parallelized(&mut ct, shift); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg >> shift, dec); + /// ``` + pub fn unchecked_scalar_right_shift_assign_parallelized( + &self, + ct: &mut RadixCiphertext, + shift: usize, + ) { + let tmp = self.key.message_modulus.0 as f64; + + //number of bits of message + let nb_bits = tmp.log2() as usize; + + // 2^u = 2^{p*q+r} = 2^{p*(q+1)}*2^{r-p} + let quotient = shift / nb_bits; + + //p-r + let modified_remainder = nb_bits - (shift % nb_bits); + + //if r == 0 + if modified_remainder == nb_bits { + self.full_propagate_parallelized(ct); + self.blockshift_right_assign(ct, quotient); + } else { + // B/2^u = (B*2^{p-r}) / (2^{p*(q+1)}) + self.unchecked_scalar_left_shift_assign_parallelized(ct, modified_remainder); + + // We partially propagate in order to not lose information + self.partial_propagate_parallelized(ct); + self.blockshift_right_assign(ct, 1_usize); + + // We propagate the last block in order to not lose information + self.propagate_parallelized(ct, ct.blocks.len() - 2); + self.blockshift_right_assign(ct, quotient); + } + } + + /// Propagates all carries except the last one. + /// For development purpose only. + fn partial_propagate_parallelized(&self, ctxt: &mut RadixCiphertext) { + let len = ctxt.blocks.len() - 1; + for i in 0..len { + self.propagate_parallelized(ctxt, i); + } + } + + /// Computes homomorphically a left shift by a scalar. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 21; + /// let shift = 2; + /// + /// let ct1 = cks.encrypt(msg); + /// + /// // Compute homomorphically a right shift: + /// let ct_res = sks.unchecked_scalar_left_shift_parallelized(&ct1, shift); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(msg << shift, dec); + /// ``` + pub fn unchecked_scalar_left_shift_parallelized( + &self, + ct_left: &RadixCiphertext, + shift: usize, + ) -> RadixCiphertext { + let mut result = ct_left.clone(); + self.unchecked_scalar_left_shift_assign_parallelized(&mut result, shift); + result + } + + /// Computes homomorphically a left shift by a scalar. + /// + /// The result is assigned in the input ciphertext + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg = 13; + /// let shift = 2; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a right shift: + /// sks.unchecked_scalar_left_shift_assign_parallelized(&mut ct, shift); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg << shift, dec); + /// ``` + pub fn unchecked_scalar_left_shift_assign_parallelized( + &self, + ct: &mut RadixCiphertext, + shift: usize, + ) { + let tmp = 1_u64 << shift; + self.smart_scalar_mul_assign_parallelized(ct, tmp); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/sub.rs b/tfhe/src/integer/server_key/radix_parallel/sub.rs new file mode 100644 index 000000000..3846af4e5 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/sub.rs @@ -0,0 +1,101 @@ +use crate::integer::ciphertext::RadixCiphertext; +use crate::integer::ServerKey; + +impl ServerKey { + /// Computes homomorphically the subtraction between ct_left and ct_right. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg_1 = 120u8; + /// let msg_2 = 181u8; + /// + /// // Encrypt two messages: + /// let mut ctxt_1 = cks.encrypt(msg_1 as u64); + /// let mut ctxt_2 = cks.encrypt(msg_2 as u64); + /// + /// // Compute homomorphically a subtraction + /// let ct_res = sks.smart_sub_parallelized(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(msg_1.wrapping_sub(msg_2) as u64, res); + /// ``` + pub fn smart_sub_parallelized( + &self, + ctxt_left: &mut RadixCiphertext, + ctxt_right: &mut RadixCiphertext, + ) -> RadixCiphertext { + // If the ciphertext cannot be negated without exceeding the capacity of a ciphertext + if !self.is_neg_possible(ctxt_right) { + self.full_propagate_parallelized(ctxt_right); + } + + // If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_sub_possible(ctxt_left, ctxt_right) { + rayon::join( + || self.full_propagate_parallelized(ctxt_left), + || self.full_propagate_parallelized(ctxt_right), + ); + } + + let mut result = ctxt_left.clone(); + self.unchecked_sub_assign(&mut result, ctxt_right); + + result + } + + /// Computes homomorphically the subtraction between ct_left and ct_right. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // We have 4 * 2 = 8 bits of message + /// let size = 4; + /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size); + /// + /// let msg_1 = 120u8; + /// let msg_2 = 181u8; + /// + /// // Encrypt two messages: + /// let mut ctxt_1 = cks.encrypt(msg_1 as u64); + /// let mut ctxt_2 = cks.encrypt(msg_2 as u64); + /// + /// // Compute homomorphically a subtraction + /// sks.smart_sub_assign_parallelized(&mut ctxt_1, &mut ctxt_2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ctxt_1); + /// assert_eq!(msg_1.wrapping_sub(msg_2) as u64, res); + /// ``` + pub fn smart_sub_assign_parallelized( + &self, + ctxt_left: &mut RadixCiphertext, + ctxt_right: &mut RadixCiphertext, + ) { + // If the ciphertext cannot be negated without exceeding the capacity of a ciphertext + if !self.is_neg_possible(ctxt_right) { + self.full_propagate_parallelized(ctxt_right); + } + + // If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !self.is_sub_possible(ctxt_left, ctxt_right) { + rayon::join( + || self.full_propagate_parallelized(ctxt_left), + || self.full_propagate_parallelized(ctxt_right), + ); + } + + self.unchecked_sub_assign(ctxt_left, ctxt_right); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests.rs b/tfhe/src/integer/server_key/radix_parallel/tests.rs new file mode 100644 index 000000000..592d9cf25 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests.rs @@ -0,0 +1,736 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::{RadixClientKey, ServerKey}; +use crate::shortint::parameters::*; +use crate::shortint::Parameters; +use paste::paste; +use rand::Rng; + +/// Number of loop iteration within randomized tests +const NB_TEST: usize = 30; + +/// Smaller number of loop iteration within randomized test, +/// meant for test where the function tested is more expensive +const NB_TEST_SMALLER: usize = 10; +const NB_CTXT: usize = 4; + +macro_rules! create_parametrized_test{ + ($name:ident { $($param:ident),* }) => { + paste! { + $( + #[test] + fn []() { + $name($param) + } + )* + } + }; + ($name:ident)=> { + create_parametrized_test!($name + { + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_2_CARRY_2, + PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_4_CARRY_4 + }); + }; +} + +create_parametrized_test!(integer_smart_add); +create_parametrized_test!(integer_smart_add_sequence_multi_thread); +create_parametrized_test!(integer_smart_add_sequence_single_thread); +create_parametrized_test!(integer_smart_bitand); +create_parametrized_test!(integer_smart_bitor); +create_parametrized_test!(integer_smart_bitxor); +create_parametrized_test!(integer_unchecked_small_scalar_mul); +create_parametrized_test!(integer_smart_small_scalar_mul); +create_parametrized_test!(integer_smart_scalar_mul); +create_parametrized_test!(integer_unchecked_scalar_left_shift); +create_parametrized_test!(integer_unchecked_scalar_right_shift); +create_parametrized_test!(integer_smart_neg); +create_parametrized_test!(integer_smart_sub); +create_parametrized_test!(integer_unchecked_block_mul); +create_parametrized_test!(integer_smart_block_mul); +create_parametrized_test!(integer_smart_mul); +create_parametrized_test!(integer_smart_scalar_sub); +create_parametrized_test!(integer_smart_scalar_add); + +fn integer_smart_add(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let mut ct_res = sks.smart_add_parallelized(&mut ctxt_0, &mut ctxt_1); + + clear = (clear_0 + clear_1) % modulus; + + // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); + //add multiple times to raise the degree + for _ in 0..NB_TEST_SMALLER { + ct_res = sks.smart_add_parallelized(&mut ct_res, &mut ctxt_0); + clear = (clear + clear_0) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_smart_add_sequence_multi_thread(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for len in [1, 2, 15, 16, 17, 64, 65] { + for _ in 0..NB_TEST_SMALLER { + let clears = (0..len) + .map(|_| rng.gen::() % modulus) + .collect::>(); + + // encryption of integers + let mut ctxts = clears + .iter() + .copied() + .map(|clear| cks.encrypt(clear)) + .collect::>(); + + // add the ciphertexts + let ct_res = sks + .smart_binary_op_seq_parallelized(&mut ctxts, ServerKey::smart_add_parallelized) + .unwrap(); + let ct_res = cks.decrypt(&ct_res); + let clear = clears.iter().sum::() % modulus; + + assert_eq!(ct_res, clear); + } + } +} + +fn integer_smart_add_sequence_single_thread(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for len in [1, 2, 15, 16, 17] { + for _ in 0..NB_TEST_SMALLER { + let clears = (0..len) + .map(|_| rng.gen::() % modulus) + .collect::>(); + + // encryption of integers + let mut ctxts = clears + .iter() + .copied() + .map(|clear| cks.encrypt(clear)) + .collect::>(); + + // add the ciphertexts + let threadpool = rayon::ThreadPoolBuilder::new() + .num_threads(1) + .build() + .unwrap(); + + let ct_res = threadpool.install(|| { + sks.smart_binary_op_seq_parallelized(&mut ctxts, ServerKey::smart_add_parallelized) + .unwrap() + }); + let ct_res = cks.decrypt(&ct_res); + let clear = clears.iter().sum::() % modulus; + + assert_eq!(ct_res, clear); + } + } +} + +fn integer_smart_bitand(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let mut ct_res = sks.smart_bitand_parallelized(&mut ctxt_0, &mut ctxt_1); + + clear = clear_0 & clear_1; + + for _ in 0..NB_TEST_SMALLER { + let clear_2 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_2 = cks.encrypt(clear_2); + + ct_res = sks.smart_bitand_parallelized(&mut ct_res, &mut ctxt_2); + clear &= clear_2; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_smart_bitor(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let mut ct_res = sks.smart_bitor_parallelized(&mut ctxt_0, &mut ctxt_1); + + clear = (clear_0 | clear_1) % modulus; + + for _ in 0..1 { + let clear_2 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_2 = cks.encrypt(clear_2); + + ct_res = sks.smart_bitor_parallelized(&mut ct_res, &mut ctxt_2); + clear = (clear | clear_2) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_smart_bitxor(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let mut ct_res = sks.smart_bitxor_parallelized(&mut ctxt_0, &mut ctxt_1); + + clear = (clear_0 ^ clear_1) % modulus; + + for _ in 0..NB_TEST_SMALLER { + let clear_2 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_2 = cks.encrypt(clear_2); + + ct_res = sks.smart_bitxor_parallelized(&mut ct_res, &mut ctxt_2); + clear = (clear ^ clear_2) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_unchecked_small_scalar_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let scalar_modulus = param.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % scalar_modulus; + + // encryption of an integer + let ct = cks.encrypt(clear); + + // add the two ciphertexts + let ct_res = sks.unchecked_small_scalar_mul_parallelized(&ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear * scalar) % modulus, dec_res); + } +} + +fn integer_smart_small_scalar_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let scalar_modulus = param.message_modulus.0 as u64; + + let mut clear_res; + for _ in 0..NB_TEST_SMALLER { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % scalar_modulus; + + // encryption of an integer + let mut ct = cks.encrypt(clear); + + let mut ct_res = sks.smart_small_scalar_mul_parallelized(&mut ct, scalar); + + clear_res = clear * scalar; + for _ in 0..NB_TEST_SMALLER { + // scalar multiplication + ct_res = sks.smart_small_scalar_mul_parallelized(&mut ct_res, scalar); + clear_res *= scalar; + } + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_res % modulus, dec_res); + } +} + +fn integer_smart_scalar_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % modulus; + + // encryption of an integer + let mut ct = cks.encrypt(clear); + + // scalar mul + let ct_res = sks.smart_scalar_mul_parallelized(&mut ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear * scalar) % modulus, dec_res); + } +} + +fn integer_unchecked_scalar_left_shift(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + //Nb of bits to shift + let tmp_f64 = param.message_modulus.0 as f64; + let nb_bits = tmp_f64.log2().floor() as usize * NB_CTXT; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % nb_bits; + + // encryption of an integer + let ct = cks.encrypt(clear); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_left_shift_parallelized(&ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear << scalar) % modulus, dec_res); + } +} + +fn integer_unchecked_scalar_right_shift(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + //Nb of bits to shift + let tmp_f64 = param.message_modulus.0 as f64; + let nb_bits = tmp_f64.log2().floor() as usize * NB_CTXT; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % nb_bits; + + // encryption of an integer + let ct = cks.encrypt(clear); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_right_shift_parallelized(&ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear >> scalar, dec_res); + } +} + +fn integer_smart_neg(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST { + // Define the cleartexts + let clear = rng.gen::() % modulus; + + // Encrypt the integers + let mut ctxt = cks.encrypt(clear); + + // Negates the ctxt + let ct_tmp = sks.smart_neg_parallelized(&mut ctxt); + + // Decrypt the result + let dec = cks.decrypt(&ct_tmp); + + // Check the correctness + let clear_result = clear.wrapping_neg() % modulus; + + assert_eq!(clear_result, dec); + } +} + +fn integer_smart_sub(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST_SMALLER { + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + // Encrypt the integers + let ctxt_1 = cks.encrypt(clear1); + let mut ctxt_2 = cks.encrypt(clear2); + + let mut res = ctxt_1.clone(); + let mut clear = clear1; + + //subtract multiple times to raise the degree + for _ in 0..NB_TEST_SMALLER { + res = sks.smart_sub_parallelized(&mut res, &mut ctxt_2); + clear = (clear - clear2) % modulus; + // println!("clear = {}, clear2 = {}", clear, cks.decrypt(&res)); + } + let dec = cks.decrypt(&res); + + // Check the correctness + assert_eq!(clear, dec); + } +} + +fn integer_unchecked_block_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let block_modulus = param.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % block_modulus; + + // encryption of an integer + let ct_zero = cks.encrypt(clear_0); + + // encryption of an integer + let ct_one = cks.encrypt_one_block(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_block_mul_parallelized(&ct_zero, &ct_one, 0); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 * clear_1) % modulus, dec_res); + } +} + +fn integer_smart_block_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let block_modulus = param.message_modulus.0 as u64; + + for _ in 0..5 { + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % block_modulus; + + // Encrypt the integers + let ctxt_1 = cks.encrypt(clear1); + let ctxt_2 = cks.encrypt_one_block(clear2); + + let mut res = ctxt_1.clone(); + let mut clear = clear1; + + res = sks.smart_block_mul_parallelized(&mut res, &ctxt_2, 0); + for _ in 0..5 { + res = sks.smart_block_mul_parallelized(&mut res, &ctxt_2, 0); + clear = (clear * clear2) % modulus; + } + let dec = cks.decrypt(&res); + + clear = (clear * clear2) % modulus; + + // Check the correctness + assert_eq!(clear, dec); + } +} + +fn integer_smart_mul(param: Parameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + //RNG + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TEST_SMALLER { + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + // println!("clear1 = {}, clear2 = {}", clear1, clear2); + + // Encrypt the integers + let ctxt_1 = cks.encrypt(clear1); + let mut ctxt_2 = cks.encrypt(clear2); + + let mut res = ctxt_1.clone(); + let mut clear = clear1; + + res = sks.smart_mul_parallelized(&mut res, &mut ctxt_2); + for _ in 0..5 { + res = sks.smart_mul_parallelized(&mut res, &mut ctxt_2); + clear = (clear * clear2) % modulus; + } + let dec = cks.decrypt(&res); + + clear = (clear * clear2) % modulus; + + // Check the correctness + assert_eq!(clear, dec); + } +} + +fn integer_smart_scalar_add(param: Parameters) { + // generate the server-client key set + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + // RNG + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // add the two ciphertexts + let mut ct_res = sks.smart_scalar_add_parallelized(&mut ctxt_0, clear_1); + + clear = (clear_0 + clear_1) % modulus; + + // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); + //add multiple times to raise the degree + for _ in 0..NB_TEST_SMALLER { + ct_res = sks.smart_scalar_add_parallelized(&mut ct_res, clear_1); + clear = (clear + clear_1) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + // assert + assert_eq!(clear, dec_res); + } + } +} + +fn integer_smart_scalar_sub(param: Parameters) { + // generate the server-client key set + let (cks, sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + let mut clear; + + // RNG + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // add the two ciphertexts + let mut ct_res = sks.smart_scalar_sub_parallelized(&mut ctxt_0, clear_1); + + clear = (clear_0 - clear_1) % modulus; + + // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); + //add multiple times to raise the degree + for _ in 0..NB_TEST_SMALLER { + ct_res = sks.smart_scalar_sub_parallelized(&mut ct_res, clear_1); + clear = (clear - clear_1) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + // assert + assert_eq!(clear, dec_res); + } + } +} diff --git a/tfhe/src/integer/tests.rs b/tfhe/src/integer/tests.rs new file mode 100644 index 000000000..7caf23e19 --- /dev/null +++ b/tfhe/src/integer/tests.rs @@ -0,0 +1,21 @@ +macro_rules! create_parametrized_test{ + ($name:ident { $($param:ident),* }) => { + ::paste::paste! { + $( + #[test] + fn []() { + $name($param) + } + )* + } + }; + ($name:ident)=> { + create_parametrized_test!($name + { + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_2_CARRY_2, + PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_4_CARRY_4 + }); + }; +} diff --git a/tfhe/src/integer/wopbs/mod.rs b/tfhe/src/integer/wopbs/mod.rs new file mode 100644 index 000000000..508418d52 --- /dev/null +++ b/tfhe/src/integer/wopbs/mod.rs @@ -0,0 +1,992 @@ +//! Module with the definition of the WopbsKey (WithOut padding PBS Key). +//! +//! This module implements the generation of another server public key, which allows to compute +//! an alternative version of the programmable bootstrapping. This does not require the use of a +//! bit of padding. +#[cfg(test)] +mod test; + +use crate::core_crypto::prelude::*; +use crate::integer::client_key::utils::i_crt; +use crate::integer::{ClientKey, CrtCiphertext, IntegerCiphertext, RadixCiphertext, ServerKey}; +use crate::shortint::ciphertext::Degree; +use rayon::prelude::*; + +use crate::shortint::Parameters; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Serialize, Deserialize)] +pub struct WopbsKey { + wopbs_key: crate::shortint::wopbs::WopbsKey, +} + +/// ```rust +/// use tfhe::integer::wopbs::{decode_radix, encode_radix}; +/// +/// let val = 11; +/// let basis = 2; +/// let nb_block = 5; +/// let radix = encode_radix(val, basis, nb_block); +/// +/// assert_eq!(val, decode_radix(radix, basis)); +/// ``` +pub fn encode_radix(val: u64, basis: u64, nb_block: u64) -> Vec { + let mut output = vec![]; + //Bits of message put to 1éfé + let mask = basis - 1; + + let mut power = 1_u64; + //Put each decomposition into a new ciphertext + for _ in 0..nb_block { + let mut decomp = val & (mask * power); + decomp /= power; + + // fill the vector with the message moduli + output.push(decomp); + + //modulus to the power i + power *= basis; + } + output +} + +pub fn encode_crt(val: u64, basis: &[u64]) -> Vec { + let mut output = vec![]; + //Put each decomposition into a new ciphertext + for i in basis { + output.push(val % i); + } + output +} + +//Concatenate two ciphertexts in one +//Used to compute bivariate wopbs +fn ciphertext_concatenation(ct1: &T, ct2: &T) -> T +where + T: IntegerCiphertext, +{ + let mut new_blocks = ct1.blocks().to_vec(); + new_blocks.extend_from_slice(ct2.blocks()); + T::from_blocks(new_blocks) +} + +pub fn encode_mix_radix(mut val: u64, basis: &[u64], modulus: u64) -> Vec { + let mut output = vec![]; + for basis in basis.iter() { + output.push(val % modulus); + val -= val % modulus; + let tmp = (val % (1 << basis)) >> (f64::log2(modulus as f64) as u64); + val >>= basis; + val += tmp; + } + output +} + +// Example: val = 5 = 0b101 , basis = [1,2] -> output = [1, 1] +/// ```rust +/// use tfhe::integer::wopbs::split_value_according_to_bit_basis; +/// // Generate the client key and the server key: +/// let val = 5; +/// let basis = vec![1, 2]; +/// assert_eq!(vec![1, 2], split_value_according_to_bit_basis(val, &basis)); +/// ``` +pub fn split_value_according_to_bit_basis(value: u64, basis: &[u64]) -> Vec { + let mut output = vec![]; + let mut tmp = value; + let mask = 1; + + for i in basis { + let mut tmp_output = 0; + for j in 0..*i { + let val = tmp & mask; + tmp_output += val << j; + tmp >>= 1; + } + output.push(tmp_output); + } + output +} + +/// ```rust +/// use tfhe::integer::wopbs::{decode_radix, encode_radix}; +/// +/// let val = 11; +/// let basis = 2; +/// let nb_block = 5; +/// assert_eq!(val, decode_radix(encode_radix(val, basis, nb_block), basis)); +/// ``` +pub fn decode_radix(val: Vec, basis: u64) -> u64 { + let mut result = 0_u64; + let mut shift = 1_u64; + for v_i in val.iter() { + //decrypt the component i of the integer and multiply it by the radix product + let tmp = v_i.wrapping_mul(shift); + + // update the result + result = result.wrapping_add(tmp); + + // update the shift for the next iteration + shift = shift.wrapping_mul(basis); + } + result +} + +impl From for WopbsKey { + fn from(wopbs_key: crate::shortint::wopbs::WopbsKey) -> Self { + Self { wopbs_key } + } +} + +impl WopbsKey { + /// Generates the server key required to compute a WoPBS from the client and the server keys. + /// # Example + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::wopbs::*; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_1_CARRY_1; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(&PARAM_MESSAGE_1_CARRY_1); + /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_1_CARRY_1); + /// ``` + pub fn new_wopbs_key(cks: &ClientKey, sks: &ServerKey, parameters: &Parameters) -> WopbsKey { + WopbsKey { + wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key( + &cks.key, &sks.key, parameters, + ), + } + } + + pub fn new_from_shortint(wopbskey: &crate::shortint::wopbs::WopbsKey) -> WopbsKey { + let key = wopbskey.clone(); + WopbsKey { wopbs_key: key } + } + + pub fn new_wopbs_key_only_for_wopbs(cks: &ClientKey, sks: &ServerKey) -> WopbsKey { + WopbsKey { + wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs( + &cks.key, &sks.key, + ), + } + } + + /// Computes the WoP-PBS given the luts. + /// + /// This works for both RadixCiphertext and CrtCiphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::wopbs::*; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let nb_block = 3; + /// //Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let mut moduli = 1_u64; + /// for _ in 0..nb_block { + /// moduli *= cks.parameters().message_modulus.0 as u64; + /// } + /// let clear = 42 % moduli; + /// let ct = cks.encrypt_radix(clear as u64, nb_block); + /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct); + /// let lut = wopbs_key.generate_lut_radix(&ct, |x| x); + /// let ct_res = wopbs_key.wopbs(&ct, &lut); + /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + /// let res = cks.decrypt_radix(&ct_res); + /// + /// assert_eq!(res, clear); + /// ``` + pub fn wopbs(&self, ct_in: &T, lut: &[Vec]) -> T + where + T: IntegerCiphertext, + { + let mut extracted_bits_blocks = Vec::with_capacity(ct_in.blocks().len()); + // Extraction of each bit for each block + for block in ct_in.blocks().iter() { + let delta = (1_usize << 63) + / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0); + let delta_log = DeltaLog(f64::log2(delta as f64) as usize); + let nb_bit_to_extract = f64::log2((block.degree.0 + 1) as f64).ceil() as usize; + + let extracted_bits = self + .wopbs_key + .extract_bits(delta_log, block, nb_bit_to_extract); + + extracted_bits_blocks.push(extracted_bits); + } + + extracted_bits_blocks.reverse(); + let vec_ct_out = self + .wopbs_key + .circuit_bootstrapping_vertical_packing(lut, &extracted_bits_blocks); + + let mut ct_vec_out = vec![]; + for (block, block_out) in ct_in.blocks().iter().zip(vec_ct_out.into_iter()) { + ct_vec_out.push(crate::shortint::Ciphertext { + ct: block_out, + degree: Degree(block.message_modulus.0 - 1), + message_modulus: block.message_modulus, + carry_modulus: block.carry_modulus, + }); + } + T::from_blocks(ct_vec_out) + } + + /// # Example + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::wopbs::WopbsKey; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// + /// let nb_block = 3; + /// //Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(&WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let mut moduli = 1_u64; + /// for _ in 0..nb_block { + /// moduli *= cks.parameters().message_modulus.0 as u64; + /// } + /// let clear = 15 % moduli; + /// let ct = cks.encrypt_radix_without_padding(clear as u64, nb_block); + /// let lut = wopbs_key.generate_lut_radix_without_padding(&ct, |x| 2 * x); + /// let ct_res = wopbs_key.wopbs_without_padding(&ct, &lut); + /// let res = cks.decrypt_radix_without_padding(&ct_res); + /// + /// assert_eq!(res, (clear * 2) % moduli) + /// ``` + pub fn wopbs_without_padding(&self, ct_in: &T, lut: &[Vec]) -> T + where + T: IntegerCiphertext, + { + let mut extracted_bits_blocks = Vec::with_capacity(ct_in.blocks().len()); + let mut ct_in = ct_in.clone(); + // Extraction of each bit for each block + for block in ct_in.blocks_mut().iter_mut() { + let delta = (1_usize << 63) / (block.message_modulus.0 * block.carry_modulus.0 / 2); + let delta_log = DeltaLog(f64::log2(delta as f64) as usize); + let nb_bit_to_extract = + f64::log2((block.message_modulus.0 * block.carry_modulus.0) as f64) as usize; + + let extracted_bits = self + .wopbs_key + .extract_bits(delta_log, block, nb_bit_to_extract); + extracted_bits_blocks.push(extracted_bits); + } + + extracted_bits_blocks.reverse(); + + let vec_ct_out = self + .wopbs_key + .circuit_bootstrapping_vertical_packing(lut, &extracted_bits_blocks); + + let mut ct_vec_out = vec![]; + for (block, block_out) in ct_in.blocks().iter().zip(vec_ct_out.into_iter()) { + ct_vec_out.push(crate::shortint::Ciphertext { + ct: block_out, + degree: Degree(block.message_modulus.0 - 1), + message_modulus: block.message_modulus, + carry_modulus: block.carry_modulus, + }); + } + T::from_blocks(ct_vec_out) + } + + /// WOPBS for native CRT + /// # Example + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::parameters::PARAM_4_BITS_5_BLOCKS; + /// use tfhe::integer::wopbs::WopbsKey; + /// + /// let basis: Vec = vec![9, 11]; + /// + /// let param = PARAM_4_BITS_5_BLOCKS; + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(¶m); + /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// + /// let mut msg_space = 1; + /// for modulus in basis.iter() { + /// msg_space *= modulus; + /// } + /// let clear = 42 % msg_space; // Encrypt the integers + /// let mut ct = cks.encrypt_native_crt(clear, basis.clone()); + /// let lut = wopbs_key.generate_lut_native_crt(&ct, |x| x); + /// let ct_res = wopbs_key.wopbs_native_crt(&mut ct, &lut); + /// let res = cks.decrypt_native_crt(&ct_res); + /// assert_eq!(res, clear); + /// ``` + pub fn wopbs_native_crt(&self, ct1: &CrtCiphertext, lut: &[Vec]) -> CrtCiphertext { + self.circuit_bootstrap_vertical_packing_native_crt(&[ct1.clone()], lut) + } + + /// # Example + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::wopbs::*; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let nb_block = 3; + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// //Generate wopbs_v0 key /// + /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let mut moduli = 1_u64; + /// for _ in 0..nb_block { + /// moduli *= cks.parameters().message_modulus.0 as u64; + /// } + /// let clear1 = 42 % moduli; + /// let clear2 = 24 % moduli; + /// let ct1 = cks.encrypt_radix(clear1 as u64, nb_block); + /// let ct2 = cks.encrypt_radix(clear2 as u64, nb_block); + /// + /// let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1); + /// let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct2); + /// let lut = wopbs_key.generate_lut_bivariate_radix(&ct1, &ct2, |x, y| 2 * x * y); + /// let ct_res = wopbs_key.bivariate_wopbs_with_degree(&ct1, &ct2, &lut); + /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + /// let res = cks.decrypt_radix(&ct_res); + /// + /// assert_eq!(res, (2 * clear1 * clear2) % moduli); + /// ``` + pub fn bivariate_wopbs_with_degree(&self, ct1: &T, ct2: &T, lut: &[Vec]) -> T + where + T: IntegerCiphertext, + { + let ct = ciphertext_concatenation(ct1, ct2); + self.wopbs(&ct, lut) + } + + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::wopbs::*; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let nb_block = 3; + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// //Generate wopbs_v0 key /// + /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let mut moduli = 1_u64; + /// for _ in 0..nb_block { + /// moduli *= cks.parameters().message_modulus.0 as u64; + /// } + /// let clear = 42 % moduli; + /// let ct = cks.encrypt_radix(clear as u64, nb_block); + /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct); + /// let lut = wopbs_key.generate_lut_radix(&ct, |x| 2 * x); + /// let ct_res = wopbs_key.wopbs(&ct, &lut); + /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + /// let res = cks.decrypt_radix(&ct_res); + /// + /// assert_eq!(res, (2 * clear) % moduli); + /// ``` + pub fn generate_lut_radix(&self, ct: &T, f: F) -> Vec> + where + F: Fn(u64) -> u64, + T: IntegerCiphertext, + { + let mut total_bit = 0; + let block_nb = ct.blocks().len(); + let mut modulus = 1; + + //This contains the basis of each block depending on the degree + let mut vec_deg_basis = vec![]; + + for (i, deg) in ct.moduli().iter().zip(ct.blocks().iter()) { + modulus *= i; + let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64; + vec_deg_basis.push(b); + total_bit += b; + } + + let mut lut_size = 1 << total_bit; + if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 { + lut_size = self.wopbs_key.param.polynomial_size.0; + } + let mut vec_lut = vec![vec![0; lut_size]; ct.blocks().len()]; + + let basis = ct.moduli()[0]; + let delta: u64 = (1 << 63) + / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0) + as u64; + + for lut_index_val in 0..(1 << total_bit) { + let encoded_with_deg_val = encode_mix_radix(lut_index_val, &vec_deg_basis, basis); + let decoded_val = decode_radix(encoded_with_deg_val.clone(), basis); + let f_val = f(decoded_val % modulus) % modulus; + let encoded_f_val = encode_radix(f_val, basis, block_nb as u64); + for lut_number in 0..block_nb { + vec_lut[lut_number][lut_index_val as usize] = encoded_f_val[lut_number] * delta; + } + } + vec_lut + } + + /// # Example + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::wopbs::WopbsKey; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let nb_block = 3; + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// //Generate wopbs_v0 key + /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let mut moduli = 1_u64; + /// for _ in 0..nb_block { + /// moduli *= cks.parameters().message_modulus.0 as u64; + /// } + /// let clear = 15 % moduli; + /// let ct = cks.encrypt_radix_without_padding(clear as u64, nb_block); + /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct); + /// let lut = wopbs_key.generate_lut_radix_without_padding(&ct, |x| 2 * x); + /// let ct_res = wopbs_key.wopbs_without_padding(&ct, &lut); + /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + /// let res = cks.decrypt_radix_without_padding(&ct_res); + /// + /// assert_eq!(res, (clear * 2) % moduli) + /// ``` + pub fn generate_lut_radix_without_padding(&self, ct: &T, f: F) -> Vec> + where + F: Fn(u64) -> u64, + T: IntegerCiphertext, + { + let log_message_modulus = f64::log2((self.wopbs_key.param.message_modulus.0) as f64) as u64; + let log_carry_modulus = f64::log2((self.wopbs_key.param.carry_modulus.0) as f64) as u64; + let log_basis = log_message_modulus + log_carry_modulus; + let delta = 64 - log_basis; + let nb_block = ct.blocks().len(); + let poly_size = self.wopbs_key.param.polynomial_size.0; + let mut lut_size = 1 << (nb_block * log_basis as usize); + if lut_size < poly_size { + lut_size = poly_size; + } + let mut vec_lut = vec![vec![0; lut_size]; nb_block]; + + for index in 0..lut_size { + // find the value represented by the index + let mut value = 0; + let mut tmp_index = index; + for i in 0..nb_block as u64 { + let tmp = tmp_index % (1 << (log_basis * (i + 1))); + tmp_index -= tmp; + value += tmp >> (log_carry_modulus * i); + } + + // fill the LUTs + for (block_index, lut_block) in vec_lut.iter_mut().enumerate().take(nb_block) { + lut_block[index] = ((f(value as u64) >> (log_carry_modulus * block_index as u64)) + % (1 << log_message_modulus)) + << delta + } + } + vec_lut + } + + /// generate lut for native CRT + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::parameters::PARAM_4_BITS_5_BLOCKS; + /// use tfhe::integer::wopbs::WopbsKey; + /// + /// let basis: Vec = vec![9, 11]; + /// + /// let param = PARAM_4_BITS_5_BLOCKS; + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(¶m); + /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// + /// let mut msg_space = 1; + /// for modulus in basis.iter() { + /// msg_space *= modulus; + /// } + /// let clear = 42 % msg_space; // Encrypt the integers + /// let mut ct = cks.encrypt_native_crt(clear, basis.clone()); + /// let lut = wopbs_key.generate_lut_native_crt(&ct, |x| x); + /// let ct_res = wopbs_key.wopbs_native_crt(&mut ct, &lut); + /// let res = cks.decrypt_native_crt(&ct_res); + /// assert_eq!(res, clear); + /// ``` + pub fn generate_lut_native_crt(&self, ct: &CrtCiphertext, f: F) -> Vec> + where + F: Fn(u64) -> u64, + { + let mut bit = vec![]; + let mut total_bit = 0; + let mut modulus = 1; + let basis: Vec<_> = ct.moduli(); + + for i in basis.iter() { + modulus *= i; + let b = f64::log2(*i as f64).ceil() as u64; + total_bit += b; + bit.push(b); + } + let mut lut_size = 1 << total_bit; + if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 { + lut_size = self.wopbs_key.param.polynomial_size.0; + } + let mut vec_lut = vec![vec![0; lut_size]; basis.len()]; + + for value in 0..modulus { + let mut index_lut = 0; + let mut tmp = 1; + for (base, bit) in basis.iter().zip(bit.iter()) { + index_lut += (((value % base) << bit) / base) * tmp; + tmp <<= bit; + } + for (j, b) in basis.iter().enumerate() { + vec_lut[j][index_lut as usize] = + (((f(value) % b) as u128 * (1 << 64)) / *b as u128) as u64 + } + } + vec_lut + } + + /// generate LUt for crt + /// # Example + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::wopbs::*; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_3_CARRY_3; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_3_CARRY_3; + /// + /// let basis: Vec = vec![5, 7]; + /// let nb_block = basis.len(); + /// + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_3_CARRY_3); + /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_3_CARRY_3); + /// + /// let mut msg_space = 1; + /// for modulus in basis.iter() { + /// msg_space *= modulus; + /// } + /// let clear = 42 % msg_space; + /// let ct = cks.encrypt_crt(clear, basis.clone()); + /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct); + /// let lut = wopbs_key.generate_lut_crt(&ct, |x| x); + /// let ct_res = wopbs_key.wopbs(&ct, &lut); + /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + /// let res = cks.decrypt_crt(&ct_res); + /// assert_eq!(res, clear); + /// ``` + pub fn generate_lut_crt(&self, ct: &CrtCiphertext, f: F) -> Vec> + where + F: Fn(u64) -> u64, + { + let mut bit = vec![]; + let mut total_bit = 0; + let mut modulus = 1; + let basis = ct.moduli(); + + for (i, deg) in basis.iter().zip(ct.blocks.iter()) { + modulus *= i; + let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64; + total_bit += b; + bit.push(b); + } + let mut lut_size = 1 << total_bit; + if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 { + lut_size = self.wopbs_key.param.polynomial_size.0; + } + let mut vec_lut = vec![vec![0; lut_size]; basis.len()]; + + for i in 0..(1 << total_bit) { + let mut value = i; + for (j, block) in ct.blocks.iter().enumerate() { + let deg = f64::log2((block.degree.0 + 1) as f64).ceil() as u64; + let delta: u64 = (1 << 63) + / (self.wopbs_key.param.message_modulus.0 + * self.wopbs_key.param.carry_modulus.0) as u64; + vec_lut[j][i as usize] = + ((f((value % (1 << deg)) % block.message_modulus.0 as u64)) + % block.message_modulus.0 as u64) + * delta; + value >>= deg; + } + } + vec_lut + } + + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::wopbs::*; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let nb_block = 3; + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2); + /// + /// //Generate wopbs_v0 key /// + /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let mut moduli = 1_u64; + /// for _ in 0..nb_block { + /// moduli *= cks.parameters().message_modulus.0 as u64; + /// } + /// let clear1 = 42 % moduli; + /// let clear2 = 24 % moduli; + /// let ct1 = cks.encrypt_radix(clear1 as u64, nb_block); + /// let ct2 = cks.encrypt_radix(clear2 as u64, nb_block); + /// + /// let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1); + /// let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct2); + /// let lut = wopbs_key.generate_lut_bivariate_radix(&ct1, &ct2, |x, y| 2 * x * y); + /// let ct_res = wopbs_key.bivariate_wopbs_with_degree(&ct1, &ct2, &lut); + /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + /// let res = cks.decrypt_radix(&ct_res); + /// + /// assert_eq!(res, (2 * clear1 * clear2) % moduli); + /// ``` + pub fn generate_lut_bivariate_radix( + &self, + ct1: &RadixCiphertext, + ct2: &RadixCiphertext, + f: F, + ) -> Vec> + where + F: Fn(u64, u64) -> u64, + { + let mut nb_bit_to_extract = vec![0; 2]; + let block_nb = ct1.blocks.len(); + //ct2 & ct1 should have the same basis + let basis = ct1.moduli(); + + //This contains the basis of each block depending on the degree + let mut vec_deg_basis = vec![vec![]; 2]; + + let mut modulus = 1; + for (ct_num, ct) in [ct1, ct2].iter().enumerate() { + modulus = 1; + for deg in ct.blocks.iter() { + modulus *= self.wopbs_key.param.message_modulus.0 as u64; + let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64; + vec_deg_basis[ct_num].push(b); + nb_bit_to_extract[ct_num] += b; + } + } + + let total_bit: u64 = nb_bit_to_extract.iter().sum(); + + let mut lut_size = 1 << total_bit; + if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 { + lut_size = self.wopbs_key.param.polynomial_size.0; + } + let mut vec_lut = vec![vec![0; lut_size]; basis.len()]; + let basis = ct1.moduli()[0]; + + let delta: u64 = (1 << 63) + / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0) + as u64; + + for lut_index_val in 0..(1 << total_bit) { + let split = vec![ + lut_index_val % (1 << nb_bit_to_extract[0]), + lut_index_val >> nb_bit_to_extract[0], + ]; + let mut decoded_val = vec![0; 2]; + for i in 0..2 { + let encoded_with_deg_val = encode_mix_radix(split[i], &vec_deg_basis[i], basis); + decoded_val[i] = decode_radix(encoded_with_deg_val.clone(), basis); + } + let f_val = f(decoded_val[0] % modulus, decoded_val[1] % modulus) % modulus; + let encoded_f_val = encode_radix(f_val, basis, block_nb as u64); + for lut_number in 0..block_nb { + vec_lut[lut_number][lut_index_val as usize] = encoded_f_val[lut_number] * delta; + } + } + vec_lut + } + + /// generate bivariate LUT for 'fake' CRT + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::wopbs::*; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_3_CARRY_3; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_3_CARRY_3; + /// + /// let basis: Vec = vec![5, 7]; + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_3_CARRY_3); + /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_3_CARRY_3); + /// + /// let mut msg_space = 1; + /// for modulus in basis.iter() { + /// msg_space *= modulus; + /// } + /// let clear1 = 42 % msg_space; // Encrypt the integers + /// let clear2 = 24 % msg_space; // Encrypt the integers + /// let ct1 = cks.encrypt_crt(clear1, basis.clone()); + /// let ct2 = cks.encrypt_crt(clear2, basis.clone()); + /// + /// let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1); + /// let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct2); + /// + /// let lut = wopbs_key.generate_lut_bivariate_crt(&ct1, &ct2, |x, y| x * y * 2); + /// let ct_res = wopbs_key.bivariate_wopbs_with_degree(&ct1, &ct2, &lut); + /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + /// let res = cks.decrypt_crt(&ct_res); + /// assert_eq!(res, (clear1 * clear2 * 2) % msg_space); + /// ``` + pub fn generate_lut_bivariate_crt( + &self, + ct1: &CrtCiphertext, + ct2: &CrtCiphertext, + f: F, + ) -> Vec> + where + F: Fn(u64, u64) -> u64, + { + let mut bit = vec![]; + let mut nb_bit_to_extract = vec![0; 2]; + let mut modulus = 1; + + //ct2 & ct1 should have the same basis + let basis = ct1.moduli(); + + for (ct_num, ct) in [ct1, ct2].iter().enumerate() { + for (i, deg) in basis.iter().zip(ct.blocks.iter()) { + modulus *= i; + let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64; + nb_bit_to_extract[ct_num] += b; + bit.push(b); + } + } + + let total_bit: u64 = nb_bit_to_extract.iter().sum(); + + let mut lut_size = 1 << total_bit; + if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 { + lut_size = self.wopbs_key.param.polynomial_size.0; + } + let mut vec_lut = vec![vec![0; lut_size]; basis.len()]; + + let delta: u64 = (1 << 63) + / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0) + as u64; + + for index in 0..(1 << total_bit) { + let mut split = encode_radix(index, 1 << nb_bit_to_extract[0], 2); + let mut crt_value = vec![vec![0; ct1.blocks.len()]; 2]; + for (j, base) in basis.iter().enumerate().take(ct1.blocks.len()) { + let deg_1 = f64::log2((ct1.blocks[j].degree.0 + 1) as f64).ceil() as u64; + let deg_2 = f64::log2((ct2.blocks[j].degree.0 + 1) as f64).ceil() as u64; + crt_value[0][j] = (split[0] % (1 << deg_1)) % base; + crt_value[1][j] = (split[1] % (1 << deg_2)) % base; + split[0] >>= deg_1; + split[1] >>= deg_2; + } + let value_1 = i_crt(&ct1.moduli(), &crt_value[0]); + let value_2 = i_crt(&ct2.moduli(), &crt_value[1]); + for (j, current_mod) in basis.iter().enumerate() { + let value = f(value_1, value_2) % current_mod; + vec_lut[j][index as usize] = (value % current_mod) * delta; + } + } + + vec_lut + } + + /// generate bivariate LUT for 'true' CRT + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::parameters::PARAM_4_BITS_5_BLOCKS; + /// use tfhe::integer::wopbs::WopbsKey; + /// + /// let basis: Vec = vec![9, 11]; + /// + /// let param = PARAM_4_BITS_5_BLOCKS; + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(¶m); + /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// + /// let mut msg_space = 1; + /// for modulus in basis.iter() { + /// msg_space *= modulus; + /// } + /// let clear1 = 42 % msg_space; + /// let clear2 = 24 % msg_space; + /// let mut ct1 = cks.encrypt_native_crt(clear1, basis.clone()); + /// let mut ct2 = cks.encrypt_native_crt(clear2, basis.clone()); + /// let lut = wopbs_key.generate_lut_bivariate_native_crt(&ct1, |x, y| x * y * 2); + /// let ct_res = wopbs_key.bivariate_wopbs_native_crt(&mut ct1, &mut ct2, &lut); + /// let res = cks.decrypt_native_crt(&ct_res); + /// assert_eq!(res, (clear1 * clear2 * 2) % msg_space); + /// ``` + pub fn generate_lut_bivariate_native_crt(&self, ct_1: &CrtCiphertext, f: F) -> Vec> + where + F: Fn(u64, u64) -> u64, + { + let mut bit = vec![]; + let mut total_bit = 0; + let mut modulus = 1; + let basis = ct_1.moduli(); + for i in basis.iter() { + modulus *= i; + let b = f64::log2(*i as f64).ceil() as u64; + total_bit += b; + bit.push(b); + } + let mut lut_size = 1 << (2 * total_bit); + if 1 << (2 * total_bit) < self.wopbs_key.param.polynomial_size.0 as u64 { + lut_size = self.wopbs_key.param.polynomial_size.0; + } + let mut vec_lut = vec![vec![0; lut_size]; basis.len()]; + + for value in 0..1 << (2 * total_bit) { + let value_1 = value % (1 << total_bit); + let value_2 = value >> total_bit; + let mut index_lut_1 = 0; + let mut index_lut_2 = 0; + let mut tmp = 1; + for (base, bit) in basis.iter().zip(bit.iter()) { + index_lut_1 += (((value_1 % base) << bit) / base) * tmp; + index_lut_2 += (((value_2 % base) << bit) / base) * tmp; + tmp <<= bit; + } + let index = (index_lut_2 << total_bit) + (index_lut_1); + for (j, b) in basis.iter().enumerate() { + vec_lut[j][index as usize] = + (((f(value_1, value_2) % b) as u128 * (1 << 64)) / *b as u128) as u64 + } + } + vec_lut + } + + /// bivariate WOPBS for native CRT + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys; + /// use tfhe::integer::parameters::PARAM_4_BITS_5_BLOCKS; + /// use tfhe::integer::wopbs::WopbsKey; + /// + /// let basis: Vec = vec![9, 11]; + /// + /// let param = PARAM_4_BITS_5_BLOCKS; + /// //Generate the client key and the server key: + /// let (cks, sks) = gen_keys(¶m); + /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// + /// let mut msg_space = 1; + /// for modulus in basis.iter() { + /// msg_space *= modulus; + /// } + /// let clear1 = 42 % msg_space; + /// let clear2 = 24 % msg_space; + /// let mut ct1 = cks.encrypt_native_crt(clear1, basis.clone()); + /// let mut ct2 = cks.encrypt_native_crt(clear2, basis.clone()); + /// let lut = wopbs_key.generate_lut_bivariate_native_crt(&ct1, |x, y| x * y * 2); + /// let ct_res = wopbs_key.bivariate_wopbs_native_crt(&mut ct1, &mut ct2, &lut); + /// let res = cks.decrypt_native_crt(&ct_res); + /// assert_eq!(res, (clear1 * clear2 * 2) % msg_space); + /// ``` + pub fn bivariate_wopbs_native_crt( + &self, + ct1: &CrtCiphertext, + ct2: &CrtCiphertext, + lut: &[Vec], + ) -> CrtCiphertext { + self.circuit_bootstrap_vertical_packing_native_crt(&[ct1.clone(), ct2.clone()], lut) + } + + fn circuit_bootstrap_vertical_packing_native_crt( + &self, + vec_ct_in: &[T], + lut: &[Vec], + ) -> T + where + T: IntegerCiphertext, + { + let mut extracted_bits_blocks = vec![]; + for ct_in in vec_ct_in.iter() { + let mut ct_in = ct_in.clone(); + // Extraction of each bit for each block + for block in ct_in.blocks_mut().iter_mut() { + let nb_bit_to_extract = + f64::log2((block.message_modulus.0 * block.carry_modulus.0) as f64).ceil() + as usize; + let delta_log = DeltaLog(64 - nb_bit_to_extract); + + // trick ( ct - delta/2 + delta/2^4 ) + let lwe_size = block.ct.lwe_size().0; + let mut cont = vec![0u64; lwe_size]; + cont[lwe_size - 1] = + (1 << (64 - nb_bit_to_extract - 1)) - (1 << (64 - nb_bit_to_extract - 5)); + + let tmp_ciphertext = LweCiphertext::from_container(cont); + lwe_ciphertext_sub_assign(&mut block.ct, &tmp_ciphertext); + + let extracted_bits = + self.wopbs_key + .extract_bits(delta_log, block, nb_bit_to_extract); + extracted_bits_blocks.push(extracted_bits); + } + } + + extracted_bits_blocks.reverse(); + + let vec_ct_out = self + .wopbs_key + .circuit_bootstrapping_vertical_packing(lut, &extracted_bits_blocks); + + let mut ct_vec_out: Vec = vec![]; + for (block, block_out) in vec_ct_in[0].blocks().iter().zip(vec_ct_out.into_iter()) { + ct_vec_out.push(crate::shortint::Ciphertext { + ct: block_out, + degree: Degree(block.message_modulus.0 - 1), + message_modulus: block.message_modulus, + carry_modulus: block.carry_modulus, + }); + } + T::from_blocks(ct_vec_out) + } + + pub fn keyswitch_to_wopbs_params(&self, sks: &ServerKey, ct_in: &T) -> T + where + T: IntegerCiphertext, + { + let blocks: Vec<_> = ct_in + .blocks() + .par_iter() + .map(|block| self.wopbs_key.keyswitch_to_wopbs_params(&sks.key, block)) + .collect(); + T::from_blocks(blocks) + } + + pub fn keyswitch_to_pbs_params(&self, ct_in: &T) -> T + where + T: IntegerCiphertext, + { + let blocks: Vec<_> = ct_in + .blocks() + .par_iter() + .map(|block| self.wopbs_key.keyswitch_to_pbs_params(block)) + .collect(); + T::from_blocks(blocks) + } +} diff --git a/tfhe/src/integer/wopbs/test.rs b/tfhe/src/integer/wopbs/test.rs new file mode 100644 index 000000000..87169db38 --- /dev/null +++ b/tfhe/src/integer/wopbs/test.rs @@ -0,0 +1,291 @@ +#![allow(unused)] + +use crate::integer::gen_keys; +use crate::integer::parameters::*; +use crate::integer::wopbs::{encode_radix, WopbsKey}; +use crate::shortint::parameters::parameters_wopbs::*; +use crate::shortint::parameters::parameters_wopbs_message_carry::*; +use crate::shortint::parameters::{Parameters, *}; +use rand::Rng; +use std::cmp::max; + +use crate::integer::keycache::{KEY_CACHE, KEY_CACHE_WOPBS}; +use paste::paste; + +const NB_TEST: usize = 10; + +macro_rules! create_parametrized_test{ + ($name:ident { $( ($sks_param:ident, $wopbs_param:ident) ),* }) => { + paste! { + $( + #[test] + fn []() { + $name(($sks_param, $wopbs_param)) + } + )* + } + }; + ($name:ident)=> { + create_parametrized_test!($name + { + (PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_2_CARRY_2), + (PARAM_MESSAGE_3_CARRY_3, WOPBS_PARAM_MESSAGE_3_CARRY_3), + (PARAM_MESSAGE_4_CARRY_4, WOPBS_PARAM_MESSAGE_4_CARRY_4) + }); + }; +} + +create_parametrized_test!(wopbs_crt); +create_parametrized_test!(wopbs_bivariate_radix); +create_parametrized_test!(wopbs_bivariate_crt); +create_parametrized_test!(wopbs_radix); + +fn make_basis(message_modulus: usize) -> Vec { + match message_modulus { + 2 => vec![2], + 3 => vec![2], + n if n < 8 => vec![2, 3], + n if n < 16 => vec![2, 5, 7], + _ => vec![3, 7, 13], + } +} + +pub fn wopbs_native_crt() { + let mut rng = rand::thread_rng(); + + let basis: Vec = vec![2, 3]; + let nb_block = basis.len(); + + let params = ( + crate::shortint::parameters::parameters_wopbs::PARAM_4_BITS_5_BLOCKS, + crate::shortint::parameters::parameters_wopbs::PARAM_4_BITS_5_BLOCKS, + ); + + let (cks, mut sks) = gen_keys(¶ms.1); + let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + + let mut msg_space = 1; + for modulus in basis.iter() { + msg_space *= modulus; + } + + let nb_test = 10; + + for _ in 0..nb_test { + let clear1 = rng.gen::() % msg_space; // Encrypt the integers + let mut ct1 = cks.encrypt_native_crt(clear1, basis.clone()); + + let lut = wopbs_key.generate_lut_native_crt(&ct1, |x| x); + + let ct_res = wopbs_key.wopbs_native_crt(&ct1, &lut); + let res = cks.decrypt_native_crt(&ct_res); + + assert_eq!(res, clear1); + } +} + +pub fn wopbs_native_crt_bivariate() { + let mut rng = rand::thread_rng(); + + let basis: Vec = vec![9, 11]; + + let nb_block = basis.len(); + + let params = ( + crate::shortint::parameters::parameters_wopbs::PARAM_4_BITS_5_BLOCKS, + crate::shortint::parameters::parameters_wopbs::PARAM_4_BITS_5_BLOCKS, + ); + + let (cks, mut sks) = gen_keys(¶ms.1); + let wopbs_key = KEY_CACHE_WOPBS.get_from_params(params); + + let mut msg_space = 1; + for modulus in basis.iter() { + msg_space *= modulus; + } + + let nb_test = 10; + let mut tmp = 0; + for _ in 0..nb_test { + let clear1 = rng.gen::() % msg_space; // Encrypt the integers + let clear2 = rng.gen::() % msg_space; // Encrypt the integers + let mut ct1 = cks.encrypt_native_crt(clear1, basis.clone()); + let mut ct2 = cks.encrypt_native_crt(clear2, basis.clone()); + + let lut = wopbs_key.generate_lut_bivariate_native_crt(&ct1, |x, y| x * y); + let ct_res = wopbs_key.bivariate_wopbs_native_crt(&ct1, &ct2, &lut); + let res = cks.decrypt_native_crt(&ct_res); + + if (clear1 * clear2) % msg_space != res { + tmp += 1; + } + } + assert_eq!(tmp, 0); +} + +// test wopbs fake crt with different degree for each Ct +pub fn wopbs_crt(params: (Parameters, Parameters)) { + let mut rng = rand::thread_rng(); + + let basis = make_basis(params.1.message_modulus.0); + + let nb_block = basis.len(); + + let (cks, mut sks) = gen_keys(¶ms.0); + let wopbs_key = KEY_CACHE_WOPBS.get_from_params(params); + + let mut msg_space = 1; + for modulus in basis.iter() { + msg_space *= modulus; + } + + let nb_test = 10; + let mut tmp = 0; + for _ in 0..nb_test { + let clear1 = rng.gen::() % msg_space; + let mut ct1 = cks.encrypt_crt(clear1, basis.clone()); + //artificially modify the degree + for ct in ct1.blocks.iter_mut() { + let degree = params.0.message_modulus.0 + * ((rng.gen::() % (params.0.carry_modulus.0 - 1)) + 1); + ct.degree.0 = degree; + } + let res = cks.decrypt_crt(&ct1); + + let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1); + let lut = wopbs_key.generate_lut_crt(&ct1, |x| (x * x) + x); + let ct_res = wopbs_key.wopbs(&ct1, &lut); + let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + + let res_wop = cks.decrypt_crt(&ct_res); + if ((res * res) + res) % msg_space != res_wop { + tmp += 1; + } + } + if tmp != 0 { + println!("failure rate {tmp:?}/{nb_test:?}"); + panic!() + } +} + +// test wopbs fake crt with different degree for each Ct +pub fn wopbs_radix(params: (Parameters, Parameters)) { + let mut rng = rand::thread_rng(); + + let nb_block = 2; + + let (cks, mut sks) = gen_keys(¶ms.0); + let wopbs_key = KEY_CACHE_WOPBS.get_from_params(params); + + let mut msg_space: u64 = params.0.message_modulus.0 as u64; + for modulus in 1..nb_block { + msg_space *= params.0.message_modulus.0 as u64; + } + + let nb_test = 10; + let mut tmp = 0; + for _ in 0..nb_test { + let clear1 = rng.gen::() % msg_space; + let mut ct1 = cks.encrypt_radix(clear1, nb_block); + + // //artificially modify the degree + let res = cks.decrypt_radix(&ct1); + let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1); + let lut = wopbs_key.generate_lut_radix(&ct1, |x| x); + let ct_res = wopbs_key.wopbs(&ct1, &lut); + let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + let res_wop = cks.decrypt_radix(&ct_res); + if res % msg_space != res_wop { + tmp += 1; + } + } + if tmp != 0 { + println!("failure rate {tmp:?}/{nb_test:?}"); + panic!() + } +} + +// test wopbs radix with different degree for each Ct +pub fn wopbs_bivariate_radix(params: (Parameters, Parameters)) { + let mut rng = rand::thread_rng(); + + let nb_block = 2; + + let (cks, mut sks) = gen_keys(¶ms.0); + let wopbs_key = KEY_CACHE_WOPBS.get_from_params(params); + + let mut msg_space: u64 = params.0.message_modulus.0 as u64; + for modulus in 1..nb_block { + msg_space *= params.0.message_modulus.0 as u64; + } + + let nb_test = 10; + + for _ in 0..nb_test { + let mut clear1 = rng.gen::() % msg_space; + let mut clear2 = rng.gen::() % msg_space; + + let mut ct1 = cks.encrypt_radix(clear1, nb_block); + let scalar = rng.gen::() % msg_space; + sks.smart_scalar_add_assign(&mut ct1, scalar); + let dec1 = cks.decrypt_radix(&ct1); + + let mut ct2 = cks.encrypt_radix(clear2, nb_block); + let scalar = rng.gen::() % msg_space; + sks.smart_scalar_add_assign(&mut ct2, scalar); + let dec2 = cks.decrypt_radix(&ct2); + + let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1); + let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct2); + + let lut = wopbs_key.generate_lut_bivariate_radix(&ct1, &ct2, |x, y| x + y * x); + let ct_res = wopbs_key.bivariate_wopbs_with_degree(&ct1, &ct2, &lut); + let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + + let res = cks.decrypt_radix(&ct_res); + assert_eq!(res, (dec1 + dec2 * dec1) % msg_space); + } +} + +// test wopbs bivariate fake crt with different degree for each Ct +pub fn wopbs_bivariate_crt(params: (Parameters, Parameters)) { + let mut rng = rand::thread_rng(); + + let basis = make_basis(params.1.message_modulus.0); + let modulus = basis.iter().product::(); + + let (cks, mut sks) = gen_keys(¶ms.0); + let wopbs_key = KEY_CACHE_WOPBS.get_from_params(params); + + let mut msg_space: u64 = 1; + for modulus in basis.iter() { + msg_space *= modulus; + } + + let nb_test = 10; + + for _ in 0..nb_test { + let clear1 = rng.gen::() % msg_space; + let clear2 = rng.gen::() % msg_space; + let mut ct1 = cks.encrypt_crt(clear1, basis.clone()); + let mut ct2 = cks.encrypt_crt(clear2, basis.clone()); + //artificially modify the degree + for (ct_1, ct_2) in ct1.blocks.iter_mut().zip(ct2.blocks.iter_mut()) { + let degree = params.0.message_modulus.0 + * ((rng.gen::() % (params.0.carry_modulus.0 - 1)) + 1); + ct_1.degree.0 = degree; + let degree = params.0.message_modulus.0 + * ((rng.gen::() % (params.0.carry_modulus.0 - 1)) + 1); + ct_2.degree.0 = degree; + } + + let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1); + let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct2); + let lut = wopbs_key.generate_lut_bivariate_crt(&ct1, &ct2, |x, y| (x * y) + y); + let ct_res = wopbs_key.bivariate_wopbs_with_degree(&ct1, &ct2, &lut); + let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + + let res = cks.decrypt_crt(&ct_res); + assert_eq!(res, ((clear1 * clear2) + clear2) % msg_space); + } +} diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index e6fb0d458..1fd5e3fb9 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -22,6 +22,13 @@ pub mod boolean; /// cbindgen:ignore pub mod core_crypto; +#[cfg(feature = "integer")] +/// Welcome to the TFHE-rs [`integer`](`crate::integer`) module documentation! +/// +/// # Special module attributes +/// cbindgen:ignore +pub mod integer; + #[cfg(feature = "shortint")] /// Welcome to the TFHE-rs [`shortint`](`crate::shortint`) module documentation! /// @@ -35,5 +42,10 @@ pub mod js_on_wasm_api; #[cfg(feature = "__wasm_api")] pub use js_on_wasm_api::*; -#[cfg(all(doctest, feature = "shortint", feature = "boolean"))] +#[cfg(all( + doctest, + feature = "shortint", + feature = "boolean", + feature = "integer" +))] mod test_user_docs; diff --git a/tfhe/src/shortint/engine/wopbs/mod.rs b/tfhe/src/shortint/engine/wopbs/mod.rs index 479eaf507..ebe6cb1df 100644 --- a/tfhe/src/shortint/engine/wopbs/mod.rs +++ b/tfhe/src/shortint/engine/wopbs/mod.rs @@ -485,8 +485,8 @@ impl ShortintEngine { pub fn circuit_bootstrapping_vertical_packing( &mut self, wopbs_key: &WopbsKey, - vec_lut: Vec>, - extracted_bits_blocks: Vec>, + vec_lut: &[Vec], + extracted_bits_blocks: &[LweCiphertextListOwned], ) -> Vec> { let lwe_size = extracted_bits_blocks[0].lwe_size(); diff --git a/tfhe/src/shortint/wopbs/mod.rs b/tfhe/src/shortint/wopbs/mod.rs index bca92e8cc..b40de552c 100644 --- a/tfhe/src/shortint/wopbs/mod.rs +++ b/tfhe/src/shortint/wopbs/mod.rs @@ -346,8 +346,8 @@ impl WopbsKey { /// # Warning Experimental pub fn circuit_bootstrapping_vertical_packing( &self, - vec_lut: Vec>, - extracted_bits_blocks: Vec>, + vec_lut: &[Vec], + extracted_bits_blocks: &[LweCiphertextListOwned], ) -> Vec> { ShortintEngine::with_thread_local_mut(|engine| { engine.circuit_bootstrapping_vertical_packing(self, vec_lut, extracted_bits_blocks) diff --git a/tfhe/src/test_user_docs.rs b/tfhe/src/test_user_docs.rs index 8437d61d9..669e2b23b 100644 --- a/tfhe/src/test_user_docs.rs +++ b/tfhe/src/test_user_docs.rs @@ -28,3 +28,18 @@ doctest!("../docs/core_crypto/tutorial.md", core_crypto_turorial); // "../docs/tutorials/circuit_evaluation.md", // circuit_evaluation // ); + +// Integer +doctest!( + "../docs/integer/getting_started/first_circuit.md", + integer_first_circuit +); +doctest!( + "../docs/integer/tutorials/serialization.md", + integer_serialization_tuto +); +doctest!( + "../docs/integer/tutorials/circuit_evaluation.md", + integer_circuit_evaluation +); +doctest!("../docs/integer/how_to/pbs.md", integer_pbs);